Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/nlp/multimodal_entailment.py
3507 views
1
"""
2
Title: Multimodal entailment
3
Author: [Sayak Paul](https://twitter.com/RisingSayak)
4
Date created: 2021/08/08
5
Last modified: 2025/01/03
6
Description: Training a multimodal model for predicting entailment.
7
Accelerator: GPU
8
Converted to Keras 3 and made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234)
9
"""
10
11
"""
12
## Introduction
13
14
In this example, we will build and train a model for predicting multimodal entailment. We will be
15
using the
16
[multimodal entailment dataset](https://github.com/google-research-datasets/recognizing-multimodal-entailment)
17
recently introduced by Google Research.
18
19
### What is multimodal entailment?
20
21
On social media platforms, to audit and moderate content
22
we may want to find answers to the
23
following questions in near real-time:
24
25
* Does a given piece of information contradict the other?
26
* Does a given piece of information imply the other?
27
28
In NLP, this task is called analyzing _textual entailment_. However, that's only
29
when the information comes from text content.
30
In practice, it's often the case the information available comes not just
31
from text content, but from a multimodal combination of text, images, audio, video, etc.
32
_Multimodal entailment_ is simply the extension of textual entailment to a variety
33
of new input modalities.
34
35
### Requirements
36
37
This example requires TensorFlow 2.5 or higher. In addition, TensorFlow Hub and
38
TensorFlow Text are required for the BERT model
39
([Devlin et al.](https://arxiv.org/abs/1810.04805)). These libraries can be installed
40
using the following command:
41
"""
42
43
"""shell
44
pip install -q tensorflow_text
45
"""
46
47
"""
48
## Imports
49
"""
50
51
from sklearn.model_selection import train_test_split
52
import matplotlib.pyplot as plt
53
import pandas as pd
54
import numpy as np
55
import random
56
import math
57
from skimage.io import imread
58
from skimage.transform import resize
59
from PIL import Image
60
import os
61
62
os.environ["KERAS_BACKEND"] = "jax" # or tensorflow, or torch
63
64
import keras
65
import keras_hub
66
from keras.utils import PyDataset
67
68
"""
69
## Define a label map
70
"""
71
72
label_map = {"Contradictory": 0, "Implies": 1, "NoEntailment": 2}
73
74
"""
75
## Collect the dataset
76
77
The original dataset is available
78
[here](https://github.com/google-research-datasets/recognizing-multimodal-entailment).
79
It comes with URLs of images which are hosted on Twitter's photo storage system called
80
the
81
[Photo Blob Storage (PBS for short)](https://blog.twitter.com/engineering/en_us/a/2012/blobstore-twitter-s-in-house-photo-storage-system).
82
We will be working with the downloaded images along with additional data that comes with
83
the original dataset. Thanks to
84
[Nilabhra Roy Chowdhury](https://de.linkedin.com/in/nilabhraroychowdhury) who worked on
85
preparing the image data.
86
"""
87
88
image_base_path = keras.utils.get_file(
89
"tweet_images",
90
"https://github.com/sayakpaul/Multimodal-Entailment-Baseline/releases/download/v1.0.0/tweet_images.tar.gz",
91
untar=True,
92
)
93
94
"""
95
## Read the dataset and apply basic preprocessing
96
"""
97
98
df = pd.read_csv(
99
"https://github.com/sayakpaul/Multimodal-Entailment-Baseline/raw/main/csvs/tweets.csv"
100
).iloc[
101
0:1000
102
] # Resources conservation since these are examples and not SOTA
103
df.sample(10)
104
105
"""
106
The columns we are interested in are the following:
107
108
* `text_1`
109
* `image_1`
110
* `text_2`
111
* `image_2`
112
* `label`
113
114
The entailment task is formulated as the following:
115
116
***Given the pairs of (`text_1`, `image_1`) and (`text_2`, `image_2`) do they entail (or
117
not entail or contradict) each other?***
118
119
We have the images already downloaded. `image_1` is downloaded as `id1` as its filename
120
and `image2` is downloaded as `id2` as its filename. In the next step, we will add two
121
more columns to `df` - filepaths of `image_1`s and `image_2`s.
122
"""
123
124
images_one_paths = []
125
images_two_paths = []
126
127
for idx in range(len(df)):
128
current_row = df.iloc[idx]
129
id_1 = current_row["id_1"]
130
id_2 = current_row["id_2"]
131
extentsion_one = current_row["image_1"].split(".")[-1]
132
extentsion_two = current_row["image_2"].split(".")[-1]
133
134
image_one_path = os.path.join(image_base_path, str(id_1) + f".{extentsion_one}")
135
image_two_path = os.path.join(image_base_path, str(id_2) + f".{extentsion_two}")
136
137
images_one_paths.append(image_one_path)
138
images_two_paths.append(image_two_path)
139
140
df["image_1_path"] = images_one_paths
141
df["image_2_path"] = images_two_paths
142
143
# Create another column containing the integer ids of
144
# the string labels.
145
df["label_idx"] = df["label"].apply(lambda x: label_map[x])
146
147
"""
148
## Dataset visualization
149
"""
150
151
152
def visualize(idx):
153
current_row = df.iloc[idx]
154
image_1 = plt.imread(current_row["image_1_path"])
155
image_2 = plt.imread(current_row["image_2_path"])
156
text_1 = current_row["text_1"]
157
text_2 = current_row["text_2"]
158
label = current_row["label"]
159
160
plt.subplot(1, 2, 1)
161
plt.imshow(image_1)
162
plt.axis("off")
163
plt.title("Image One")
164
plt.subplot(1, 2, 2)
165
plt.imshow(image_1)
166
plt.axis("off")
167
plt.title("Image Two")
168
plt.show()
169
170
print(f"Text one: {text_1}")
171
print(f"Text two: {text_2}")
172
print(f"Label: {label}")
173
174
175
random_idx = random.choice(range(len(df)))
176
visualize(random_idx)
177
178
random_idx = random.choice(range(len(df)))
179
visualize(random_idx)
180
181
"""
182
## Train/test split
183
184
The dataset suffers from
185
[class imbalance problem](https://developers.google.com/machine-learning/glossary#class-imbalanced-dataset).
186
We can confirm that in the following cell.
187
"""
188
189
df["label"].value_counts()
190
191
"""
192
To account for that we will go for a stratified split.
193
"""
194
195
# 10% for test
196
train_df, test_df = train_test_split(
197
df, test_size=0.1, stratify=df["label"].values, random_state=42
198
)
199
# 5% for validation
200
train_df, val_df = train_test_split(
201
train_df, test_size=0.05, stratify=train_df["label"].values, random_state=42
202
)
203
204
print(f"Total training examples: {len(train_df)}")
205
print(f"Total validation examples: {len(val_df)}")
206
print(f"Total test examples: {len(test_df)}")
207
208
"""
209
## Data input pipeline
210
211
Keras Hub provides
212
[variety of BERT family of models](https://keras.io/keras_hub/presets/).
213
Each of those models comes with a
214
corresponding preprocessing layer. You can learn more about these models and their
215
preprocessing layers from
216
[this resource](https://www.kaggle.com/models/keras/bert/keras/bert_base_en_uncased/2).
217
218
To keep the runtime of this example relatively short, we will use a base_unacased variant of
219
the original BERT model.
220
"""
221
222
"""
223
text preprocessing using KerasHub
224
"""
225
226
text_preprocessor = keras_hub.models.BertTextClassifierPreprocessor.from_preset(
227
"bert_base_en_uncased",
228
sequence_length=128,
229
)
230
231
"""
232
### Run the preprocessor on a sample input
233
"""
234
235
idx = random.choice(range(len(train_df)))
236
row = train_df.iloc[idx]
237
sample_text_1, sample_text_2 = row["text_1"], row["text_2"]
238
print(f"Text 1: {sample_text_1}")
239
print(f"Text 2: {sample_text_2}")
240
241
test_text = [sample_text_1, sample_text_2]
242
text_preprocessed = text_preprocessor(test_text)
243
244
print("Keys : ", list(text_preprocessed.keys()))
245
print("Shape Token Ids : ", text_preprocessed["token_ids"].shape)
246
print("Token Ids : ", text_preprocessed["token_ids"][0, :16])
247
print(" Shape Padding Mask : ", text_preprocessed["padding_mask"].shape)
248
print("Padding Mask : ", text_preprocessed["padding_mask"][0, :16])
249
print("Shape Segment Ids : ", text_preprocessed["segment_ids"].shape)
250
print("Segment Ids : ", text_preprocessed["segment_ids"][0, :16])
251
252
253
"""
254
We will now create `tf.data.Dataset` objects from the dataframes.
255
256
Note that the text inputs will be preprocessed as a part of the data input pipeline. But
257
the preprocessing modules can also be a part of their corresponding BERT models. This
258
helps reduce the training/serving skew and lets our models operate with raw text inputs.
259
Follow [this tutorial](https://www.tensorflow.org/text/tutorials/classify_text_with_bert)
260
to learn more about how to incorporate the preprocessing modules directly inside the
261
models.
262
"""
263
264
265
def dataframe_to_dataset(dataframe):
266
columns = ["image_1_path", "image_2_path", "text_1", "text_2", "label_idx"]
267
ds = UnifiedPyDataset(
268
dataframe,
269
batch_size=32,
270
workers=4,
271
)
272
return ds
273
274
275
"""
276
### Preprocessing utilities
277
"""
278
279
bert_input_features = ["padding_mask", "segment_ids", "token_ids"]
280
281
282
def preprocess_text(text_1, text_2):
283
output = text_preprocessor([text_1, text_2])
284
output = {
285
feature: keras.ops.reshape(output[feature], [-1])
286
for feature in bert_input_features
287
}
288
return output
289
290
291
"""
292
### Create the final datasets, method adapted from PyDataset doc string.
293
"""
294
295
296
class UnifiedPyDataset(PyDataset):
297
"""A Keras-compatible dataset that processes a DataFrame for TensorFlow, JAX, and PyTorch."""
298
299
def __init__(
300
self,
301
df,
302
batch_size=32,
303
workers=4,
304
use_multiprocessing=False,
305
max_queue_size=10,
306
**kwargs,
307
):
308
"""
309
Args:
310
df: pandas DataFrame with data
311
batch_size: Batch size for dataset
312
workers: Number of workers to use for parallel loading (Keras)
313
use_multiprocessing: Whether to use multiprocessing
314
max_queue_size: Maximum size of the data queue for parallel loading
315
"""
316
super().__init__(**kwargs)
317
self.dataframe = df
318
columns = ["image_1_path", "image_2_path", "text_1", "text_2"]
319
# image files
320
self.image_x_1 = self.dataframe["image_1_path"]
321
self.image_x_2 = self.dataframe["image_1_path"]
322
self.image_y = self.dataframe["label_idx"]
323
# text files
324
self.text_x_1 = self.dataframe["text_1"]
325
self.text_x_2 = self.dataframe["text_2"]
326
self.text_y = self.dataframe["label_idx"]
327
# general
328
self.batch_size = batch_size
329
self.workers = workers
330
self.use_multiprocessing = use_multiprocessing
331
self.max_queue_size = max_queue_size
332
333
def __getitem__(self, index):
334
"""
335
Fetches a batch of data from the dataset at the given index.
336
"""
337
338
# Return x, y for batch idx.
339
low = index * self.batch_size
340
# Cap upper bound at array length; the last batch may be smaller
341
# if the total number of items is not a multiple of batch size.
342
# image files
343
high_image_1 = min(low + self.batch_size, len(self.image_x_1))
344
high_image_2 = min(low + self.batch_size, len(self.image_x_2))
345
# text
346
high_text_1 = min(low + self.batch_size, len(self.text_x_1))
347
high_text_2 = min(low + self.batch_size, len(self.text_x_1))
348
# images files
349
batch_image_x_1 = self.image_x_1[low:high_image_1]
350
batch_image_y_1 = self.image_y[low:high_image_1]
351
batch_image_x_2 = self.image_x_2[low:high_image_2]
352
batch_image_y_2 = self.image_y[low:high_image_2]
353
# text files
354
batch_text_x_1 = self.text_x_1[low:high_text_1]
355
batch_text_y_1 = self.text_y[low:high_text_1]
356
batch_text_x_2 = self.text_x_2[low:high_text_2]
357
batch_text_y_2 = self.text_y[low:high_text_2]
358
# image number 1 inputs
359
image_1 = [
360
resize(imread(file_name), (128, 128)) for file_name in batch_image_x_1
361
]
362
image_1 = [
363
( # exeperienced some shapes which were different from others.
364
np.array(Image.fromarray((img.astype(np.uint8))).convert("RGB"))
365
if img.shape[2] == 4
366
else img
367
)
368
for img in image_1
369
]
370
image_1 = np.array(image_1)
371
# Both text inputs to the model, return a dict for inputs to BertBackbone
372
text = {
373
key: np.array(
374
[
375
d[key]
376
for d in [
377
preprocess_text(file_path1, file_path2)
378
for file_path1, file_path2 in zip(
379
batch_text_x_1, batch_text_x_2
380
)
381
]
382
]
383
)
384
for key in ["padding_mask", "token_ids", "segment_ids"]
385
}
386
# Image number 2 model inputs
387
image_2 = [
388
resize(imread(file_name), (128, 128)) for file_name in batch_image_x_2
389
]
390
image_2 = [
391
( # exeperienced some shapes which were different from others
392
np.array(Image.fromarray((img.astype(np.uint8))).convert("RGB"))
393
if img.shape[2] == 4
394
else img
395
)
396
for img in image_2
397
]
398
# Stack the list comprehension to an nd.array
399
image_2 = np.array(image_2)
400
return (
401
{
402
"image_1": image_1,
403
"image_2": image_2,
404
"padding_mask": text["padding_mask"],
405
"segment_ids": text["segment_ids"],
406
"token_ids": text["token_ids"],
407
},
408
# Target lables
409
np.array(batch_image_y_1),
410
)
411
412
def __len__(self):
413
"""
414
Returns the number of batches in the dataset.
415
"""
416
return math.ceil(len(self.dataframe) / self.batch_size)
417
418
419
"""
420
Create train, validation and test datasets
421
"""
422
423
424
def prepare_dataset(dataframe):
425
ds = dataframe_to_dataset(dataframe)
426
return ds
427
428
429
train_ds = prepare_dataset(train_df)
430
validation_ds = prepare_dataset(val_df)
431
test_ds = prepare_dataset(test_df)
432
433
"""
434
## Model building utilities
435
436
Our final model will accept two images along with their text counterparts. While the
437
images will be directly fed to the model the text inputs will first be preprocessed and
438
then will make it into the model. Below is a visual illustration of this approach:
439
440
![](https://github.com/sayakpaul/Multimodal-Entailment-Baseline/raw/main/figures/brief_architecture.png)
441
442
The model consists of the following elements:
443
444
* A standalone encoder for the images. We will use a
445
[ResNet50V2](https://arxiv.org/abs/1603.05027) pre-trained on the ImageNet-1k dataset for
446
this.
447
* A standalone encoder for the images. A pre-trained BERT will be used for this.
448
449
After extracting the individual embeddings, they will be projected in an identical space.
450
Finally, their projections will be concatenated and be fed to the final classification
451
layer.
452
453
This is a multi-class classification problem involving the following classes:
454
455
* NoEntailment
456
* Implies
457
* Contradictory
458
459
`project_embeddings()`, `create_vision_encoder()`, and `create_text_encoder()` utilities
460
are referred from [this example](https://keras.io/examples/nlp/nl_image_search/).
461
"""
462
463
"""
464
Projection utilities
465
"""
466
467
468
def project_embeddings(
469
embeddings, num_projection_layers, projection_dims, dropout_rate
470
):
471
projected_embeddings = keras.layers.Dense(units=projection_dims)(embeddings)
472
for _ in range(num_projection_layers):
473
x = keras.ops.nn.gelu(projected_embeddings)
474
x = keras.layers.Dense(projection_dims)(x)
475
x = keras.layers.Dropout(dropout_rate)(x)
476
x = keras.layers.Add()([projected_embeddings, x])
477
projected_embeddings = keras.layers.LayerNormalization()(x)
478
return projected_embeddings
479
480
481
"""
482
Vision encoder utilities
483
"""
484
485
486
def create_vision_encoder(
487
num_projection_layers, projection_dims, dropout_rate, trainable=False
488
):
489
# Load the pre-trained ResNet50V2 model to be used as the base encoder.
490
resnet_v2 = keras.applications.ResNet50V2(
491
include_top=False, weights="imagenet", pooling="avg"
492
)
493
# Set the trainability of the base encoder.
494
for layer in resnet_v2.layers:
495
layer.trainable = trainable
496
497
# Receive the images as inputs.
498
image_1 = keras.Input(shape=(128, 128, 3), name="image_1")
499
image_2 = keras.Input(shape=(128, 128, 3), name="image_2")
500
501
# Preprocess the input image.
502
preprocessed_1 = keras.applications.resnet_v2.preprocess_input(image_1)
503
preprocessed_2 = keras.applications.resnet_v2.preprocess_input(image_2)
504
505
# Generate the embeddings for the images using the resnet_v2 model
506
# concatenate them.
507
embeddings_1 = resnet_v2(preprocessed_1)
508
embeddings_2 = resnet_v2(preprocessed_2)
509
embeddings = keras.layers.Concatenate()([embeddings_1, embeddings_2])
510
511
# Project the embeddings produced by the model.
512
outputs = project_embeddings(
513
embeddings, num_projection_layers, projection_dims, dropout_rate
514
)
515
# Create the vision encoder model.
516
return keras.Model([image_1, image_2], outputs, name="vision_encoder")
517
518
519
"""
520
Text encoder utilities
521
"""
522
523
524
def create_text_encoder(
525
num_projection_layers, projection_dims, dropout_rate, trainable=False
526
):
527
# Load the pre-trained BERT BackBone using KerasHub.
528
bert = keras_hub.models.BertBackbone.from_preset(
529
"bert_base_en_uncased", num_classes=3
530
)
531
532
# Set the trainability of the base encoder.
533
bert.trainable = trainable
534
535
# Receive the text as inputs.
536
bert_input_features = ["padding_mask", "segment_ids", "token_ids"]
537
inputs = {
538
feature: keras.Input(shape=(256,), dtype="int32", name=feature)
539
for feature in bert_input_features
540
}
541
542
# Generate embeddings for the preprocessed text using the BERT model.
543
embeddings = bert(inputs)["pooled_output"]
544
545
# Project the embeddings produced by the model.
546
outputs = project_embeddings(
547
embeddings, num_projection_layers, projection_dims, dropout_rate
548
)
549
# Create the text encoder model.
550
return keras.Model(inputs, outputs, name="text_encoder")
551
552
553
"""
554
Multimodal model utilities
555
"""
556
557
558
def create_multimodal_model(
559
num_projection_layers=1,
560
projection_dims=256,
561
dropout_rate=0.1,
562
vision_trainable=False,
563
text_trainable=False,
564
):
565
# Receive the images as inputs.
566
image_1 = keras.Input(shape=(128, 128, 3), name="image_1")
567
image_2 = keras.Input(shape=(128, 128, 3), name="image_2")
568
569
# Receive the text as inputs.
570
bert_input_features = ["padding_mask", "segment_ids", "token_ids"]
571
text_inputs = {
572
feature: keras.Input(shape=(256,), dtype="int32", name=feature)
573
for feature in bert_input_features
574
}
575
text_inputs = list(text_inputs.values())
576
# Create the encoders.
577
vision_encoder = create_vision_encoder(
578
num_projection_layers, projection_dims, dropout_rate, vision_trainable
579
)
580
text_encoder = create_text_encoder(
581
num_projection_layers, projection_dims, dropout_rate, text_trainable
582
)
583
584
# Fetch the embedding projections.
585
vision_projections = vision_encoder([image_1, image_2])
586
text_projections = text_encoder(text_inputs)
587
588
# Concatenate the projections and pass through the classification layer.
589
concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
590
outputs = keras.layers.Dense(3, activation="softmax")(concatenated)
591
return keras.Model([image_1, image_2, *text_inputs], outputs)
592
593
594
multimodal_model = create_multimodal_model()
595
keras.utils.plot_model(multimodal_model, show_shapes=True)
596
597
"""
598
You can inspect the structure of the individual encoders as well by setting the
599
`expand_nested` argument of `plot_model()` to `True`. You are encouraged
600
to play with the different hyperparameters involved in building this model and
601
observe how the final performance is affected.
602
"""
603
604
"""
605
## Compile and train the model
606
"""
607
608
multimodal_model.compile(
609
optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
610
)
611
612
history = multimodal_model.fit(train_ds, validation_data=validation_ds, epochs=1)
613
614
"""
615
## Evaluate the model
616
"""
617
618
_, acc = multimodal_model.evaluate(test_ds)
619
print(f"Accuracy on the test set: {round(acc * 100, 2)}%.")
620
621
"""
622
## Additional notes regarding training
623
624
**Incorporating regularization**:
625
626
The training logs suggest that the model is starting to overfit and may have benefitted
627
from regularization. Dropout ([Srivastava et al.](https://jmlr.org/papers/v15/srivastava14a.html))
628
is a simple yet powerful regularization technique that we can use in our model.
629
But how should we apply it here?
630
631
We could always introduce Dropout (`keras.layers.Dropout`) in between different layers of the model.
632
But here is another recipe. Our model expects inputs from two different data modalities.
633
What if either of the modalities is not present during inference? To account for this,
634
we can introduce Dropout to the individual projections just before they get concatenated:
635
636
```python
637
vision_projections = keras.layers.Dropout(rate)(vision_projections)
638
text_projections = keras.layers.Dropout(rate)(text_projections)
639
concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
640
```
641
642
**Attending to what matters**:
643
644
Do all parts of the images correspond equally to their textual counterparts? It's likely
645
not the case. To make our model only focus on the most important bits of the images that relate
646
well to their corresponding textual parts we can use "cross-attention":
647
648
```python
649
# Embeddings.
650
vision_projections = vision_encoder([image_1, image_2])
651
text_projections = text_encoder(text_inputs)
652
653
# Cross-attention (Luong-style).
654
query_value_attention_seq = keras.layers.Attention(use_scale=True, dropout=0.2)(
655
[vision_projections, text_projections]
656
)
657
# Concatenate.
658
concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
659
contextual = keras.layers.Concatenate()([concatenated, query_value_attention_seq])
660
```
661
662
To see this in action, refer to
663
[this notebook](https://github.com/sayakpaul/Multimodal-Entailment-Baseline/blob/main/multimodal_entailment_attn.ipynb).
664
665
**Handling class imbalance**:
666
667
The dataset suffers from class imbalance. Investigating the confusion matrix of the
668
above model reveals that it performs poorly on the minority classes. If we had used a
669
weighted loss then the training would have been more guided. You can check out
670
[this notebook](https://github.com/sayakpaul/Multimodal-Entailment-Baseline/blob/main/multimodal_entailment.ipynb)
671
that takes class-imbalance into account during model training.
672
673
**Using only text inputs**:
674
675
Also, what if we had only incorporated text inputs for the entailment task? Because of
676
the nature of the text inputs encountered on social media platforms, text inputs alone
677
would have hurt the final performance. Under a similar training setup, by only using
678
text inputs we get to 67.14% top-1 accuracy on the same test set. Refer to
679
[this notebook](https://github.com/sayakpaul/Multimodal-Entailment-Baseline/blob/main/text_entailment.ipynb)
680
for details.
681
682
Finally, here is a table comparing different approaches taken for the entailment task:
683
684
| Type | Standard<br>Cross-entropy | Loss-weighted<br>Cross-entropy | Focal Loss |
685
|:---: |:---: |:---: |:---: |
686
| Multimodal | 77.86% | 67.86% | 86.43% |
687
| Only text | 67.14% | 11.43% | 37.86% |
688
689
You can check out [this repository](https://git.io/JR0HU) to learn more about how the
690
experiments were conducted to obtain these numbers.
691
"""
692
693
"""
694
## Final remarks
695
696
* The architecture we used in this example is too large for the number of data points
697
available for training. It's going to benefit from more data.
698
* We used a smaller variant of the original BERT model. Chances are high that with a
699
larger variant, this performance will be improved. TensorFlow Hub
700
[provides](https://www.tensorflow.org/text/tutorials/bert_glue#loading_models_from_tensorflow_hub)
701
a number of different BERT models that you can experiment with.
702
* We kept the pre-trained models frozen. Fine-tuning them on the multimodal entailment
703
task would could resulted in better performance.
704
* We built a simple baseline model for the multimodal entailment task. There are various
705
approaches that have been proposed to tackle the entailment problem.
706
[This presentation deck](https://docs.google.com/presentation/d/1mAB31BCmqzfedreNZYn4hsKPFmgHA9Kxz219DzyRY3c/edit?usp=sharing)
707
from the
708
[Recognizing Multimodal Entailment](https://multimodal-entailment.github.io/)
709
tutorial provides a comprehensive overview.
710
711
You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/multimodal-entailment)
712
and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/multimodal_entailment)
713
"""
714
715