Path: blob/master/examples/nlp/multimodal_entailment.py
3507 views
"""1Title: Multimodal entailment2Author: [Sayak Paul](https://twitter.com/RisingSayak)3Date created: 2021/08/084Last modified: 2025/01/035Description: Training a multimodal model for predicting entailment.6Accelerator: GPU7Converted to Keras 3 and made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234)8"""910"""11## Introduction1213In this example, we will build and train a model for predicting multimodal entailment. We will be14using the15[multimodal entailment dataset](https://github.com/google-research-datasets/recognizing-multimodal-entailment)16recently introduced by Google Research.1718### What is multimodal entailment?1920On social media platforms, to audit and moderate content21we may want to find answers to the22following questions in near real-time:2324* Does a given piece of information contradict the other?25* Does a given piece of information imply the other?2627In NLP, this task is called analyzing _textual entailment_. However, that's only28when the information comes from text content.29In practice, it's often the case the information available comes not just30from text content, but from a multimodal combination of text, images, audio, video, etc.31_Multimodal entailment_ is simply the extension of textual entailment to a variety32of new input modalities.3334### Requirements3536This example requires TensorFlow 2.5 or higher. In addition, TensorFlow Hub and37TensorFlow Text are required for the BERT model38([Devlin et al.](https://arxiv.org/abs/1810.04805)). These libraries can be installed39using the following command:40"""4142"""shell43pip install -q tensorflow_text44"""4546"""47## Imports48"""4950from sklearn.model_selection import train_test_split51import matplotlib.pyplot as plt52import pandas as pd53import numpy as np54import random55import math56from skimage.io import imread57from skimage.transform import resize58from PIL import Image59import os6061os.environ["KERAS_BACKEND"] = "jax" # or tensorflow, or torch6263import keras64import keras_hub65from keras.utils import PyDataset6667"""68## Define a label map69"""7071label_map = {"Contradictory": 0, "Implies": 1, "NoEntailment": 2}7273"""74## Collect the dataset7576The original dataset is available77[here](https://github.com/google-research-datasets/recognizing-multimodal-entailment).78It comes with URLs of images which are hosted on Twitter's photo storage system called79the80[Photo Blob Storage (PBS for short)](https://blog.twitter.com/engineering/en_us/a/2012/blobstore-twitter-s-in-house-photo-storage-system).81We will be working with the downloaded images along with additional data that comes with82the original dataset. Thanks to83[Nilabhra Roy Chowdhury](https://de.linkedin.com/in/nilabhraroychowdhury) who worked on84preparing the image data.85"""8687image_base_path = keras.utils.get_file(88"tweet_images",89"https://github.com/sayakpaul/Multimodal-Entailment-Baseline/releases/download/v1.0.0/tweet_images.tar.gz",90untar=True,91)9293"""94## Read the dataset and apply basic preprocessing95"""9697df = pd.read_csv(98"https://github.com/sayakpaul/Multimodal-Entailment-Baseline/raw/main/csvs/tweets.csv"99).iloc[1000:1000101] # Resources conservation since these are examples and not SOTA102df.sample(10)103104"""105The columns we are interested in are the following:106107* `text_1`108* `image_1`109* `text_2`110* `image_2`111* `label`112113The entailment task is formulated as the following:114115***Given the pairs of (`text_1`, `image_1`) and (`text_2`, `image_2`) do they entail (or116not entail or contradict) each other?***117118We have the images already downloaded. `image_1` is downloaded as `id1` as its filename119and `image2` is downloaded as `id2` as its filename. In the next step, we will add two120more columns to `df` - filepaths of `image_1`s and `image_2`s.121"""122123images_one_paths = []124images_two_paths = []125126for idx in range(len(df)):127current_row = df.iloc[idx]128id_1 = current_row["id_1"]129id_2 = current_row["id_2"]130extentsion_one = current_row["image_1"].split(".")[-1]131extentsion_two = current_row["image_2"].split(".")[-1]132133image_one_path = os.path.join(image_base_path, str(id_1) + f".{extentsion_one}")134image_two_path = os.path.join(image_base_path, str(id_2) + f".{extentsion_two}")135136images_one_paths.append(image_one_path)137images_two_paths.append(image_two_path)138139df["image_1_path"] = images_one_paths140df["image_2_path"] = images_two_paths141142# Create another column containing the integer ids of143# the string labels.144df["label_idx"] = df["label"].apply(lambda x: label_map[x])145146"""147## Dataset visualization148"""149150151def visualize(idx):152current_row = df.iloc[idx]153image_1 = plt.imread(current_row["image_1_path"])154image_2 = plt.imread(current_row["image_2_path"])155text_1 = current_row["text_1"]156text_2 = current_row["text_2"]157label = current_row["label"]158159plt.subplot(1, 2, 1)160plt.imshow(image_1)161plt.axis("off")162plt.title("Image One")163plt.subplot(1, 2, 2)164plt.imshow(image_1)165plt.axis("off")166plt.title("Image Two")167plt.show()168169print(f"Text one: {text_1}")170print(f"Text two: {text_2}")171print(f"Label: {label}")172173174random_idx = random.choice(range(len(df)))175visualize(random_idx)176177random_idx = random.choice(range(len(df)))178visualize(random_idx)179180"""181## Train/test split182183The dataset suffers from184[class imbalance problem](https://developers.google.com/machine-learning/glossary#class-imbalanced-dataset).185We can confirm that in the following cell.186"""187188df["label"].value_counts()189190"""191To account for that we will go for a stratified split.192"""193194# 10% for test195train_df, test_df = train_test_split(196df, test_size=0.1, stratify=df["label"].values, random_state=42197)198# 5% for validation199train_df, val_df = train_test_split(200train_df, test_size=0.05, stratify=train_df["label"].values, random_state=42201)202203print(f"Total training examples: {len(train_df)}")204print(f"Total validation examples: {len(val_df)}")205print(f"Total test examples: {len(test_df)}")206207"""208## Data input pipeline209210Keras Hub provides211[variety of BERT family of models](https://keras.io/keras_hub/presets/).212Each of those models comes with a213corresponding preprocessing layer. You can learn more about these models and their214preprocessing layers from215[this resource](https://www.kaggle.com/models/keras/bert/keras/bert_base_en_uncased/2).216217To keep the runtime of this example relatively short, we will use a base_unacased variant of218the original BERT model.219"""220221"""222text preprocessing using KerasHub223"""224225text_preprocessor = keras_hub.models.BertTextClassifierPreprocessor.from_preset(226"bert_base_en_uncased",227sequence_length=128,228)229230"""231### Run the preprocessor on a sample input232"""233234idx = random.choice(range(len(train_df)))235row = train_df.iloc[idx]236sample_text_1, sample_text_2 = row["text_1"], row["text_2"]237print(f"Text 1: {sample_text_1}")238print(f"Text 2: {sample_text_2}")239240test_text = [sample_text_1, sample_text_2]241text_preprocessed = text_preprocessor(test_text)242243print("Keys : ", list(text_preprocessed.keys()))244print("Shape Token Ids : ", text_preprocessed["token_ids"].shape)245print("Token Ids : ", text_preprocessed["token_ids"][0, :16])246print(" Shape Padding Mask : ", text_preprocessed["padding_mask"].shape)247print("Padding Mask : ", text_preprocessed["padding_mask"][0, :16])248print("Shape Segment Ids : ", text_preprocessed["segment_ids"].shape)249print("Segment Ids : ", text_preprocessed["segment_ids"][0, :16])250251252"""253We will now create `tf.data.Dataset` objects from the dataframes.254255Note that the text inputs will be preprocessed as a part of the data input pipeline. But256the preprocessing modules can also be a part of their corresponding BERT models. This257helps reduce the training/serving skew and lets our models operate with raw text inputs.258Follow [this tutorial](https://www.tensorflow.org/text/tutorials/classify_text_with_bert)259to learn more about how to incorporate the preprocessing modules directly inside the260models.261"""262263264def dataframe_to_dataset(dataframe):265columns = ["image_1_path", "image_2_path", "text_1", "text_2", "label_idx"]266ds = UnifiedPyDataset(267dataframe,268batch_size=32,269workers=4,270)271return ds272273274"""275### Preprocessing utilities276"""277278bert_input_features = ["padding_mask", "segment_ids", "token_ids"]279280281def preprocess_text(text_1, text_2):282output = text_preprocessor([text_1, text_2])283output = {284feature: keras.ops.reshape(output[feature], [-1])285for feature in bert_input_features286}287return output288289290"""291### Create the final datasets, method adapted from PyDataset doc string.292"""293294295class UnifiedPyDataset(PyDataset):296"""A Keras-compatible dataset that processes a DataFrame for TensorFlow, JAX, and PyTorch."""297298def __init__(299self,300df,301batch_size=32,302workers=4,303use_multiprocessing=False,304max_queue_size=10,305**kwargs,306):307"""308Args:309df: pandas DataFrame with data310batch_size: Batch size for dataset311workers: Number of workers to use for parallel loading (Keras)312use_multiprocessing: Whether to use multiprocessing313max_queue_size: Maximum size of the data queue for parallel loading314"""315super().__init__(**kwargs)316self.dataframe = df317columns = ["image_1_path", "image_2_path", "text_1", "text_2"]318# image files319self.image_x_1 = self.dataframe["image_1_path"]320self.image_x_2 = self.dataframe["image_1_path"]321self.image_y = self.dataframe["label_idx"]322# text files323self.text_x_1 = self.dataframe["text_1"]324self.text_x_2 = self.dataframe["text_2"]325self.text_y = self.dataframe["label_idx"]326# general327self.batch_size = batch_size328self.workers = workers329self.use_multiprocessing = use_multiprocessing330self.max_queue_size = max_queue_size331332def __getitem__(self, index):333"""334Fetches a batch of data from the dataset at the given index.335"""336337# Return x, y for batch idx.338low = index * self.batch_size339# Cap upper bound at array length; the last batch may be smaller340# if the total number of items is not a multiple of batch size.341# image files342high_image_1 = min(low + self.batch_size, len(self.image_x_1))343high_image_2 = min(low + self.batch_size, len(self.image_x_2))344# text345high_text_1 = min(low + self.batch_size, len(self.text_x_1))346high_text_2 = min(low + self.batch_size, len(self.text_x_1))347# images files348batch_image_x_1 = self.image_x_1[low:high_image_1]349batch_image_y_1 = self.image_y[low:high_image_1]350batch_image_x_2 = self.image_x_2[low:high_image_2]351batch_image_y_2 = self.image_y[low:high_image_2]352# text files353batch_text_x_1 = self.text_x_1[low:high_text_1]354batch_text_y_1 = self.text_y[low:high_text_1]355batch_text_x_2 = self.text_x_2[low:high_text_2]356batch_text_y_2 = self.text_y[low:high_text_2]357# image number 1 inputs358image_1 = [359resize(imread(file_name), (128, 128)) for file_name in batch_image_x_1360]361image_1 = [362( # exeperienced some shapes which were different from others.363np.array(Image.fromarray((img.astype(np.uint8))).convert("RGB"))364if img.shape[2] == 4365else img366)367for img in image_1368]369image_1 = np.array(image_1)370# Both text inputs to the model, return a dict for inputs to BertBackbone371text = {372key: np.array(373[374d[key]375for d in [376preprocess_text(file_path1, file_path2)377for file_path1, file_path2 in zip(378batch_text_x_1, batch_text_x_2379)380]381]382)383for key in ["padding_mask", "token_ids", "segment_ids"]384}385# Image number 2 model inputs386image_2 = [387resize(imread(file_name), (128, 128)) for file_name in batch_image_x_2388]389image_2 = [390( # exeperienced some shapes which were different from others391np.array(Image.fromarray((img.astype(np.uint8))).convert("RGB"))392if img.shape[2] == 4393else img394)395for img in image_2396]397# Stack the list comprehension to an nd.array398image_2 = np.array(image_2)399return (400{401"image_1": image_1,402"image_2": image_2,403"padding_mask": text["padding_mask"],404"segment_ids": text["segment_ids"],405"token_ids": text["token_ids"],406},407# Target lables408np.array(batch_image_y_1),409)410411def __len__(self):412"""413Returns the number of batches in the dataset.414"""415return math.ceil(len(self.dataframe) / self.batch_size)416417418"""419Create train, validation and test datasets420"""421422423def prepare_dataset(dataframe):424ds = dataframe_to_dataset(dataframe)425return ds426427428train_ds = prepare_dataset(train_df)429validation_ds = prepare_dataset(val_df)430test_ds = prepare_dataset(test_df)431432"""433## Model building utilities434435Our final model will accept two images along with their text counterparts. While the436images will be directly fed to the model the text inputs will first be preprocessed and437then will make it into the model. Below is a visual illustration of this approach:438439440441The model consists of the following elements:442443* A standalone encoder for the images. We will use a444[ResNet50V2](https://arxiv.org/abs/1603.05027) pre-trained on the ImageNet-1k dataset for445this.446* A standalone encoder for the images. A pre-trained BERT will be used for this.447448After extracting the individual embeddings, they will be projected in an identical space.449Finally, their projections will be concatenated and be fed to the final classification450layer.451452This is a multi-class classification problem involving the following classes:453454* NoEntailment455* Implies456* Contradictory457458`project_embeddings()`, `create_vision_encoder()`, and `create_text_encoder()` utilities459are referred from [this example](https://keras.io/examples/nlp/nl_image_search/).460"""461462"""463Projection utilities464"""465466467def project_embeddings(468embeddings, num_projection_layers, projection_dims, dropout_rate469):470projected_embeddings = keras.layers.Dense(units=projection_dims)(embeddings)471for _ in range(num_projection_layers):472x = keras.ops.nn.gelu(projected_embeddings)473x = keras.layers.Dense(projection_dims)(x)474x = keras.layers.Dropout(dropout_rate)(x)475x = keras.layers.Add()([projected_embeddings, x])476projected_embeddings = keras.layers.LayerNormalization()(x)477return projected_embeddings478479480"""481Vision encoder utilities482"""483484485def create_vision_encoder(486num_projection_layers, projection_dims, dropout_rate, trainable=False487):488# Load the pre-trained ResNet50V2 model to be used as the base encoder.489resnet_v2 = keras.applications.ResNet50V2(490include_top=False, weights="imagenet", pooling="avg"491)492# Set the trainability of the base encoder.493for layer in resnet_v2.layers:494layer.trainable = trainable495496# Receive the images as inputs.497image_1 = keras.Input(shape=(128, 128, 3), name="image_1")498image_2 = keras.Input(shape=(128, 128, 3), name="image_2")499500# Preprocess the input image.501preprocessed_1 = keras.applications.resnet_v2.preprocess_input(image_1)502preprocessed_2 = keras.applications.resnet_v2.preprocess_input(image_2)503504# Generate the embeddings for the images using the resnet_v2 model505# concatenate them.506embeddings_1 = resnet_v2(preprocessed_1)507embeddings_2 = resnet_v2(preprocessed_2)508embeddings = keras.layers.Concatenate()([embeddings_1, embeddings_2])509510# Project the embeddings produced by the model.511outputs = project_embeddings(512embeddings, num_projection_layers, projection_dims, dropout_rate513)514# Create the vision encoder model.515return keras.Model([image_1, image_2], outputs, name="vision_encoder")516517518"""519Text encoder utilities520"""521522523def create_text_encoder(524num_projection_layers, projection_dims, dropout_rate, trainable=False525):526# Load the pre-trained BERT BackBone using KerasHub.527bert = keras_hub.models.BertBackbone.from_preset(528"bert_base_en_uncased", num_classes=3529)530531# Set the trainability of the base encoder.532bert.trainable = trainable533534# Receive the text as inputs.535bert_input_features = ["padding_mask", "segment_ids", "token_ids"]536inputs = {537feature: keras.Input(shape=(256,), dtype="int32", name=feature)538for feature in bert_input_features539}540541# Generate embeddings for the preprocessed text using the BERT model.542embeddings = bert(inputs)["pooled_output"]543544# Project the embeddings produced by the model.545outputs = project_embeddings(546embeddings, num_projection_layers, projection_dims, dropout_rate547)548# Create the text encoder model.549return keras.Model(inputs, outputs, name="text_encoder")550551552"""553Multimodal model utilities554"""555556557def create_multimodal_model(558num_projection_layers=1,559projection_dims=256,560dropout_rate=0.1,561vision_trainable=False,562text_trainable=False,563):564# Receive the images as inputs.565image_1 = keras.Input(shape=(128, 128, 3), name="image_1")566image_2 = keras.Input(shape=(128, 128, 3), name="image_2")567568# Receive the text as inputs.569bert_input_features = ["padding_mask", "segment_ids", "token_ids"]570text_inputs = {571feature: keras.Input(shape=(256,), dtype="int32", name=feature)572for feature in bert_input_features573}574text_inputs = list(text_inputs.values())575# Create the encoders.576vision_encoder = create_vision_encoder(577num_projection_layers, projection_dims, dropout_rate, vision_trainable578)579text_encoder = create_text_encoder(580num_projection_layers, projection_dims, dropout_rate, text_trainable581)582583# Fetch the embedding projections.584vision_projections = vision_encoder([image_1, image_2])585text_projections = text_encoder(text_inputs)586587# Concatenate the projections and pass through the classification layer.588concatenated = keras.layers.Concatenate()([vision_projections, text_projections])589outputs = keras.layers.Dense(3, activation="softmax")(concatenated)590return keras.Model([image_1, image_2, *text_inputs], outputs)591592593multimodal_model = create_multimodal_model()594keras.utils.plot_model(multimodal_model, show_shapes=True)595596"""597You can inspect the structure of the individual encoders as well by setting the598`expand_nested` argument of `plot_model()` to `True`. You are encouraged599to play with the different hyperparameters involved in building this model and600observe how the final performance is affected.601"""602603"""604## Compile and train the model605"""606607multimodal_model.compile(608optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]609)610611history = multimodal_model.fit(train_ds, validation_data=validation_ds, epochs=1)612613"""614## Evaluate the model615"""616617_, acc = multimodal_model.evaluate(test_ds)618print(f"Accuracy on the test set: {round(acc * 100, 2)}%.")619620"""621## Additional notes regarding training622623**Incorporating regularization**:624625The training logs suggest that the model is starting to overfit and may have benefitted626from regularization. Dropout ([Srivastava et al.](https://jmlr.org/papers/v15/srivastava14a.html))627is a simple yet powerful regularization technique that we can use in our model.628But how should we apply it here?629630We could always introduce Dropout (`keras.layers.Dropout`) in between different layers of the model.631But here is another recipe. Our model expects inputs from two different data modalities.632What if either of the modalities is not present during inference? To account for this,633we can introduce Dropout to the individual projections just before they get concatenated:634635```python636vision_projections = keras.layers.Dropout(rate)(vision_projections)637text_projections = keras.layers.Dropout(rate)(text_projections)638concatenated = keras.layers.Concatenate()([vision_projections, text_projections])639```640641**Attending to what matters**:642643Do all parts of the images correspond equally to their textual counterparts? It's likely644not the case. To make our model only focus on the most important bits of the images that relate645well to their corresponding textual parts we can use "cross-attention":646647```python648# Embeddings.649vision_projections = vision_encoder([image_1, image_2])650text_projections = text_encoder(text_inputs)651652# Cross-attention (Luong-style).653query_value_attention_seq = keras.layers.Attention(use_scale=True, dropout=0.2)(654[vision_projections, text_projections]655)656# Concatenate.657concatenated = keras.layers.Concatenate()([vision_projections, text_projections])658contextual = keras.layers.Concatenate()([concatenated, query_value_attention_seq])659```660661To see this in action, refer to662[this notebook](https://github.com/sayakpaul/Multimodal-Entailment-Baseline/blob/main/multimodal_entailment_attn.ipynb).663664**Handling class imbalance**:665666The dataset suffers from class imbalance. Investigating the confusion matrix of the667above model reveals that it performs poorly on the minority classes. If we had used a668weighted loss then the training would have been more guided. You can check out669[this notebook](https://github.com/sayakpaul/Multimodal-Entailment-Baseline/blob/main/multimodal_entailment.ipynb)670that takes class-imbalance into account during model training.671672**Using only text inputs**:673674Also, what if we had only incorporated text inputs for the entailment task? Because of675the nature of the text inputs encountered on social media platforms, text inputs alone676would have hurt the final performance. Under a similar training setup, by only using677text inputs we get to 67.14% top-1 accuracy on the same test set. Refer to678[this notebook](https://github.com/sayakpaul/Multimodal-Entailment-Baseline/blob/main/text_entailment.ipynb)679for details.680681Finally, here is a table comparing different approaches taken for the entailment task:682683| Type | Standard<br>Cross-entropy | Loss-weighted<br>Cross-entropy | Focal Loss |684|:---: |:---: |:---: |:---: |685| Multimodal | 77.86% | 67.86% | 86.43% |686| Only text | 67.14% | 11.43% | 37.86% |687688You can check out [this repository](https://git.io/JR0HU) to learn more about how the689experiments were conducted to obtain these numbers.690"""691692"""693## Final remarks694695* The architecture we used in this example is too large for the number of data points696available for training. It's going to benefit from more data.697* We used a smaller variant of the original BERT model. Chances are high that with a698larger variant, this performance will be improved. TensorFlow Hub699[provides](https://www.tensorflow.org/text/tutorials/bert_glue#loading_models_from_tensorflow_hub)700a number of different BERT models that you can experiment with.701* We kept the pre-trained models frozen. Fine-tuning them on the multimodal entailment702task would could resulted in better performance.703* We built a simple baseline model for the multimodal entailment task. There are various704approaches that have been proposed to tackle the entailment problem.705[This presentation deck](https://docs.google.com/presentation/d/1mAB31BCmqzfedreNZYn4hsKPFmgHA9Kxz219DzyRY3c/edit?usp=sharing)706from the707[Recognizing Multimodal Entailment](https://multimodal-entailment.github.io/)708tutorial provides a comprehensive overview.709710You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/multimodal-entailment)711and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/multimodal_entailment)712"""713714715