Path: blob/master/examples/vision/nl_image_search.py
3507 views
"""1Title: Natural language image search with a Dual Encoder2Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)3Date created: 2021/01/304Last modified: 2021/01/305Description: Implementation of a dual encoder model for retrieving images that match natural language queries.6Accelerator: GPU7"""89"""10## Introduction1112The example demonstrates how to build a dual encoder (also known as two-tower) neural network13model to search for images using natural language. The model is inspired by14the [CLIP](https://openai.com/blog/clip/)15approach, introduced by Alec Radford et al. The idea is to train a vision encoder and a text16encoder jointly to project the representation of images and their captions into the same embedding17space, such that the caption embeddings are located near the embeddings of the images they describe.1819This example requires TensorFlow 2.4 or higher.20In addition, [TensorFlow Hub](https://www.tensorflow.org/hub)21and [TensorFlow Text](https://www.tensorflow.org/tutorials/tensorflow_text/intro)22are required for the BERT model, and [TensorFlow Addons](https://www.tensorflow.org/addons)23is required for the AdamW optimizer. These libraries can be installed using the24following command:2526```python27pip install -q -U tensorflow-hub tensorflow-text tensorflow-addons28```29"""3031"""32## Setup33"""3435import os36import collections37import json38import numpy as np39import tensorflow as tf40from tensorflow import keras41from tensorflow.keras import layers42import tensorflow_hub as hub43import tensorflow_text as text44import tensorflow_addons as tfa45import matplotlib.pyplot as plt46import matplotlib.image as mpimg47from tqdm import tqdm4849# Suppressing tf.hub warnings50tf.get_logger().setLevel("ERROR")5152"""53## Prepare the data5455We will use the [MS-COCO](https://cocodataset.org/#home) dataset to train our56dual encoder model. MS-COCO contains over 82,000 images, each of which has at least575 different caption annotations. The dataset is usually used for58[image captioning](https://www.tensorflow.org/tutorials/text/image_captioning)59tasks, but we can repurpose the image-caption pairs to train our dual encoder60model for image search.6162###63Download and extract the data6465First, let's download the dataset, which consists of two compressed folders:66one with images, and the other—with associated image captions.67Note that the compressed images folder is 13GB in size.68"""6970root_dir = "datasets"71annotations_dir = os.path.join(root_dir, "annotations")72images_dir = os.path.join(root_dir, "train2014")73tfrecords_dir = os.path.join(root_dir, "tfrecords")74annotation_file = os.path.join(annotations_dir, "captions_train2014.json")7576# Download caption annotation files77if not os.path.exists(annotations_dir):78annotation_zip = tf.keras.utils.get_file(79"captions.zip",80cache_dir=os.path.abspath("."),81origin="http://images.cocodataset.org/annotations/annotations_trainval2014.zip",82extract=True,83)84os.remove(annotation_zip)8586# Download image files87if not os.path.exists(images_dir):88image_zip = tf.keras.utils.get_file(89"train2014.zip",90cache_dir=os.path.abspath("."),91origin="http://images.cocodataset.org/zips/train2014.zip",92extract=True,93)94os.remove(image_zip)9596print("Dataset is downloaded and extracted successfully.")9798with open(annotation_file, "r") as f:99annotations = json.load(f)["annotations"]100101image_path_to_caption = collections.defaultdict(list)102for element in annotations:103caption = f"{element['caption'].lower().rstrip('.')}"104image_path = images_dir + "/COCO_train2014_" + "%012d.jpg" % (element["image_id"])105image_path_to_caption[image_path].append(caption)106107image_paths = list(image_path_to_caption.keys())108print(f"Number of images: {len(image_paths)}")109110"""111### Process and save the data to TFRecord files112113You can change the `sample_size` parameter to control many image-caption pairs114will be used for training the dual encoder model.115In this example we set `train_size` to 30,000 images,116which is about 35% of the dataset. We use 2 captions for each117image, thus producing 60,000 image-caption pairs. The size of the training set118affects the quality of the produced encoders, but more examples would lead to119longer training time.120"""121122train_size = 30000123valid_size = 5000124captions_per_image = 2125images_per_file = 2000126127train_image_paths = image_paths[:train_size]128num_train_files = int(np.ceil(train_size / images_per_file))129train_files_prefix = os.path.join(tfrecords_dir, "train")130131valid_image_paths = image_paths[-valid_size:]132num_valid_files = int(np.ceil(valid_size / images_per_file))133valid_files_prefix = os.path.join(tfrecords_dir, "valid")134135tf.io.gfile.makedirs(tfrecords_dir)136137138def bytes_feature(value):139return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))140141142def create_example(image_path, caption):143feature = {144"caption": bytes_feature(caption.encode()),145"raw_image": bytes_feature(tf.io.read_file(image_path).numpy()),146}147return tf.train.Example(features=tf.train.Features(feature=feature))148149150def write_tfrecords(file_name, image_paths):151caption_list = []152image_path_list = []153for image_path in image_paths:154captions = image_path_to_caption[image_path][:captions_per_image]155caption_list.extend(captions)156image_path_list.extend([image_path] * len(captions))157158with tf.io.TFRecordWriter(file_name) as writer:159for example_idx in range(len(image_path_list)):160example = create_example(161image_path_list[example_idx], caption_list[example_idx]162)163writer.write(example.SerializeToString())164return example_idx + 1165166167def write_data(image_paths, num_files, files_prefix):168example_counter = 0169for file_idx in tqdm(range(num_files)):170file_name = files_prefix + "-%02d.tfrecord" % (file_idx)171start_idx = images_per_file * file_idx172end_idx = start_idx + images_per_file173example_counter += write_tfrecords(file_name, image_paths[start_idx:end_idx])174return example_counter175176177train_example_count = write_data(train_image_paths, num_train_files, train_files_prefix)178print(f"{train_example_count} training examples were written to tfrecord files.")179180valid_example_count = write_data(valid_image_paths, num_valid_files, valid_files_prefix)181print(f"{valid_example_count} evaluation examples were written to tfrecord files.")182183"""184### Create `tf.data.Dataset` for training and evaluation185"""186187188feature_description = {189"caption": tf.io.FixedLenFeature([], tf.string),190"raw_image": tf.io.FixedLenFeature([], tf.string),191}192193194def read_example(example):195features = tf.io.parse_single_example(example, feature_description)196raw_image = features.pop("raw_image")197features["image"] = tf.image.resize(198tf.image.decode_jpeg(raw_image, channels=3), size=(299, 299)199)200return features201202203def get_dataset(file_pattern, batch_size):204return (205tf.data.TFRecordDataset(tf.data.Dataset.list_files(file_pattern))206.map(207read_example,208num_parallel_calls=tf.data.AUTOTUNE,209deterministic=False,210)211.shuffle(batch_size * 10)212.prefetch(buffer_size=tf.data.AUTOTUNE)213.batch(batch_size)214)215216217"""218## Implement the projection head219220The projection head is used to transform the image and the text embeddings to221the same embedding space with the same dimensionality.222"""223224225def project_embeddings(226embeddings, num_projection_layers, projection_dims, dropout_rate227):228projected_embeddings = layers.Dense(units=projection_dims)(embeddings)229for _ in range(num_projection_layers):230x = tf.nn.gelu(projected_embeddings)231x = layers.Dense(projection_dims)(x)232x = layers.Dropout(dropout_rate)(x)233x = layers.Add()([projected_embeddings, x])234projected_embeddings = layers.LayerNormalization()(x)235return projected_embeddings236237238"""239## Implement the vision encoder240241In this example, we use [Xception](https://keras.io/api/applications/xception/)242from [Keras Applications](https://keras.io/api/applications/) as the base for the243vision encoder.244"""245246247def create_vision_encoder(248num_projection_layers, projection_dims, dropout_rate, trainable=False249):250# Load the pre-trained Xception model to be used as the base encoder.251xception = keras.applications.Xception(252include_top=False, weights="imagenet", pooling="avg"253)254# Set the trainability of the base encoder.255for layer in xception.layers:256layer.trainable = trainable257# Receive the images as inputs.258inputs = layers.Input(shape=(299, 299, 3), name="image_input")259# Preprocess the input image.260xception_input = tf.keras.applications.xception.preprocess_input(inputs)261# Generate the embeddings for the images using the xception model.262embeddings = xception(xception_input)263# Project the embeddings produced by the model.264outputs = project_embeddings(265embeddings, num_projection_layers, projection_dims, dropout_rate266)267# Create the vision encoder model.268return keras.Model(inputs, outputs, name="vision_encoder")269270271"""272## Implement the text encoder273274We use [BERT](https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-256_A-4/1)275from [TensorFlow Hub](https://tfhub.dev) as the text encoder276"""277278279def create_text_encoder(280num_projection_layers, projection_dims, dropout_rate, trainable=False281):282# Load the BERT preprocessing module.283preprocess = hub.KerasLayer(284"https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2",285name="text_preprocessing",286)287# Load the pre-trained BERT model to be used as the base encoder.288bert = hub.KerasLayer(289"https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1",290"bert",291)292# Set the trainability of the base encoder.293bert.trainable = trainable294# Receive the text as inputs.295inputs = layers.Input(shape=(), dtype=tf.string, name="text_input")296# Preprocess the text.297bert_inputs = preprocess(inputs)298# Generate embeddings for the preprocessed text using the BERT model.299embeddings = bert(bert_inputs)["pooled_output"]300# Project the embeddings produced by the model.301outputs = project_embeddings(302embeddings, num_projection_layers, projection_dims, dropout_rate303)304# Create the text encoder model.305return keras.Model(inputs, outputs, name="text_encoder")306307308"""309## Implement the dual encoder310311To calculate the loss, we compute the pairwise dot-product similarity between312each `caption_i` and `images_j` in the batch as the predictions.313The target similarity between `caption_i` and `image_j` is computed as314the average of the (dot-product similarity between `caption_i` and `caption_j`)315and (the dot-product similarity between `image_i` and `image_j`).316Then, we use crossentropy to compute the loss between the targets and the predictions.317"""318319320class DualEncoder(keras.Model):321def __init__(self, text_encoder, image_encoder, temperature=1.0, **kwargs):322super().__init__(**kwargs)323self.text_encoder = text_encoder324self.image_encoder = image_encoder325self.temperature = temperature326self.loss_tracker = keras.metrics.Mean(name="loss")327328@property329def metrics(self):330return [self.loss_tracker]331332def call(self, features, training=False):333# Place each encoder on a separate GPU (if available).334# TF will fallback on available devices if there are fewer than 2 GPUs.335with tf.device("/gpu:0"):336# Get the embeddings for the captions.337caption_embeddings = text_encoder(features["caption"], training=training)338with tf.device("/gpu:1"):339# Get the embeddings for the images.340image_embeddings = vision_encoder(features["image"], training=training)341return caption_embeddings, image_embeddings342343def compute_loss(self, caption_embeddings, image_embeddings):344# logits[i][j] is the dot_similarity(caption_i, image_j).345logits = (346tf.matmul(caption_embeddings, image_embeddings, transpose_b=True)347/ self.temperature348)349# images_similarity[i][j] is the dot_similarity(image_i, image_j).350images_similarity = tf.matmul(351image_embeddings, image_embeddings, transpose_b=True352)353# captions_similarity[i][j] is the dot_similarity(caption_i, caption_j).354captions_similarity = tf.matmul(355caption_embeddings, caption_embeddings, transpose_b=True356)357# targets[i][j] = avarage dot_similarity(caption_i, caption_j) and dot_similarity(image_i, image_j).358targets = keras.activations.softmax(359(captions_similarity + images_similarity) / (2 * self.temperature)360)361# Compute the loss for the captions using crossentropy362captions_loss = keras.losses.categorical_crossentropy(363y_true=targets, y_pred=logits, from_logits=True364)365# Compute the loss for the images using crossentropy366images_loss = keras.losses.categorical_crossentropy(367y_true=tf.transpose(targets), y_pred=tf.transpose(logits), from_logits=True368)369# Return the mean of the loss over the batch.370return (captions_loss + images_loss) / 2371372def train_step(self, features):373with tf.GradientTape() as tape:374# Forward pass375caption_embeddings, image_embeddings = self(features, training=True)376loss = self.compute_loss(caption_embeddings, image_embeddings)377# Backward pass378gradients = tape.gradient(loss, self.trainable_variables)379self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))380# Monitor loss381self.loss_tracker.update_state(loss)382return {"loss": self.loss_tracker.result()}383384def test_step(self, features):385caption_embeddings, image_embeddings = self(features, training=False)386loss = self.compute_loss(caption_embeddings, image_embeddings)387self.loss_tracker.update_state(loss)388return {"loss": self.loss_tracker.result()}389390391"""392## Train the dual encoder model393394In this experiment, we freeze the base encoders for text and images, and make only395the projection head trainable.396"""397398num_epochs = 5 # In practice, train for at least 30 epochs399batch_size = 256400401vision_encoder = create_vision_encoder(402num_projection_layers=1, projection_dims=256, dropout_rate=0.1403)404text_encoder = create_text_encoder(405num_projection_layers=1, projection_dims=256, dropout_rate=0.1406)407dual_encoder = DualEncoder(text_encoder, vision_encoder, temperature=0.05)408dual_encoder.compile(409optimizer=tfa.optimizers.AdamW(learning_rate=0.001, weight_decay=0.001)410)411412"""413Note that training the model with 60,000 image-caption pairs, with a batch size of 256,414takes around 12 minutes per epoch using a V100 GPU accelerator. If 2 GPUs are available,415the epoch takes around 8 minutes.416"""417418print(f"Number of GPUs: {len(tf.config.list_physical_devices('GPU'))}")419print(f"Number of examples (caption-image pairs): {train_example_count}")420print(f"Batch size: {batch_size}")421print(f"Steps per epoch: {int(np.ceil(train_example_count / batch_size))}")422train_dataset = get_dataset(os.path.join(tfrecords_dir, "train-*.tfrecord"), batch_size)423valid_dataset = get_dataset(os.path.join(tfrecords_dir, "valid-*.tfrecord"), batch_size)424# Create a learning rate scheduler callback.425reduce_lr = keras.callbacks.ReduceLROnPlateau(426monitor="val_loss", factor=0.2, patience=3427)428# Create an early stopping callback.429early_stopping = tf.keras.callbacks.EarlyStopping(430monitor="val_loss", patience=5, restore_best_weights=True431)432history = dual_encoder.fit(433train_dataset,434epochs=num_epochs,435validation_data=valid_dataset,436callbacks=[reduce_lr, early_stopping],437)438print("Training completed. Saving vision and text encoders...")439vision_encoder.save("vision_encoder")440text_encoder.save("text_encoder")441print("Models are saved.")442443"""444Plotting the training loss:445"""446447plt.plot(history.history["loss"])448plt.plot(history.history["val_loss"])449plt.ylabel("Loss")450plt.xlabel("Epoch")451plt.legend(["train", "valid"], loc="upper right")452plt.show()453454"""455## Search for images using natural language queries456457We can then retrieve images corresponding to natural language queries via458the following steps:4594601. Generate embeddings for the images by feeding them into the `vision_encoder`.4612. Feed the natural language query to the `text_encoder` to generate a query embedding.4623. Compute the similarity between the query embedding and the image embeddings463in the index to retrieve the indices of the top matches.4644. Look up the paths of the top matching images to display them.465466Note that, after training the `dual encoder`, only the fine-tuned `vision_encoder`467and `text_encoder` models will be used, while the `dual_encoder` model will be discarded.468"""469470"""471### Generate embeddings for the images472473We load the images and feed them into the `vision_encoder` to generate their embeddings.474In large scale systems, this step is performed using a parallel data processing framework,475such as [Apache Spark](https://spark.apache.org) or [Apache Beam](https://beam.apache.org).476Generating the image embeddings may take several minutes.477"""478print("Loading vision and text encoders...")479vision_encoder = keras.models.load_model("vision_encoder")480text_encoder = keras.models.load_model("text_encoder")481print("Models are loaded.")482483484def read_image(image_path):485image_array = tf.image.decode_jpeg(tf.io.read_file(image_path), channels=3)486return tf.image.resize(image_array, (299, 299))487488489print(f"Generating embeddings for {len(image_paths)} images...")490image_embeddings = vision_encoder.predict(491tf.data.Dataset.from_tensor_slices(image_paths).map(read_image).batch(batch_size),492verbose=1,493)494print(f"Image embeddings shape: {image_embeddings.shape}.")495496"""497### Retrieve relevant images498499In this example, we use exact matching by computing the dot product similarity500between the input query embedding and the image embeddings, and retrieve the top k501matches. However, *approximate* similarity matching, using frameworks like502[ScaNN](https://github.com/google-research/google-research/tree/master/scann),503[Annoy](https://github.com/spotify/annoy), or [Faiss](https://github.com/facebookresearch/faiss)504is preferred in real-time use cases to scale with a large number of images.505"""506507508def find_matches(image_embeddings, queries, k=9, normalize=True):509# Get the embedding for the query.510query_embedding = text_encoder(tf.convert_to_tensor(queries))511# Normalize the query and the image embeddings.512if normalize:513image_embeddings = tf.math.l2_normalize(image_embeddings, axis=1)514query_embedding = tf.math.l2_normalize(query_embedding, axis=1)515# Compute the dot product between the query and the image embeddings.516dot_similarity = tf.matmul(query_embedding, image_embeddings, transpose_b=True)517# Retrieve top k indices.518results = tf.math.top_k(dot_similarity, k).indices.numpy()519# Return matching image paths.520return [[image_paths[idx] for idx in indices] for indices in results]521522523"""524Set the `query` variable to the type of images you want to search for.525Try things like: 'a plate of healthy food',526'a woman wearing a hat is walking down a sidewalk',527'a bird sits near to the water', or 'wild animals are standing in a field'.528"""529530query = "a family standing next to the ocean on a sandy beach with a surf board"531matches = find_matches(image_embeddings, [query], normalize=True)[0]532533plt.figure(figsize=(20, 20))534for i in range(9):535ax = plt.subplot(3, 3, i + 1)536plt.imshow(mpimg.imread(matches[i]))537plt.axis("off")538539540"""541## Evaluate the retrieval quality542543To evaluate the dual encoder model, we use the captions as queries.544We use the out-of-training-sample images and captions to evaluate the retrieval quality,545using top k accuracy. A true prediction is counted if, for a given caption, its associated image546is retrieved within the top k matches.547"""548549550def compute_top_k_accuracy(image_paths, k=100):551hits = 0552num_batches = int(np.ceil(len(image_paths) / batch_size))553for idx in tqdm(range(num_batches)):554start_idx = idx * batch_size555end_idx = start_idx + batch_size556current_image_paths = image_paths[start_idx:end_idx]557queries = [558image_path_to_caption[image_path][0] for image_path in current_image_paths559]560result = find_matches(image_embeddings, queries, k)561hits += sum(562[563image_path in matches564for (image_path, matches) in list(zip(current_image_paths, result))565]566)567568return hits / len(image_paths)569570571print("Scoring training data...")572train_accuracy = compute_top_k_accuracy(train_image_paths)573print(f"Train accuracy: {round(train_accuracy * 100, 3)}%")574575print("Scoring evaluation data...")576eval_accuracy = compute_top_k_accuracy(image_paths[train_size:])577print(f"Eval accuracy: {round(eval_accuracy * 100, 3)}%")578579580"""581## Final remarks582583You can obtain better results by increasing the size of the training sample,584train for more epochs, explore other base encoders for images and text,585set the base encoders to be trainable, and tune the hyperparameters,586especially the `temperature` for the softmax in the loss computation.587588Example available on HuggingFace589590| Trained Model | Demo |591| :--: | :--: |592| [](https://huggingface.co/keras-io/dual-encoder-image-search) | [](https://huggingface.co/spaces/keras-io/dual-encoder-image-search) |593"""594595596