Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/image_captioning.py
3507 views
1
"""
2
Title: Image Captioning
3
Author: [A_K_Nain](https://twitter.com/A_K_Nain)
4
Date created: 2021/05/29
5
Last modified: 2021/10/31
6
Description: Implement an image captioning model using a CNN and a Transformer.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Setup
12
"""
13
14
import os
15
16
os.environ["KERAS_BACKEND"] = "tensorflow"
17
18
import re
19
import numpy as np
20
import matplotlib.pyplot as plt
21
22
import tensorflow as tf
23
import keras
24
from keras import layers
25
from keras.applications import efficientnet
26
from keras.layers import TextVectorization
27
28
keras.utils.set_random_seed(111)
29
30
"""
31
## Download the dataset
32
33
We will be using the Flickr8K dataset for this tutorial. This dataset comprises over
34
8,000 images, that are each paired with five different captions.
35
"""
36
37
38
"""shell
39
wget -q https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip
40
wget -q https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip
41
unzip -qq Flickr8k_Dataset.zip
42
unzip -qq Flickr8k_text.zip
43
rm Flickr8k_Dataset.zip Flickr8k_text.zip
44
"""
45
46
47
# Path to the images
48
IMAGES_PATH = "Flicker8k_Dataset"
49
50
# Desired image dimensions
51
IMAGE_SIZE = (299, 299)
52
53
# Vocabulary size
54
VOCAB_SIZE = 10000
55
56
# Fixed length allowed for any sequence
57
SEQ_LENGTH = 25
58
59
# Dimension for the image embeddings and token embeddings
60
EMBED_DIM = 512
61
62
# Per-layer units in the feed-forward network
63
FF_DIM = 512
64
65
# Other training parameters
66
BATCH_SIZE = 64
67
EPOCHS = 30
68
AUTOTUNE = tf.data.AUTOTUNE
69
70
"""
71
## Preparing the dataset
72
"""
73
74
75
def load_captions_data(filename):
76
"""Loads captions (text) data and maps them to corresponding images.
77
78
Args:
79
filename: Path to the text file containing caption data.
80
81
Returns:
82
caption_mapping: Dictionary mapping image names and the corresponding captions
83
text_data: List containing all the available captions
84
"""
85
86
with open(filename) as caption_file:
87
caption_data = caption_file.readlines()
88
caption_mapping = {}
89
text_data = []
90
images_to_skip = set()
91
92
for line in caption_data:
93
line = line.rstrip("\n")
94
# Image name and captions are separated using a tab
95
img_name, caption = line.split("\t")
96
97
# Each image is repeated five times for the five different captions.
98
# Each image name has a suffix `#(caption_number)`
99
img_name = img_name.split("#")[0]
100
img_name = os.path.join(IMAGES_PATH, img_name.strip())
101
102
# We will remove caption that are either too short to too long
103
tokens = caption.strip().split()
104
105
if len(tokens) < 5 or len(tokens) > SEQ_LENGTH:
106
images_to_skip.add(img_name)
107
continue
108
109
if img_name.endswith("jpg") and img_name not in images_to_skip:
110
# We will add a start and an end token to each caption
111
caption = "<start> " + caption.strip() + " <end>"
112
text_data.append(caption)
113
114
if img_name in caption_mapping:
115
caption_mapping[img_name].append(caption)
116
else:
117
caption_mapping[img_name] = [caption]
118
119
for img_name in images_to_skip:
120
if img_name in caption_mapping:
121
del caption_mapping[img_name]
122
123
return caption_mapping, text_data
124
125
126
def train_val_split(caption_data, train_size=0.8, shuffle=True):
127
"""Split the captioning dataset into train and validation sets.
128
129
Args:
130
caption_data (dict): Dictionary containing the mapped caption data
131
train_size (float): Fraction of all the full dataset to use as training data
132
shuffle (bool): Whether to shuffle the dataset before splitting
133
134
Returns:
135
Traning and validation datasets as two separated dicts
136
"""
137
138
# 1. Get the list of all image names
139
all_images = list(caption_data.keys())
140
141
# 2. Shuffle if necessary
142
if shuffle:
143
np.random.shuffle(all_images)
144
145
# 3. Split into training and validation sets
146
train_size = int(len(caption_data) * train_size)
147
148
training_data = {
149
img_name: caption_data[img_name] for img_name in all_images[:train_size]
150
}
151
validation_data = {
152
img_name: caption_data[img_name] for img_name in all_images[train_size:]
153
}
154
155
# 4. Return the splits
156
return training_data, validation_data
157
158
159
# Load the dataset
160
captions_mapping, text_data = load_captions_data("Flickr8k.token.txt")
161
162
# Split the dataset into training and validation sets
163
train_data, valid_data = train_val_split(captions_mapping)
164
print("Number of training samples: ", len(train_data))
165
print("Number of validation samples: ", len(valid_data))
166
167
"""
168
## Vectorizing the text data
169
170
We'll use the `TextVectorization` layer to vectorize the text data,
171
that is to say, to turn the
172
original strings into integer sequences where each integer represents the index of
173
a word in a vocabulary. We will use a custom string standardization scheme
174
(strip punctuation characters except `<` and `>`) and the default
175
splitting scheme (split on whitespace).
176
"""
177
178
179
def custom_standardization(input_string):
180
lowercase = tf.strings.lower(input_string)
181
return tf.strings.regex_replace(lowercase, "[%s]" % re.escape(strip_chars), "")
182
183
184
strip_chars = "!\"#$%&'()*+,-./:;<=>?@[\]^_`{|}~"
185
strip_chars = strip_chars.replace("<", "")
186
strip_chars = strip_chars.replace(">", "")
187
188
vectorization = TextVectorization(
189
max_tokens=VOCAB_SIZE,
190
output_mode="int",
191
output_sequence_length=SEQ_LENGTH,
192
standardize=custom_standardization,
193
)
194
vectorization.adapt(text_data)
195
196
# Data augmentation for image data
197
image_augmentation = keras.Sequential(
198
[
199
layers.RandomFlip("horizontal"),
200
layers.RandomRotation(0.2),
201
layers.RandomContrast(0.3),
202
]
203
)
204
205
206
"""
207
## Building a `tf.data.Dataset` pipeline for training
208
209
We will generate pairs of images and corresponding captions using a `tf.data.Dataset` object.
210
The pipeline consists of two steps:
211
212
1. Read the image from the disk
213
2. Tokenize all the five captions corresponding to the image
214
"""
215
216
217
def decode_and_resize(img_path):
218
img = tf.io.read_file(img_path)
219
img = tf.image.decode_jpeg(img, channels=3)
220
img = tf.image.resize(img, IMAGE_SIZE)
221
img = tf.image.convert_image_dtype(img, tf.float32)
222
return img
223
224
225
def process_input(img_path, captions):
226
return decode_and_resize(img_path), vectorization(captions)
227
228
229
def make_dataset(images, captions):
230
dataset = tf.data.Dataset.from_tensor_slices((images, captions))
231
dataset = dataset.shuffle(BATCH_SIZE * 8)
232
dataset = dataset.map(process_input, num_parallel_calls=AUTOTUNE)
233
dataset = dataset.batch(BATCH_SIZE).prefetch(AUTOTUNE)
234
235
return dataset
236
237
238
# Pass the list of images and the list of corresponding captions
239
train_dataset = make_dataset(list(train_data.keys()), list(train_data.values()))
240
241
valid_dataset = make_dataset(list(valid_data.keys()), list(valid_data.values()))
242
243
244
"""
245
## Building the model
246
247
Our image captioning architecture consists of three models:
248
249
1. A CNN: used to extract the image features
250
2. A TransformerEncoder: The extracted image features are then passed to a Transformer
251
based encoder that generates a new representation of the inputs
252
3. A TransformerDecoder: This model takes the encoder output and the text data
253
(sequences) as inputs and tries to learn to generate the caption.
254
"""
255
256
257
def get_cnn_model():
258
base_model = efficientnet.EfficientNetB0(
259
input_shape=(*IMAGE_SIZE, 3),
260
include_top=False,
261
weights="imagenet",
262
)
263
# We freeze our feature extractor
264
base_model.trainable = False
265
base_model_out = base_model.output
266
base_model_out = layers.Reshape((-1, base_model_out.shape[-1]))(base_model_out)
267
cnn_model = keras.models.Model(base_model.input, base_model_out)
268
return cnn_model
269
270
271
class TransformerEncoderBlock(layers.Layer):
272
def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
273
super().__init__(**kwargs)
274
self.embed_dim = embed_dim
275
self.dense_dim = dense_dim
276
self.num_heads = num_heads
277
self.attention_1 = layers.MultiHeadAttention(
278
num_heads=num_heads, key_dim=embed_dim, dropout=0.0
279
)
280
self.layernorm_1 = layers.LayerNormalization()
281
self.layernorm_2 = layers.LayerNormalization()
282
self.dense_1 = layers.Dense(embed_dim, activation="relu")
283
284
def call(self, inputs, training, mask=None):
285
inputs = self.layernorm_1(inputs)
286
inputs = self.dense_1(inputs)
287
288
attention_output_1 = self.attention_1(
289
query=inputs,
290
value=inputs,
291
key=inputs,
292
attention_mask=None,
293
training=training,
294
)
295
out_1 = self.layernorm_2(inputs + attention_output_1)
296
return out_1
297
298
299
class PositionalEmbedding(layers.Layer):
300
def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
301
super().__init__(**kwargs)
302
self.token_embeddings = layers.Embedding(
303
input_dim=vocab_size, output_dim=embed_dim
304
)
305
self.position_embeddings = layers.Embedding(
306
input_dim=sequence_length, output_dim=embed_dim
307
)
308
self.sequence_length = sequence_length
309
self.vocab_size = vocab_size
310
self.embed_dim = embed_dim
311
self.embed_scale = tf.math.sqrt(tf.cast(embed_dim, tf.float32))
312
313
def call(self, inputs):
314
length = tf.shape(inputs)[-1]
315
positions = tf.range(start=0, limit=length, delta=1)
316
embedded_tokens = self.token_embeddings(inputs)
317
embedded_tokens = embedded_tokens * self.embed_scale
318
embedded_positions = self.position_embeddings(positions)
319
return embedded_tokens + embedded_positions
320
321
def compute_mask(self, inputs, mask=None):
322
return tf.math.not_equal(inputs, 0)
323
324
325
class TransformerDecoderBlock(layers.Layer):
326
def __init__(self, embed_dim, ff_dim, num_heads, **kwargs):
327
super().__init__(**kwargs)
328
self.embed_dim = embed_dim
329
self.ff_dim = ff_dim
330
self.num_heads = num_heads
331
self.attention_1 = layers.MultiHeadAttention(
332
num_heads=num_heads, key_dim=embed_dim, dropout=0.1
333
)
334
self.attention_2 = layers.MultiHeadAttention(
335
num_heads=num_heads, key_dim=embed_dim, dropout=0.1
336
)
337
self.ffn_layer_1 = layers.Dense(ff_dim, activation="relu")
338
self.ffn_layer_2 = layers.Dense(embed_dim)
339
340
self.layernorm_1 = layers.LayerNormalization()
341
self.layernorm_2 = layers.LayerNormalization()
342
self.layernorm_3 = layers.LayerNormalization()
343
344
self.embedding = PositionalEmbedding(
345
embed_dim=EMBED_DIM,
346
sequence_length=SEQ_LENGTH,
347
vocab_size=VOCAB_SIZE,
348
)
349
self.out = layers.Dense(VOCAB_SIZE, activation="softmax")
350
351
self.dropout_1 = layers.Dropout(0.3)
352
self.dropout_2 = layers.Dropout(0.5)
353
self.supports_masking = True
354
355
def call(self, inputs, encoder_outputs, training, mask=None):
356
inputs = self.embedding(inputs)
357
causal_mask = self.get_causal_attention_mask(inputs)
358
359
if mask is not None:
360
padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32)
361
combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32)
362
combined_mask = tf.minimum(combined_mask, causal_mask)
363
364
attention_output_1 = self.attention_1(
365
query=inputs,
366
value=inputs,
367
key=inputs,
368
attention_mask=combined_mask,
369
training=training,
370
)
371
out_1 = self.layernorm_1(inputs + attention_output_1)
372
373
attention_output_2 = self.attention_2(
374
query=out_1,
375
value=encoder_outputs,
376
key=encoder_outputs,
377
attention_mask=padding_mask,
378
training=training,
379
)
380
out_2 = self.layernorm_2(out_1 + attention_output_2)
381
382
ffn_out = self.ffn_layer_1(out_2)
383
ffn_out = self.dropout_1(ffn_out, training=training)
384
ffn_out = self.ffn_layer_2(ffn_out)
385
386
ffn_out = self.layernorm_3(ffn_out + out_2, training=training)
387
ffn_out = self.dropout_2(ffn_out, training=training)
388
preds = self.out(ffn_out)
389
return preds
390
391
def get_causal_attention_mask(self, inputs):
392
input_shape = tf.shape(inputs)
393
batch_size, sequence_length = input_shape[0], input_shape[1]
394
i = tf.range(sequence_length)[:, tf.newaxis]
395
j = tf.range(sequence_length)
396
mask = tf.cast(i >= j, dtype="int32")
397
mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
398
mult = tf.concat(
399
[
400
tf.expand_dims(batch_size, -1),
401
tf.constant([1, 1], dtype=tf.int32),
402
],
403
axis=0,
404
)
405
return tf.tile(mask, mult)
406
407
408
class ImageCaptioningModel(keras.Model):
409
def __init__(
410
self,
411
cnn_model,
412
encoder,
413
decoder,
414
num_captions_per_image=5,
415
image_aug=None,
416
):
417
super().__init__()
418
self.cnn_model = cnn_model
419
self.encoder = encoder
420
self.decoder = decoder
421
self.loss_tracker = keras.metrics.Mean(name="loss")
422
self.acc_tracker = keras.metrics.Mean(name="accuracy")
423
self.num_captions_per_image = num_captions_per_image
424
self.image_aug = image_aug
425
426
def calculate_loss(self, y_true, y_pred, mask):
427
loss = self.loss(y_true, y_pred)
428
mask = tf.cast(mask, dtype=loss.dtype)
429
loss *= mask
430
return tf.reduce_sum(loss) / tf.reduce_sum(mask)
431
432
def calculate_accuracy(self, y_true, y_pred, mask):
433
accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=2))
434
accuracy = tf.math.logical_and(mask, accuracy)
435
accuracy = tf.cast(accuracy, dtype=tf.float32)
436
mask = tf.cast(mask, dtype=tf.float32)
437
return tf.reduce_sum(accuracy) / tf.reduce_sum(mask)
438
439
def _compute_caption_loss_and_acc(self, img_embed, batch_seq, training=True):
440
encoder_out = self.encoder(img_embed, training=training)
441
batch_seq_inp = batch_seq[:, :-1]
442
batch_seq_true = batch_seq[:, 1:]
443
mask = tf.math.not_equal(batch_seq_true, 0)
444
batch_seq_pred = self.decoder(
445
batch_seq_inp, encoder_out, training=training, mask=mask
446
)
447
loss = self.calculate_loss(batch_seq_true, batch_seq_pred, mask)
448
acc = self.calculate_accuracy(batch_seq_true, batch_seq_pred, mask)
449
return loss, acc
450
451
def train_step(self, batch_data):
452
batch_img, batch_seq = batch_data
453
batch_loss = 0
454
batch_acc = 0
455
456
if self.image_aug:
457
batch_img = self.image_aug(batch_img)
458
459
# 1. Get image embeddings
460
img_embed = self.cnn_model(batch_img)
461
462
# 2. Pass each of the five captions one by one to the decoder
463
# along with the encoder outputs and compute the loss as well as accuracy
464
# for each caption.
465
for i in range(self.num_captions_per_image):
466
with tf.GradientTape() as tape:
467
loss, acc = self._compute_caption_loss_and_acc(
468
img_embed, batch_seq[:, i, :], training=True
469
)
470
471
# 3. Update loss and accuracy
472
batch_loss += loss
473
batch_acc += acc
474
475
# 4. Get the list of all the trainable weights
476
train_vars = (
477
self.encoder.trainable_variables + self.decoder.trainable_variables
478
)
479
480
# 5. Get the gradients
481
grads = tape.gradient(loss, train_vars)
482
483
# 6. Update the trainable weights
484
self.optimizer.apply_gradients(zip(grads, train_vars))
485
486
# 7. Update the trackers
487
batch_acc /= float(self.num_captions_per_image)
488
self.loss_tracker.update_state(batch_loss)
489
self.acc_tracker.update_state(batch_acc)
490
491
# 8. Return the loss and accuracy values
492
return {
493
"loss": self.loss_tracker.result(),
494
"acc": self.acc_tracker.result(),
495
}
496
497
def test_step(self, batch_data):
498
batch_img, batch_seq = batch_data
499
batch_loss = 0
500
batch_acc = 0
501
502
# 1. Get image embeddings
503
img_embed = self.cnn_model(batch_img)
504
505
# 2. Pass each of the five captions one by one to the decoder
506
# along with the encoder outputs and compute the loss as well as accuracy
507
# for each caption.
508
for i in range(self.num_captions_per_image):
509
loss, acc = self._compute_caption_loss_and_acc(
510
img_embed, batch_seq[:, i, :], training=False
511
)
512
513
# 3. Update batch loss and batch accuracy
514
batch_loss += loss
515
batch_acc += acc
516
517
batch_acc /= float(self.num_captions_per_image)
518
519
# 4. Update the trackers
520
self.loss_tracker.update_state(batch_loss)
521
self.acc_tracker.update_state(batch_acc)
522
523
# 5. Return the loss and accuracy values
524
return {
525
"loss": self.loss_tracker.result(),
526
"acc": self.acc_tracker.result(),
527
}
528
529
@property
530
def metrics(self):
531
# We need to list our metrics here so the `reset_states()` can be
532
# called automatically.
533
return [self.loss_tracker, self.acc_tracker]
534
535
536
cnn_model = get_cnn_model()
537
encoder = TransformerEncoderBlock(embed_dim=EMBED_DIM, dense_dim=FF_DIM, num_heads=1)
538
decoder = TransformerDecoderBlock(embed_dim=EMBED_DIM, ff_dim=FF_DIM, num_heads=2)
539
caption_model = ImageCaptioningModel(
540
cnn_model=cnn_model,
541
encoder=encoder,
542
decoder=decoder,
543
image_aug=image_augmentation,
544
)
545
546
"""
547
## Model training
548
"""
549
550
551
# Define the loss function
552
cross_entropy = keras.losses.SparseCategoricalCrossentropy(
553
from_logits=False,
554
reduction=None,
555
)
556
557
# EarlyStopping criteria
558
early_stopping = keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True)
559
560
561
# Learning Rate Scheduler for the optimizer
562
class LRSchedule(keras.optimizers.schedules.LearningRateSchedule):
563
def __init__(self, post_warmup_learning_rate, warmup_steps):
564
super().__init__()
565
self.post_warmup_learning_rate = post_warmup_learning_rate
566
self.warmup_steps = warmup_steps
567
568
def __call__(self, step):
569
global_step = tf.cast(step, tf.float32)
570
warmup_steps = tf.cast(self.warmup_steps, tf.float32)
571
warmup_progress = global_step / warmup_steps
572
warmup_learning_rate = self.post_warmup_learning_rate * warmup_progress
573
return tf.cond(
574
global_step < warmup_steps,
575
lambda: warmup_learning_rate,
576
lambda: self.post_warmup_learning_rate,
577
)
578
579
580
# Create a learning rate schedule
581
num_train_steps = len(train_dataset) * EPOCHS
582
num_warmup_steps = num_train_steps // 15
583
lr_schedule = LRSchedule(post_warmup_learning_rate=1e-4, warmup_steps=num_warmup_steps)
584
585
# Compile the model
586
caption_model.compile(optimizer=keras.optimizers.Adam(lr_schedule), loss=cross_entropy)
587
588
# Fit the model
589
caption_model.fit(
590
train_dataset,
591
epochs=EPOCHS,
592
validation_data=valid_dataset,
593
callbacks=[early_stopping],
594
)
595
596
"""
597
## Check sample predictions
598
"""
599
600
vocab = vectorization.get_vocabulary()
601
index_lookup = dict(zip(range(len(vocab)), vocab))
602
max_decoded_sentence_length = SEQ_LENGTH - 1
603
valid_images = list(valid_data.keys())
604
605
606
def generate_caption():
607
# Select a random image from the validation dataset
608
sample_img = np.random.choice(valid_images)
609
610
# Read the image from the disk
611
sample_img = decode_and_resize(sample_img)
612
img = sample_img.numpy().clip(0, 255).astype(np.uint8)
613
plt.imshow(img)
614
plt.show()
615
616
# Pass the image to the CNN
617
img = tf.expand_dims(sample_img, 0)
618
img = caption_model.cnn_model(img)
619
620
# Pass the image features to the Transformer encoder
621
encoded_img = caption_model.encoder(img, training=False)
622
623
# Generate the caption using the Transformer decoder
624
decoded_caption = "<start> "
625
for i in range(max_decoded_sentence_length):
626
tokenized_caption = vectorization([decoded_caption])[:, :-1]
627
mask = tf.math.not_equal(tokenized_caption, 0)
628
predictions = caption_model.decoder(
629
tokenized_caption, encoded_img, training=False, mask=mask
630
)
631
sampled_token_index = np.argmax(predictions[0, i, :])
632
sampled_token = index_lookup[sampled_token_index]
633
if sampled_token == "<end>":
634
break
635
decoded_caption += " " + sampled_token
636
637
decoded_caption = decoded_caption.replace("<start> ", "")
638
decoded_caption = decoded_caption.replace(" <end>", "").strip()
639
print("Predicted Caption: ", decoded_caption)
640
641
642
# Check predictions for a few samples
643
generate_caption()
644
generate_caption()
645
generate_caption()
646
647
"""
648
## End Notes
649
650
We saw that the model starts to generate reasonable captions after a few epochs. To keep
651
this example easily runnable, we have trained it with a few constraints, like a minimal
652
number of attention heads. To improve the predictions, you can try changing these training
653
settings and find a good model for your use case.
654
"""
655
656