Path: blob/master/examples/vision/image_captioning.py
3507 views
"""1Title: Image Captioning2Author: [A_K_Nain](https://twitter.com/A_K_Nain)3Date created: 2021/05/294Last modified: 2021/10/315Description: Implement an image captioning model using a CNN and a Transformer.6Accelerator: GPU7"""89"""10## Setup11"""1213import os1415os.environ["KERAS_BACKEND"] = "tensorflow"1617import re18import numpy as np19import matplotlib.pyplot as plt2021import tensorflow as tf22import keras23from keras import layers24from keras.applications import efficientnet25from keras.layers import TextVectorization2627keras.utils.set_random_seed(111)2829"""30## Download the dataset3132We will be using the Flickr8K dataset for this tutorial. This dataset comprises over338,000 images, that are each paired with five different captions.34"""353637"""shell38wget -q https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip39wget -q https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip40unzip -qq Flickr8k_Dataset.zip41unzip -qq Flickr8k_text.zip42rm Flickr8k_Dataset.zip Flickr8k_text.zip43"""444546# Path to the images47IMAGES_PATH = "Flicker8k_Dataset"4849# Desired image dimensions50IMAGE_SIZE = (299, 299)5152# Vocabulary size53VOCAB_SIZE = 100005455# Fixed length allowed for any sequence56SEQ_LENGTH = 255758# Dimension for the image embeddings and token embeddings59EMBED_DIM = 5126061# Per-layer units in the feed-forward network62FF_DIM = 5126364# Other training parameters65BATCH_SIZE = 6466EPOCHS = 3067AUTOTUNE = tf.data.AUTOTUNE6869"""70## Preparing the dataset71"""727374def load_captions_data(filename):75"""Loads captions (text) data and maps them to corresponding images.7677Args:78filename: Path to the text file containing caption data.7980Returns:81caption_mapping: Dictionary mapping image names and the corresponding captions82text_data: List containing all the available captions83"""8485with open(filename) as caption_file:86caption_data = caption_file.readlines()87caption_mapping = {}88text_data = []89images_to_skip = set()9091for line in caption_data:92line = line.rstrip("\n")93# Image name and captions are separated using a tab94img_name, caption = line.split("\t")9596# Each image is repeated five times for the five different captions.97# Each image name has a suffix `#(caption_number)`98img_name = img_name.split("#")[0]99img_name = os.path.join(IMAGES_PATH, img_name.strip())100101# We will remove caption that are either too short to too long102tokens = caption.strip().split()103104if len(tokens) < 5 or len(tokens) > SEQ_LENGTH:105images_to_skip.add(img_name)106continue107108if img_name.endswith("jpg") and img_name not in images_to_skip:109# We will add a start and an end token to each caption110caption = "<start> " + caption.strip() + " <end>"111text_data.append(caption)112113if img_name in caption_mapping:114caption_mapping[img_name].append(caption)115else:116caption_mapping[img_name] = [caption]117118for img_name in images_to_skip:119if img_name in caption_mapping:120del caption_mapping[img_name]121122return caption_mapping, text_data123124125def train_val_split(caption_data, train_size=0.8, shuffle=True):126"""Split the captioning dataset into train and validation sets.127128Args:129caption_data (dict): Dictionary containing the mapped caption data130train_size (float): Fraction of all the full dataset to use as training data131shuffle (bool): Whether to shuffle the dataset before splitting132133Returns:134Traning and validation datasets as two separated dicts135"""136137# 1. Get the list of all image names138all_images = list(caption_data.keys())139140# 2. Shuffle if necessary141if shuffle:142np.random.shuffle(all_images)143144# 3. Split into training and validation sets145train_size = int(len(caption_data) * train_size)146147training_data = {148img_name: caption_data[img_name] for img_name in all_images[:train_size]149}150validation_data = {151img_name: caption_data[img_name] for img_name in all_images[train_size:]152}153154# 4. Return the splits155return training_data, validation_data156157158# Load the dataset159captions_mapping, text_data = load_captions_data("Flickr8k.token.txt")160161# Split the dataset into training and validation sets162train_data, valid_data = train_val_split(captions_mapping)163print("Number of training samples: ", len(train_data))164print("Number of validation samples: ", len(valid_data))165166"""167## Vectorizing the text data168169We'll use the `TextVectorization` layer to vectorize the text data,170that is to say, to turn the171original strings into integer sequences where each integer represents the index of172a word in a vocabulary. We will use a custom string standardization scheme173(strip punctuation characters except `<` and `>`) and the default174splitting scheme (split on whitespace).175"""176177178def custom_standardization(input_string):179lowercase = tf.strings.lower(input_string)180return tf.strings.regex_replace(lowercase, "[%s]" % re.escape(strip_chars), "")181182183strip_chars = "!\"#$%&'()*+,-./:;<=>?@[\]^_`{|}~"184strip_chars = strip_chars.replace("<", "")185strip_chars = strip_chars.replace(">", "")186187vectorization = TextVectorization(188max_tokens=VOCAB_SIZE,189output_mode="int",190output_sequence_length=SEQ_LENGTH,191standardize=custom_standardization,192)193vectorization.adapt(text_data)194195# Data augmentation for image data196image_augmentation = keras.Sequential(197[198layers.RandomFlip("horizontal"),199layers.RandomRotation(0.2),200layers.RandomContrast(0.3),201]202)203204205"""206## Building a `tf.data.Dataset` pipeline for training207208We will generate pairs of images and corresponding captions using a `tf.data.Dataset` object.209The pipeline consists of two steps:2102111. Read the image from the disk2122. Tokenize all the five captions corresponding to the image213"""214215216def decode_and_resize(img_path):217img = tf.io.read_file(img_path)218img = tf.image.decode_jpeg(img, channels=3)219img = tf.image.resize(img, IMAGE_SIZE)220img = tf.image.convert_image_dtype(img, tf.float32)221return img222223224def process_input(img_path, captions):225return decode_and_resize(img_path), vectorization(captions)226227228def make_dataset(images, captions):229dataset = tf.data.Dataset.from_tensor_slices((images, captions))230dataset = dataset.shuffle(BATCH_SIZE * 8)231dataset = dataset.map(process_input, num_parallel_calls=AUTOTUNE)232dataset = dataset.batch(BATCH_SIZE).prefetch(AUTOTUNE)233234return dataset235236237# Pass the list of images and the list of corresponding captions238train_dataset = make_dataset(list(train_data.keys()), list(train_data.values()))239240valid_dataset = make_dataset(list(valid_data.keys()), list(valid_data.values()))241242243"""244## Building the model245246Our image captioning architecture consists of three models:2472481. A CNN: used to extract the image features2492. A TransformerEncoder: The extracted image features are then passed to a Transformer250based encoder that generates a new representation of the inputs2513. A TransformerDecoder: This model takes the encoder output and the text data252(sequences) as inputs and tries to learn to generate the caption.253"""254255256def get_cnn_model():257base_model = efficientnet.EfficientNetB0(258input_shape=(*IMAGE_SIZE, 3),259include_top=False,260weights="imagenet",261)262# We freeze our feature extractor263base_model.trainable = False264base_model_out = base_model.output265base_model_out = layers.Reshape((-1, base_model_out.shape[-1]))(base_model_out)266cnn_model = keras.models.Model(base_model.input, base_model_out)267return cnn_model268269270class TransformerEncoderBlock(layers.Layer):271def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):272super().__init__(**kwargs)273self.embed_dim = embed_dim274self.dense_dim = dense_dim275self.num_heads = num_heads276self.attention_1 = layers.MultiHeadAttention(277num_heads=num_heads, key_dim=embed_dim, dropout=0.0278)279self.layernorm_1 = layers.LayerNormalization()280self.layernorm_2 = layers.LayerNormalization()281self.dense_1 = layers.Dense(embed_dim, activation="relu")282283def call(self, inputs, training, mask=None):284inputs = self.layernorm_1(inputs)285inputs = self.dense_1(inputs)286287attention_output_1 = self.attention_1(288query=inputs,289value=inputs,290key=inputs,291attention_mask=None,292training=training,293)294out_1 = self.layernorm_2(inputs + attention_output_1)295return out_1296297298class PositionalEmbedding(layers.Layer):299def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):300super().__init__(**kwargs)301self.token_embeddings = layers.Embedding(302input_dim=vocab_size, output_dim=embed_dim303)304self.position_embeddings = layers.Embedding(305input_dim=sequence_length, output_dim=embed_dim306)307self.sequence_length = sequence_length308self.vocab_size = vocab_size309self.embed_dim = embed_dim310self.embed_scale = tf.math.sqrt(tf.cast(embed_dim, tf.float32))311312def call(self, inputs):313length = tf.shape(inputs)[-1]314positions = tf.range(start=0, limit=length, delta=1)315embedded_tokens = self.token_embeddings(inputs)316embedded_tokens = embedded_tokens * self.embed_scale317embedded_positions = self.position_embeddings(positions)318return embedded_tokens + embedded_positions319320def compute_mask(self, inputs, mask=None):321return tf.math.not_equal(inputs, 0)322323324class TransformerDecoderBlock(layers.Layer):325def __init__(self, embed_dim, ff_dim, num_heads, **kwargs):326super().__init__(**kwargs)327self.embed_dim = embed_dim328self.ff_dim = ff_dim329self.num_heads = num_heads330self.attention_1 = layers.MultiHeadAttention(331num_heads=num_heads, key_dim=embed_dim, dropout=0.1332)333self.attention_2 = layers.MultiHeadAttention(334num_heads=num_heads, key_dim=embed_dim, dropout=0.1335)336self.ffn_layer_1 = layers.Dense(ff_dim, activation="relu")337self.ffn_layer_2 = layers.Dense(embed_dim)338339self.layernorm_1 = layers.LayerNormalization()340self.layernorm_2 = layers.LayerNormalization()341self.layernorm_3 = layers.LayerNormalization()342343self.embedding = PositionalEmbedding(344embed_dim=EMBED_DIM,345sequence_length=SEQ_LENGTH,346vocab_size=VOCAB_SIZE,347)348self.out = layers.Dense(VOCAB_SIZE, activation="softmax")349350self.dropout_1 = layers.Dropout(0.3)351self.dropout_2 = layers.Dropout(0.5)352self.supports_masking = True353354def call(self, inputs, encoder_outputs, training, mask=None):355inputs = self.embedding(inputs)356causal_mask = self.get_causal_attention_mask(inputs)357358if mask is not None:359padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32)360combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32)361combined_mask = tf.minimum(combined_mask, causal_mask)362363attention_output_1 = self.attention_1(364query=inputs,365value=inputs,366key=inputs,367attention_mask=combined_mask,368training=training,369)370out_1 = self.layernorm_1(inputs + attention_output_1)371372attention_output_2 = self.attention_2(373query=out_1,374value=encoder_outputs,375key=encoder_outputs,376attention_mask=padding_mask,377training=training,378)379out_2 = self.layernorm_2(out_1 + attention_output_2)380381ffn_out = self.ffn_layer_1(out_2)382ffn_out = self.dropout_1(ffn_out, training=training)383ffn_out = self.ffn_layer_2(ffn_out)384385ffn_out = self.layernorm_3(ffn_out + out_2, training=training)386ffn_out = self.dropout_2(ffn_out, training=training)387preds = self.out(ffn_out)388return preds389390def get_causal_attention_mask(self, inputs):391input_shape = tf.shape(inputs)392batch_size, sequence_length = input_shape[0], input_shape[1]393i = tf.range(sequence_length)[:, tf.newaxis]394j = tf.range(sequence_length)395mask = tf.cast(i >= j, dtype="int32")396mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))397mult = tf.concat(398[399tf.expand_dims(batch_size, -1),400tf.constant([1, 1], dtype=tf.int32),401],402axis=0,403)404return tf.tile(mask, mult)405406407class ImageCaptioningModel(keras.Model):408def __init__(409self,410cnn_model,411encoder,412decoder,413num_captions_per_image=5,414image_aug=None,415):416super().__init__()417self.cnn_model = cnn_model418self.encoder = encoder419self.decoder = decoder420self.loss_tracker = keras.metrics.Mean(name="loss")421self.acc_tracker = keras.metrics.Mean(name="accuracy")422self.num_captions_per_image = num_captions_per_image423self.image_aug = image_aug424425def calculate_loss(self, y_true, y_pred, mask):426loss = self.loss(y_true, y_pred)427mask = tf.cast(mask, dtype=loss.dtype)428loss *= mask429return tf.reduce_sum(loss) / tf.reduce_sum(mask)430431def calculate_accuracy(self, y_true, y_pred, mask):432accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=2))433accuracy = tf.math.logical_and(mask, accuracy)434accuracy = tf.cast(accuracy, dtype=tf.float32)435mask = tf.cast(mask, dtype=tf.float32)436return tf.reduce_sum(accuracy) / tf.reduce_sum(mask)437438def _compute_caption_loss_and_acc(self, img_embed, batch_seq, training=True):439encoder_out = self.encoder(img_embed, training=training)440batch_seq_inp = batch_seq[:, :-1]441batch_seq_true = batch_seq[:, 1:]442mask = tf.math.not_equal(batch_seq_true, 0)443batch_seq_pred = self.decoder(444batch_seq_inp, encoder_out, training=training, mask=mask445)446loss = self.calculate_loss(batch_seq_true, batch_seq_pred, mask)447acc = self.calculate_accuracy(batch_seq_true, batch_seq_pred, mask)448return loss, acc449450def train_step(self, batch_data):451batch_img, batch_seq = batch_data452batch_loss = 0453batch_acc = 0454455if self.image_aug:456batch_img = self.image_aug(batch_img)457458# 1. Get image embeddings459img_embed = self.cnn_model(batch_img)460461# 2. Pass each of the five captions one by one to the decoder462# along with the encoder outputs and compute the loss as well as accuracy463# for each caption.464for i in range(self.num_captions_per_image):465with tf.GradientTape() as tape:466loss, acc = self._compute_caption_loss_and_acc(467img_embed, batch_seq[:, i, :], training=True468)469470# 3. Update loss and accuracy471batch_loss += loss472batch_acc += acc473474# 4. Get the list of all the trainable weights475train_vars = (476self.encoder.trainable_variables + self.decoder.trainable_variables477)478479# 5. Get the gradients480grads = tape.gradient(loss, train_vars)481482# 6. Update the trainable weights483self.optimizer.apply_gradients(zip(grads, train_vars))484485# 7. Update the trackers486batch_acc /= float(self.num_captions_per_image)487self.loss_tracker.update_state(batch_loss)488self.acc_tracker.update_state(batch_acc)489490# 8. Return the loss and accuracy values491return {492"loss": self.loss_tracker.result(),493"acc": self.acc_tracker.result(),494}495496def test_step(self, batch_data):497batch_img, batch_seq = batch_data498batch_loss = 0499batch_acc = 0500501# 1. Get image embeddings502img_embed = self.cnn_model(batch_img)503504# 2. Pass each of the five captions one by one to the decoder505# along with the encoder outputs and compute the loss as well as accuracy506# for each caption.507for i in range(self.num_captions_per_image):508loss, acc = self._compute_caption_loss_and_acc(509img_embed, batch_seq[:, i, :], training=False510)511512# 3. Update batch loss and batch accuracy513batch_loss += loss514batch_acc += acc515516batch_acc /= float(self.num_captions_per_image)517518# 4. Update the trackers519self.loss_tracker.update_state(batch_loss)520self.acc_tracker.update_state(batch_acc)521522# 5. Return the loss and accuracy values523return {524"loss": self.loss_tracker.result(),525"acc": self.acc_tracker.result(),526}527528@property529def metrics(self):530# We need to list our metrics here so the `reset_states()` can be531# called automatically.532return [self.loss_tracker, self.acc_tracker]533534535cnn_model = get_cnn_model()536encoder = TransformerEncoderBlock(embed_dim=EMBED_DIM, dense_dim=FF_DIM, num_heads=1)537decoder = TransformerDecoderBlock(embed_dim=EMBED_DIM, ff_dim=FF_DIM, num_heads=2)538caption_model = ImageCaptioningModel(539cnn_model=cnn_model,540encoder=encoder,541decoder=decoder,542image_aug=image_augmentation,543)544545"""546## Model training547"""548549550# Define the loss function551cross_entropy = keras.losses.SparseCategoricalCrossentropy(552from_logits=False,553reduction=None,554)555556# EarlyStopping criteria557early_stopping = keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True)558559560# Learning Rate Scheduler for the optimizer561class LRSchedule(keras.optimizers.schedules.LearningRateSchedule):562def __init__(self, post_warmup_learning_rate, warmup_steps):563super().__init__()564self.post_warmup_learning_rate = post_warmup_learning_rate565self.warmup_steps = warmup_steps566567def __call__(self, step):568global_step = tf.cast(step, tf.float32)569warmup_steps = tf.cast(self.warmup_steps, tf.float32)570warmup_progress = global_step / warmup_steps571warmup_learning_rate = self.post_warmup_learning_rate * warmup_progress572return tf.cond(573global_step < warmup_steps,574lambda: warmup_learning_rate,575lambda: self.post_warmup_learning_rate,576)577578579# Create a learning rate schedule580num_train_steps = len(train_dataset) * EPOCHS581num_warmup_steps = num_train_steps // 15582lr_schedule = LRSchedule(post_warmup_learning_rate=1e-4, warmup_steps=num_warmup_steps)583584# Compile the model585caption_model.compile(optimizer=keras.optimizers.Adam(lr_schedule), loss=cross_entropy)586587# Fit the model588caption_model.fit(589train_dataset,590epochs=EPOCHS,591validation_data=valid_dataset,592callbacks=[early_stopping],593)594595"""596## Check sample predictions597"""598599vocab = vectorization.get_vocabulary()600index_lookup = dict(zip(range(len(vocab)), vocab))601max_decoded_sentence_length = SEQ_LENGTH - 1602valid_images = list(valid_data.keys())603604605def generate_caption():606# Select a random image from the validation dataset607sample_img = np.random.choice(valid_images)608609# Read the image from the disk610sample_img = decode_and_resize(sample_img)611img = sample_img.numpy().clip(0, 255).astype(np.uint8)612plt.imshow(img)613plt.show()614615# Pass the image to the CNN616img = tf.expand_dims(sample_img, 0)617img = caption_model.cnn_model(img)618619# Pass the image features to the Transformer encoder620encoded_img = caption_model.encoder(img, training=False)621622# Generate the caption using the Transformer decoder623decoded_caption = "<start> "624for i in range(max_decoded_sentence_length):625tokenized_caption = vectorization([decoded_caption])[:, :-1]626mask = tf.math.not_equal(tokenized_caption, 0)627predictions = caption_model.decoder(628tokenized_caption, encoded_img, training=False, mask=mask629)630sampled_token_index = np.argmax(predictions[0, i, :])631sampled_token = index_lookup[sampled_token_index]632if sampled_token == "<end>":633break634decoded_caption += " " + sampled_token635636decoded_caption = decoded_caption.replace("<start> ", "")637decoded_caption = decoded_caption.replace(" <end>", "").strip()638print("Predicted Caption: ", decoded_caption)639640641# Check predictions for a few samples642generate_caption()643generate_caption()644generate_caption()645646"""647## End Notes648649We saw that the model starts to generate reasonable captions after a few epochs. To keep650this example easily runnable, we have trained it with a few constraints, like a minimal651number of attention heads. To improve the predictions, you can try changing these training652settings and find a good model for your use case.653"""654655656