Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/nl_image_search.py
3507 views
1
"""
2
Title: Natural language image search with a Dual Encoder
3
Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
4
Date created: 2021/01/30
5
Last modified: 2021/01/30
6
Description: Implementation of a dual encoder model for retrieving images that match natural language queries.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
The example demonstrates how to build a dual encoder (also known as two-tower) neural network
14
model to search for images using natural language. The model is inspired by
15
the [CLIP](https://openai.com/blog/clip/)
16
approach, introduced by Alec Radford et al. The idea is to train a vision encoder and a text
17
encoder jointly to project the representation of images and their captions into the same embedding
18
space, such that the caption embeddings are located near the embeddings of the images they describe.
19
20
This example requires TensorFlow 2.4 or higher.
21
In addition, [TensorFlow Hub](https://www.tensorflow.org/hub)
22
and [TensorFlow Text](https://www.tensorflow.org/tutorials/tensorflow_text/intro)
23
are required for the BERT model, and [TensorFlow Addons](https://www.tensorflow.org/addons)
24
is required for the AdamW optimizer. These libraries can be installed using the
25
following command:
26
27
```python
28
pip install -q -U tensorflow-hub tensorflow-text tensorflow-addons
29
```
30
"""
31
32
"""
33
## Setup
34
"""
35
36
import os
37
import collections
38
import json
39
import numpy as np
40
import tensorflow as tf
41
from tensorflow import keras
42
from tensorflow.keras import layers
43
import tensorflow_hub as hub
44
import tensorflow_text as text
45
import tensorflow_addons as tfa
46
import matplotlib.pyplot as plt
47
import matplotlib.image as mpimg
48
from tqdm import tqdm
49
50
# Suppressing tf.hub warnings
51
tf.get_logger().setLevel("ERROR")
52
53
"""
54
## Prepare the data
55
56
We will use the [MS-COCO](https://cocodataset.org/#home) dataset to train our
57
dual encoder model. MS-COCO contains over 82,000 images, each of which has at least
58
5 different caption annotations. The dataset is usually used for
59
[image captioning](https://www.tensorflow.org/tutorials/text/image_captioning)
60
tasks, but we can repurpose the image-caption pairs to train our dual encoder
61
model for image search.
62
63
###
64
Download and extract the data
65
66
First, let's download the dataset, which consists of two compressed folders:
67
one with images, and the other—with associated image captions.
68
Note that the compressed images folder is 13GB in size.
69
"""
70
71
root_dir = "datasets"
72
annotations_dir = os.path.join(root_dir, "annotations")
73
images_dir = os.path.join(root_dir, "train2014")
74
tfrecords_dir = os.path.join(root_dir, "tfrecords")
75
annotation_file = os.path.join(annotations_dir, "captions_train2014.json")
76
77
# Download caption annotation files
78
if not os.path.exists(annotations_dir):
79
annotation_zip = tf.keras.utils.get_file(
80
"captions.zip",
81
cache_dir=os.path.abspath("."),
82
origin="http://images.cocodataset.org/annotations/annotations_trainval2014.zip",
83
extract=True,
84
)
85
os.remove(annotation_zip)
86
87
# Download image files
88
if not os.path.exists(images_dir):
89
image_zip = tf.keras.utils.get_file(
90
"train2014.zip",
91
cache_dir=os.path.abspath("."),
92
origin="http://images.cocodataset.org/zips/train2014.zip",
93
extract=True,
94
)
95
os.remove(image_zip)
96
97
print("Dataset is downloaded and extracted successfully.")
98
99
with open(annotation_file, "r") as f:
100
annotations = json.load(f)["annotations"]
101
102
image_path_to_caption = collections.defaultdict(list)
103
for element in annotations:
104
caption = f"{element['caption'].lower().rstrip('.')}"
105
image_path = images_dir + "/COCO_train2014_" + "%012d.jpg" % (element["image_id"])
106
image_path_to_caption[image_path].append(caption)
107
108
image_paths = list(image_path_to_caption.keys())
109
print(f"Number of images: {len(image_paths)}")
110
111
"""
112
### Process and save the data to TFRecord files
113
114
You can change the `sample_size` parameter to control many image-caption pairs
115
will be used for training the dual encoder model.
116
In this example we set `train_size` to 30,000 images,
117
which is about 35% of the dataset. We use 2 captions for each
118
image, thus producing 60,000 image-caption pairs. The size of the training set
119
affects the quality of the produced encoders, but more examples would lead to
120
longer training time.
121
"""
122
123
train_size = 30000
124
valid_size = 5000
125
captions_per_image = 2
126
images_per_file = 2000
127
128
train_image_paths = image_paths[:train_size]
129
num_train_files = int(np.ceil(train_size / images_per_file))
130
train_files_prefix = os.path.join(tfrecords_dir, "train")
131
132
valid_image_paths = image_paths[-valid_size:]
133
num_valid_files = int(np.ceil(valid_size / images_per_file))
134
valid_files_prefix = os.path.join(tfrecords_dir, "valid")
135
136
tf.io.gfile.makedirs(tfrecords_dir)
137
138
139
def bytes_feature(value):
140
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
141
142
143
def create_example(image_path, caption):
144
feature = {
145
"caption": bytes_feature(caption.encode()),
146
"raw_image": bytes_feature(tf.io.read_file(image_path).numpy()),
147
}
148
return tf.train.Example(features=tf.train.Features(feature=feature))
149
150
151
def write_tfrecords(file_name, image_paths):
152
caption_list = []
153
image_path_list = []
154
for image_path in image_paths:
155
captions = image_path_to_caption[image_path][:captions_per_image]
156
caption_list.extend(captions)
157
image_path_list.extend([image_path] * len(captions))
158
159
with tf.io.TFRecordWriter(file_name) as writer:
160
for example_idx in range(len(image_path_list)):
161
example = create_example(
162
image_path_list[example_idx], caption_list[example_idx]
163
)
164
writer.write(example.SerializeToString())
165
return example_idx + 1
166
167
168
def write_data(image_paths, num_files, files_prefix):
169
example_counter = 0
170
for file_idx in tqdm(range(num_files)):
171
file_name = files_prefix + "-%02d.tfrecord" % (file_idx)
172
start_idx = images_per_file * file_idx
173
end_idx = start_idx + images_per_file
174
example_counter += write_tfrecords(file_name, image_paths[start_idx:end_idx])
175
return example_counter
176
177
178
train_example_count = write_data(train_image_paths, num_train_files, train_files_prefix)
179
print(f"{train_example_count} training examples were written to tfrecord files.")
180
181
valid_example_count = write_data(valid_image_paths, num_valid_files, valid_files_prefix)
182
print(f"{valid_example_count} evaluation examples were written to tfrecord files.")
183
184
"""
185
### Create `tf.data.Dataset` for training and evaluation
186
"""
187
188
189
feature_description = {
190
"caption": tf.io.FixedLenFeature([], tf.string),
191
"raw_image": tf.io.FixedLenFeature([], tf.string),
192
}
193
194
195
def read_example(example):
196
features = tf.io.parse_single_example(example, feature_description)
197
raw_image = features.pop("raw_image")
198
features["image"] = tf.image.resize(
199
tf.image.decode_jpeg(raw_image, channels=3), size=(299, 299)
200
)
201
return features
202
203
204
def get_dataset(file_pattern, batch_size):
205
return (
206
tf.data.TFRecordDataset(tf.data.Dataset.list_files(file_pattern))
207
.map(
208
read_example,
209
num_parallel_calls=tf.data.AUTOTUNE,
210
deterministic=False,
211
)
212
.shuffle(batch_size * 10)
213
.prefetch(buffer_size=tf.data.AUTOTUNE)
214
.batch(batch_size)
215
)
216
217
218
"""
219
## Implement the projection head
220
221
The projection head is used to transform the image and the text embeddings to
222
the same embedding space with the same dimensionality.
223
"""
224
225
226
def project_embeddings(
227
embeddings, num_projection_layers, projection_dims, dropout_rate
228
):
229
projected_embeddings = layers.Dense(units=projection_dims)(embeddings)
230
for _ in range(num_projection_layers):
231
x = tf.nn.gelu(projected_embeddings)
232
x = layers.Dense(projection_dims)(x)
233
x = layers.Dropout(dropout_rate)(x)
234
x = layers.Add()([projected_embeddings, x])
235
projected_embeddings = layers.LayerNormalization()(x)
236
return projected_embeddings
237
238
239
"""
240
## Implement the vision encoder
241
242
In this example, we use [Xception](https://keras.io/api/applications/xception/)
243
from [Keras Applications](https://keras.io/api/applications/) as the base for the
244
vision encoder.
245
"""
246
247
248
def create_vision_encoder(
249
num_projection_layers, projection_dims, dropout_rate, trainable=False
250
):
251
# Load the pre-trained Xception model to be used as the base encoder.
252
xception = keras.applications.Xception(
253
include_top=False, weights="imagenet", pooling="avg"
254
)
255
# Set the trainability of the base encoder.
256
for layer in xception.layers:
257
layer.trainable = trainable
258
# Receive the images as inputs.
259
inputs = layers.Input(shape=(299, 299, 3), name="image_input")
260
# Preprocess the input image.
261
xception_input = tf.keras.applications.xception.preprocess_input(inputs)
262
# Generate the embeddings for the images using the xception model.
263
embeddings = xception(xception_input)
264
# Project the embeddings produced by the model.
265
outputs = project_embeddings(
266
embeddings, num_projection_layers, projection_dims, dropout_rate
267
)
268
# Create the vision encoder model.
269
return keras.Model(inputs, outputs, name="vision_encoder")
270
271
272
"""
273
## Implement the text encoder
274
275
We use [BERT](https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-256_A-4/1)
276
from [TensorFlow Hub](https://tfhub.dev) as the text encoder
277
"""
278
279
280
def create_text_encoder(
281
num_projection_layers, projection_dims, dropout_rate, trainable=False
282
):
283
# Load the BERT preprocessing module.
284
preprocess = hub.KerasLayer(
285
"https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2",
286
name="text_preprocessing",
287
)
288
# Load the pre-trained BERT model to be used as the base encoder.
289
bert = hub.KerasLayer(
290
"https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1",
291
"bert",
292
)
293
# Set the trainability of the base encoder.
294
bert.trainable = trainable
295
# Receive the text as inputs.
296
inputs = layers.Input(shape=(), dtype=tf.string, name="text_input")
297
# Preprocess the text.
298
bert_inputs = preprocess(inputs)
299
# Generate embeddings for the preprocessed text using the BERT model.
300
embeddings = bert(bert_inputs)["pooled_output"]
301
# Project the embeddings produced by the model.
302
outputs = project_embeddings(
303
embeddings, num_projection_layers, projection_dims, dropout_rate
304
)
305
# Create the text encoder model.
306
return keras.Model(inputs, outputs, name="text_encoder")
307
308
309
"""
310
## Implement the dual encoder
311
312
To calculate the loss, we compute the pairwise dot-product similarity between
313
each `caption_i` and `images_j` in the batch as the predictions.
314
The target similarity between `caption_i` and `image_j` is computed as
315
the average of the (dot-product similarity between `caption_i` and `caption_j`)
316
and (the dot-product similarity between `image_i` and `image_j`).
317
Then, we use crossentropy to compute the loss between the targets and the predictions.
318
"""
319
320
321
class DualEncoder(keras.Model):
322
def __init__(self, text_encoder, image_encoder, temperature=1.0, **kwargs):
323
super().__init__(**kwargs)
324
self.text_encoder = text_encoder
325
self.image_encoder = image_encoder
326
self.temperature = temperature
327
self.loss_tracker = keras.metrics.Mean(name="loss")
328
329
@property
330
def metrics(self):
331
return [self.loss_tracker]
332
333
def call(self, features, training=False):
334
# Place each encoder on a separate GPU (if available).
335
# TF will fallback on available devices if there are fewer than 2 GPUs.
336
with tf.device("/gpu:0"):
337
# Get the embeddings for the captions.
338
caption_embeddings = text_encoder(features["caption"], training=training)
339
with tf.device("/gpu:1"):
340
# Get the embeddings for the images.
341
image_embeddings = vision_encoder(features["image"], training=training)
342
return caption_embeddings, image_embeddings
343
344
def compute_loss(self, caption_embeddings, image_embeddings):
345
# logits[i][j] is the dot_similarity(caption_i, image_j).
346
logits = (
347
tf.matmul(caption_embeddings, image_embeddings, transpose_b=True)
348
/ self.temperature
349
)
350
# images_similarity[i][j] is the dot_similarity(image_i, image_j).
351
images_similarity = tf.matmul(
352
image_embeddings, image_embeddings, transpose_b=True
353
)
354
# captions_similarity[i][j] is the dot_similarity(caption_i, caption_j).
355
captions_similarity = tf.matmul(
356
caption_embeddings, caption_embeddings, transpose_b=True
357
)
358
# targets[i][j] = avarage dot_similarity(caption_i, caption_j) and dot_similarity(image_i, image_j).
359
targets = keras.activations.softmax(
360
(captions_similarity + images_similarity) / (2 * self.temperature)
361
)
362
# Compute the loss for the captions using crossentropy
363
captions_loss = keras.losses.categorical_crossentropy(
364
y_true=targets, y_pred=logits, from_logits=True
365
)
366
# Compute the loss for the images using crossentropy
367
images_loss = keras.losses.categorical_crossentropy(
368
y_true=tf.transpose(targets), y_pred=tf.transpose(logits), from_logits=True
369
)
370
# Return the mean of the loss over the batch.
371
return (captions_loss + images_loss) / 2
372
373
def train_step(self, features):
374
with tf.GradientTape() as tape:
375
# Forward pass
376
caption_embeddings, image_embeddings = self(features, training=True)
377
loss = self.compute_loss(caption_embeddings, image_embeddings)
378
# Backward pass
379
gradients = tape.gradient(loss, self.trainable_variables)
380
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
381
# Monitor loss
382
self.loss_tracker.update_state(loss)
383
return {"loss": self.loss_tracker.result()}
384
385
def test_step(self, features):
386
caption_embeddings, image_embeddings = self(features, training=False)
387
loss = self.compute_loss(caption_embeddings, image_embeddings)
388
self.loss_tracker.update_state(loss)
389
return {"loss": self.loss_tracker.result()}
390
391
392
"""
393
## Train the dual encoder model
394
395
In this experiment, we freeze the base encoders for text and images, and make only
396
the projection head trainable.
397
"""
398
399
num_epochs = 5 # In practice, train for at least 30 epochs
400
batch_size = 256
401
402
vision_encoder = create_vision_encoder(
403
num_projection_layers=1, projection_dims=256, dropout_rate=0.1
404
)
405
text_encoder = create_text_encoder(
406
num_projection_layers=1, projection_dims=256, dropout_rate=0.1
407
)
408
dual_encoder = DualEncoder(text_encoder, vision_encoder, temperature=0.05)
409
dual_encoder.compile(
410
optimizer=tfa.optimizers.AdamW(learning_rate=0.001, weight_decay=0.001)
411
)
412
413
"""
414
Note that training the model with 60,000 image-caption pairs, with a batch size of 256,
415
takes around 12 minutes per epoch using a V100 GPU accelerator. If 2 GPUs are available,
416
the epoch takes around 8 minutes.
417
"""
418
419
print(f"Number of GPUs: {len(tf.config.list_physical_devices('GPU'))}")
420
print(f"Number of examples (caption-image pairs): {train_example_count}")
421
print(f"Batch size: {batch_size}")
422
print(f"Steps per epoch: {int(np.ceil(train_example_count / batch_size))}")
423
train_dataset = get_dataset(os.path.join(tfrecords_dir, "train-*.tfrecord"), batch_size)
424
valid_dataset = get_dataset(os.path.join(tfrecords_dir, "valid-*.tfrecord"), batch_size)
425
# Create a learning rate scheduler callback.
426
reduce_lr = keras.callbacks.ReduceLROnPlateau(
427
monitor="val_loss", factor=0.2, patience=3
428
)
429
# Create an early stopping callback.
430
early_stopping = tf.keras.callbacks.EarlyStopping(
431
monitor="val_loss", patience=5, restore_best_weights=True
432
)
433
history = dual_encoder.fit(
434
train_dataset,
435
epochs=num_epochs,
436
validation_data=valid_dataset,
437
callbacks=[reduce_lr, early_stopping],
438
)
439
print("Training completed. Saving vision and text encoders...")
440
vision_encoder.save("vision_encoder")
441
text_encoder.save("text_encoder")
442
print("Models are saved.")
443
444
"""
445
Plotting the training loss:
446
"""
447
448
plt.plot(history.history["loss"])
449
plt.plot(history.history["val_loss"])
450
plt.ylabel("Loss")
451
plt.xlabel("Epoch")
452
plt.legend(["train", "valid"], loc="upper right")
453
plt.show()
454
455
"""
456
## Search for images using natural language queries
457
458
We can then retrieve images corresponding to natural language queries via
459
the following steps:
460
461
1. Generate embeddings for the images by feeding them into the `vision_encoder`.
462
2. Feed the natural language query to the `text_encoder` to generate a query embedding.
463
3. Compute the similarity between the query embedding and the image embeddings
464
in the index to retrieve the indices of the top matches.
465
4. Look up the paths of the top matching images to display them.
466
467
Note that, after training the `dual encoder`, only the fine-tuned `vision_encoder`
468
and `text_encoder` models will be used, while the `dual_encoder` model will be discarded.
469
"""
470
471
"""
472
### Generate embeddings for the images
473
474
We load the images and feed them into the `vision_encoder` to generate their embeddings.
475
In large scale systems, this step is performed using a parallel data processing framework,
476
such as [Apache Spark](https://spark.apache.org) or [Apache Beam](https://beam.apache.org).
477
Generating the image embeddings may take several minutes.
478
"""
479
print("Loading vision and text encoders...")
480
vision_encoder = keras.models.load_model("vision_encoder")
481
text_encoder = keras.models.load_model("text_encoder")
482
print("Models are loaded.")
483
484
485
def read_image(image_path):
486
image_array = tf.image.decode_jpeg(tf.io.read_file(image_path), channels=3)
487
return tf.image.resize(image_array, (299, 299))
488
489
490
print(f"Generating embeddings for {len(image_paths)} images...")
491
image_embeddings = vision_encoder.predict(
492
tf.data.Dataset.from_tensor_slices(image_paths).map(read_image).batch(batch_size),
493
verbose=1,
494
)
495
print(f"Image embeddings shape: {image_embeddings.shape}.")
496
497
"""
498
### Retrieve relevant images
499
500
In this example, we use exact matching by computing the dot product similarity
501
between the input query embedding and the image embeddings, and retrieve the top k
502
matches. However, *approximate* similarity matching, using frameworks like
503
[ScaNN](https://github.com/google-research/google-research/tree/master/scann),
504
[Annoy](https://github.com/spotify/annoy), or [Faiss](https://github.com/facebookresearch/faiss)
505
is preferred in real-time use cases to scale with a large number of images.
506
"""
507
508
509
def find_matches(image_embeddings, queries, k=9, normalize=True):
510
# Get the embedding for the query.
511
query_embedding = text_encoder(tf.convert_to_tensor(queries))
512
# Normalize the query and the image embeddings.
513
if normalize:
514
image_embeddings = tf.math.l2_normalize(image_embeddings, axis=1)
515
query_embedding = tf.math.l2_normalize(query_embedding, axis=1)
516
# Compute the dot product between the query and the image embeddings.
517
dot_similarity = tf.matmul(query_embedding, image_embeddings, transpose_b=True)
518
# Retrieve top k indices.
519
results = tf.math.top_k(dot_similarity, k).indices.numpy()
520
# Return matching image paths.
521
return [[image_paths[idx] for idx in indices] for indices in results]
522
523
524
"""
525
Set the `query` variable to the type of images you want to search for.
526
Try things like: 'a plate of healthy food',
527
'a woman wearing a hat is walking down a sidewalk',
528
'a bird sits near to the water', or 'wild animals are standing in a field'.
529
"""
530
531
query = "a family standing next to the ocean on a sandy beach with a surf board"
532
matches = find_matches(image_embeddings, [query], normalize=True)[0]
533
534
plt.figure(figsize=(20, 20))
535
for i in range(9):
536
ax = plt.subplot(3, 3, i + 1)
537
plt.imshow(mpimg.imread(matches[i]))
538
plt.axis("off")
539
540
541
"""
542
## Evaluate the retrieval quality
543
544
To evaluate the dual encoder model, we use the captions as queries.
545
We use the out-of-training-sample images and captions to evaluate the retrieval quality,
546
using top k accuracy. A true prediction is counted if, for a given caption, its associated image
547
is retrieved within the top k matches.
548
"""
549
550
551
def compute_top_k_accuracy(image_paths, k=100):
552
hits = 0
553
num_batches = int(np.ceil(len(image_paths) / batch_size))
554
for idx in tqdm(range(num_batches)):
555
start_idx = idx * batch_size
556
end_idx = start_idx + batch_size
557
current_image_paths = image_paths[start_idx:end_idx]
558
queries = [
559
image_path_to_caption[image_path][0] for image_path in current_image_paths
560
]
561
result = find_matches(image_embeddings, queries, k)
562
hits += sum(
563
[
564
image_path in matches
565
for (image_path, matches) in list(zip(current_image_paths, result))
566
]
567
)
568
569
return hits / len(image_paths)
570
571
572
print("Scoring training data...")
573
train_accuracy = compute_top_k_accuracy(train_image_paths)
574
print(f"Train accuracy: {round(train_accuracy * 100, 3)}%")
575
576
print("Scoring evaluation data...")
577
eval_accuracy = compute_top_k_accuracy(image_paths[train_size:])
578
print(f"Eval accuracy: {round(eval_accuracy * 100, 3)}%")
579
580
581
"""
582
## Final remarks
583
584
You can obtain better results by increasing the size of the training sample,
585
train for more epochs, explore other base encoders for images and text,
586
set the base encoders to be trainable, and tune the hyperparameters,
587
especially the `temperature` for the softmax in the loss computation.
588
589
Example available on HuggingFace
590
591
| Trained Model | Demo |
592
| :--: | :--: |
593
| [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Model-nl%20image%20search-black.svg)](https://huggingface.co/keras-io/dual-encoder-image-search) | [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Spaces-nl%20image%20search-black.svg)](https://huggingface.co/spaces/keras-io/dual-encoder-image-search) |
594
"""
595
596