Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/beginner_source/flava_finetuning_tutorial.py
Views: 712
# -*- coding: utf-8 -*-1"""2TorchMultimodal Tutorial: Finetuning FLAVA3============================================4"""56######################################################################7# Multimodal AI has recently become very popular owing to its ubiquitous8# nature, from use cases like image captioning and visual search to more9# recent applications like image generation from text. **TorchMultimodal10# is a library powered by Pytorch consisting of building blocks and end to11# end examples, aiming to enable and accelerate research in12# multimodality**.13#14# In this tutorial, we will demonstrate how to use a **pretrained SoTA15# model called** `FLAVA <https://arxiv.org/pdf/2112.04482.pdf>`__ **from16# TorchMultimodal library to finetune on a multimodal task i.e. visual17# question answering** (VQA). The model consists of two unimodal transformer18# based encoders for text and image and a multimodal encoder to combine19# the two embeddings. It is pretrained using contrastive, image text matching and20# text, image and multimodal masking losses.212223######################################################################24# Installation25# -----------------26# We will use TextVQA dataset and ``bert tokenizer`` from Hugging Face for this27# tutorial. So you need to install datasets and transformers in addition to TorchMultimodal.28#29# .. note::30#31# When running this tutorial in Google Colab, install the required packages by32# creating a new cell and running the following commands:33#34# .. code-block::35#36# !pip install torchmultimodal-nightly37# !pip install datasets38# !pip install transformers39#4041######################################################################42# Steps43# -----44#45# 1. Download the Hugging Face dataset to a directory on your computer by running the following command:46#47# .. code-block::48#49# wget http://dl.fbaipublicfiles.com/pythia/data/vocab.tar.gz50# tar xf vocab.tar.gz51#52# .. note::53# If you are running this tutorial in Google Colab, run these commands54# in a new cell and prepend these commands with an exclamation mark (!)55#56#57# 2. For this tutorial, we treat VQA as a classification task where58# the inputs are images and question (text) and the output is an answer class.59# So we need to download the vocab file with answer classes and create the answer to60# label mapping.61#62# We also load the `textvqa63# dataset <https://arxiv.org/pdf/1904.08920.pdf>`__ containing 34602 training samples64# (images,questions and answers) from Hugging Face65#66# We see there are 3997 answer classes including a class representing67# unknown answers.68#6970with open("data/vocabs/answers_textvqa_more_than_1.txt") as f:71vocab = f.readlines()7273answer_to_idx = {}74for idx, entry in enumerate(vocab):75answer_to_idx[entry.strip("\n")] = idx76print(len(vocab))77print(vocab[:5])7879from datasets import load_dataset80dataset = load_dataset("textvqa")8182######################################################################83# Lets display a sample entry from the dataset:84#8586import matplotlib.pyplot as plt87import numpy as np88idx = 589print("Question: ", dataset["train"][idx]["question"])90print("Answers: " ,dataset["train"][idx]["answers"])91im = np.asarray(dataset["train"][idx]["image"].resize((500,500)))92plt.imshow(im)93plt.show()949596######################################################################97# 3. Next, we write the transform function to convert the image and text into98# Tensors consumable by our model - For images, we use the transforms from99# torchvision to convert to Tensor and resize to uniform sizes - For text,100# we tokenize (and pad) them using the ``BertTokenizer`` from Hugging Face -101# For answers (i.e. labels), we take the most frequently occurring answer102# as the label to train with:103#104105import torch106from torchvision import transforms107from collections import defaultdict108from transformers import BertTokenizer109from functools import partial110111def transform(tokenizer, input):112batch = {}113image_transform = transforms.Compose([transforms.ToTensor(), transforms.Resize([224,224])])114image = image_transform(input["image"][0].convert("RGB"))115batch["image"] = [image]116117tokenized=tokenizer(input["question"],return_tensors='pt',padding="max_length",max_length=512)118batch.update(tokenized)119120121ans_to_count = defaultdict(int)122for ans in input["answers"][0]:123ans_to_count[ans] += 1124max_value = max(ans_to_count, key=ans_to_count.get)125ans_idx = answer_to_idx.get(max_value,0)126batch["answers"] = torch.as_tensor([ans_idx])127return batch128129tokenizer=BertTokenizer.from_pretrained("bert-base-uncased",padding="max_length",max_length=512)130transform=partial(transform,tokenizer)131dataset.set_transform(transform)132133134######################################################################135# 4. Finally, we import the ``flava_model_for_classification`` from136# ``torchmultimodal``. It loads the pretrained FLAVA checkpoint by default and137# includes a classification head.138#139# The model forward function passes the image through the visual encoder140# and the question through the text encoder. The image and question141# embeddings are then passed through the multimodal encoder. The final142# embedding corresponding to the CLS token is passed through a MLP head143# which finally gives the probability distribution over each possible144# answers.145#146147from torchmultimodal.models.flava.model import flava_model_for_classification148model = flava_model_for_classification(num_classes=len(vocab))149150151######################################################################152# 5. We put together the dataset and model in a toy training loop to153# demonstrate how to train the model for 3 iterations:154#155156from torch import nn157BATCH_SIZE = 2158MAX_STEPS = 3159from torch.utils.data import DataLoader160161train_dataloader = DataLoader(dataset["train"], batch_size= BATCH_SIZE)162optimizer = torch.optim.AdamW(model.parameters())163164165epochs = 1166for _ in range(epochs):167for idx, batch in enumerate(train_dataloader):168optimizer.zero_grad()169out = model(text = batch["input_ids"], image = batch["image"], labels = batch["answers"])170loss = out.loss171loss.backward()172optimizer.step()173print(f"Loss at step {idx} = {loss}")174if idx >= MAX_STEPS-1:175break176177178######################################################################179# Conclusion180# -------------------181#182# This tutorial introduced the basics around how to finetune on a183# multimodal task using FLAVA from TorchMultimodal. Please also check out184# other examples from the library like185# `MDETR <https://github.com/facebookresearch/multimodal/tree/main/torchmultimodal/models/mdetr>`__186# which is a multimodal model for object detection and187# `Omnivore <https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/models/omnivore.py>`__188# which is multitask model spanning image, video and 3d classification.189#190191192