CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/flava_finetuning_tutorial.py
Views: 712
1
# -*- coding: utf-8 -*-
2
"""
3
TorchMultimodal Tutorial: Finetuning FLAVA
4
============================================
5
"""
6
7
######################################################################
8
# Multimodal AI has recently become very popular owing to its ubiquitous
9
# nature, from use cases like image captioning and visual search to more
10
# recent applications like image generation from text. **TorchMultimodal
11
# is a library powered by Pytorch consisting of building blocks and end to
12
# end examples, aiming to enable and accelerate research in
13
# multimodality**.
14
#
15
# In this tutorial, we will demonstrate how to use a **pretrained SoTA
16
# model called** `FLAVA <https://arxiv.org/pdf/2112.04482.pdf>`__ **from
17
# TorchMultimodal library to finetune on a multimodal task i.e. visual
18
# question answering** (VQA). The model consists of two unimodal transformer
19
# based encoders for text and image and a multimodal encoder to combine
20
# the two embeddings. It is pretrained using contrastive, image text matching and
21
# text, image and multimodal masking losses.
22
23
24
######################################################################
25
# Installation
26
# -----------------
27
# We will use TextVQA dataset and ``bert tokenizer`` from Hugging Face for this
28
# tutorial. So you need to install datasets and transformers in addition to TorchMultimodal.
29
#
30
# .. note::
31
#
32
# When running this tutorial in Google Colab, install the required packages by
33
# creating a new cell and running the following commands:
34
#
35
# .. code-block::
36
#
37
# !pip install torchmultimodal-nightly
38
# !pip install datasets
39
# !pip install transformers
40
#
41
42
######################################################################
43
# Steps
44
# -----
45
#
46
# 1. Download the Hugging Face dataset to a directory on your computer by running the following command:
47
#
48
# .. code-block::
49
#
50
# wget http://dl.fbaipublicfiles.com/pythia/data/vocab.tar.gz
51
# tar xf vocab.tar.gz
52
#
53
# .. note::
54
# If you are running this tutorial in Google Colab, run these commands
55
# in a new cell and prepend these commands with an exclamation mark (!)
56
#
57
#
58
# 2. For this tutorial, we treat VQA as a classification task where
59
# the inputs are images and question (text) and the output is an answer class.
60
# So we need to download the vocab file with answer classes and create the answer to
61
# label mapping.
62
#
63
# We also load the `textvqa
64
# dataset <https://arxiv.org/pdf/1904.08920.pdf>`__ containing 34602 training samples
65
# (images,questions and answers) from Hugging Face
66
#
67
# We see there are 3997 answer classes including a class representing
68
# unknown answers.
69
#
70
71
with open("data/vocabs/answers_textvqa_more_than_1.txt") as f:
72
vocab = f.readlines()
73
74
answer_to_idx = {}
75
for idx, entry in enumerate(vocab):
76
answer_to_idx[entry.strip("\n")] = idx
77
print(len(vocab))
78
print(vocab[:5])
79
80
from datasets import load_dataset
81
dataset = load_dataset("textvqa")
82
83
######################################################################
84
# Lets display a sample entry from the dataset:
85
#
86
87
import matplotlib.pyplot as plt
88
import numpy as np
89
idx = 5
90
print("Question: ", dataset["train"][idx]["question"])
91
print("Answers: " ,dataset["train"][idx]["answers"])
92
im = np.asarray(dataset["train"][idx]["image"].resize((500,500)))
93
plt.imshow(im)
94
plt.show()
95
96
97
######################################################################
98
# 3. Next, we write the transform function to convert the image and text into
99
# Tensors consumable by our model - For images, we use the transforms from
100
# torchvision to convert to Tensor and resize to uniform sizes - For text,
101
# we tokenize (and pad) them using the ``BertTokenizer`` from Hugging Face -
102
# For answers (i.e. labels), we take the most frequently occurring answer
103
# as the label to train with:
104
#
105
106
import torch
107
from torchvision import transforms
108
from collections import defaultdict
109
from transformers import BertTokenizer
110
from functools import partial
111
112
def transform(tokenizer, input):
113
batch = {}
114
image_transform = transforms.Compose([transforms.ToTensor(), transforms.Resize([224,224])])
115
image = image_transform(input["image"][0].convert("RGB"))
116
batch["image"] = [image]
117
118
tokenized=tokenizer(input["question"],return_tensors='pt',padding="max_length",max_length=512)
119
batch.update(tokenized)
120
121
122
ans_to_count = defaultdict(int)
123
for ans in input["answers"][0]:
124
ans_to_count[ans] += 1
125
max_value = max(ans_to_count, key=ans_to_count.get)
126
ans_idx = answer_to_idx.get(max_value,0)
127
batch["answers"] = torch.as_tensor([ans_idx])
128
return batch
129
130
tokenizer=BertTokenizer.from_pretrained("bert-base-uncased",padding="max_length",max_length=512)
131
transform=partial(transform,tokenizer)
132
dataset.set_transform(transform)
133
134
135
######################################################################
136
# 4. Finally, we import the ``flava_model_for_classification`` from
137
# ``torchmultimodal``. It loads the pretrained FLAVA checkpoint by default and
138
# includes a classification head.
139
#
140
# The model forward function passes the image through the visual encoder
141
# and the question through the text encoder. The image and question
142
# embeddings are then passed through the multimodal encoder. The final
143
# embedding corresponding to the CLS token is passed through a MLP head
144
# which finally gives the probability distribution over each possible
145
# answers.
146
#
147
148
from torchmultimodal.models.flava.model import flava_model_for_classification
149
model = flava_model_for_classification(num_classes=len(vocab))
150
151
152
######################################################################
153
# 5. We put together the dataset and model in a toy training loop to
154
# demonstrate how to train the model for 3 iterations:
155
#
156
157
from torch import nn
158
BATCH_SIZE = 2
159
MAX_STEPS = 3
160
from torch.utils.data import DataLoader
161
162
train_dataloader = DataLoader(dataset["train"], batch_size= BATCH_SIZE)
163
optimizer = torch.optim.AdamW(model.parameters())
164
165
166
epochs = 1
167
for _ in range(epochs):
168
for idx, batch in enumerate(train_dataloader):
169
optimizer.zero_grad()
170
out = model(text = batch["input_ids"], image = batch["image"], labels = batch["answers"])
171
loss = out.loss
172
loss.backward()
173
optimizer.step()
174
print(f"Loss at step {idx} = {loss}")
175
if idx >= MAX_STEPS-1:
176
break
177
178
179
######################################################################
180
# Conclusion
181
# -------------------
182
#
183
# This tutorial introduced the basics around how to finetune on a
184
# multimodal task using FLAVA from TorchMultimodal. Please also check out
185
# other examples from the library like
186
# `MDETR <https://github.com/facebookresearch/multimodal/tree/main/torchmultimodal/models/mdetr>`__
187
# which is a multimodal model for object detection and
188
# `Omnivore <https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/models/omnivore.py>`__
189
# which is multitask model spanning image, video and 3d classification.
190
#
191
192