Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/near_dup_search.py
3507 views
1
"""
2
Title: Near-duplicate image search
3
Author: [Sayak Paul](https://twitter.com/RisingSayak)
4
Date created: 2021/09/10
5
Last modified: 2023/08/30
6
Description: Building a near-duplicate image search utility using deep learning and locality-sensitive hashing.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
Fetching similar images in (near) real time is an important use case of information
14
retrieval systems. Some popular products utilizing it include Pinterest, Google Image
15
Search, etc. In this example, we will build a similar image search utility using
16
[Locality Sensitive Hashing](https://towardsdatascience.com/understanding-locality-sensitive-hashing-49f6d1f6134)
17
(LSH) and [random projection](https://en.wikipedia.org/wiki/Random_projection) on top
18
of the image representations computed by a pretrained image classifier.
19
This kind of search engine is also known
20
as a _near-duplicate (or near-dup) image detector_.
21
We will also look into optimizing the inference performance of
22
our search utility on GPU using [TensorRT](https://developer.nvidia.com/tensorrt).
23
24
There are other examples under [keras.io/examples/vision](https://keras.io/examples/vision)
25
that are worth checking out in this regard:
26
27
* [Metric learning for image similarity search](https://keras.io/examples/vision/metric_learning)
28
* [Image similarity estimation using a Siamese Network with a triplet loss](https://keras.io/examples/vision/siamese_network)
29
30
Finally, this example uses the following resource as a reference and as such reuses some
31
of its code:
32
[Locality Sensitive Hashing for Similar Item Search](https://towardsdatascience.com/locality-sensitive-hashing-for-music-search-f2f1940ace23).
33
34
_Note that in order to optimize the performance of our parser,
35
you should have a GPU runtime available._
36
"""
37
38
"""
39
## Setup
40
"""
41
42
"""shell
43
pip install tensorrt
44
"""
45
46
"""
47
## Imports
48
"""
49
50
import matplotlib.pyplot as plt
51
import tensorflow as tf
52
import tensorrt
53
import numpy as np
54
import time
55
56
import tensorflow_datasets as tfds
57
58
tfds.disable_progress_bar()
59
60
"""
61
## Load the dataset and create a training set of 1,000 images
62
63
To keep the run time of the example short, we will be using a subset of 1,000 images from
64
the `tf_flowers` dataset (available through
65
[TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/tf_flowers))
66
to build our vocabulary.
67
"""
68
69
train_ds, validation_ds = tfds.load(
70
"tf_flowers", split=["train[:85%]", "train[85%:]"], as_supervised=True
71
)
72
73
IMAGE_SIZE = 224
74
NUM_IMAGES = 1000
75
76
images = []
77
labels = []
78
79
for image, label in train_ds.take(NUM_IMAGES):
80
image = tf.image.resize(image, (IMAGE_SIZE, IMAGE_SIZE))
81
images.append(image.numpy())
82
labels.append(label.numpy())
83
84
images = np.array(images)
85
labels = np.array(labels)
86
87
"""
88
## Load a pre-trained model
89
"""
90
91
"""
92
In this section, we load an image classification model that was trained on the
93
`tf_flowers` dataset. 85% of the total images were used to build the training set. For
94
more details on the training, refer to
95
[this notebook](https://github.com/sayakpaul/near-dup-parser/blob/main/bit-supervised-training.ipynb).
96
97
The underlying model is a BiT-ResNet (proposed in
98
[Big Transfer (BiT): General Visual Representation Learning](https://arxiv.org/abs/1912.11370)).
99
The BiT-ResNet family of models is known to provide excellent transfer performance across
100
a wide variety of different downstream tasks.
101
"""
102
103
"""shell
104
wget -q https://github.com/sayakpaul/near-dup-parser/releases/download/v0.1.0/flower_model_bit_0.96875.zip
105
unzip -qq flower_model_bit_0.96875.zip
106
"""
107
108
bit_model = tf.keras.models.load_model("flower_model_bit_0.96875")
109
bit_model.count_params()
110
111
"""
112
## Create an embedding model
113
114
To retrieve similar images given a query image, we need to first generate vector
115
representations of all the images involved. We do this via an
116
embedding model that extracts output features from our pretrained classifier and
117
normalizes the resulting feature vectors.
118
"""
119
120
embedding_model = tf.keras.Sequential(
121
[
122
tf.keras.layers.Input((IMAGE_SIZE, IMAGE_SIZE, 3)),
123
tf.keras.layers.Rescaling(scale=1.0 / 255),
124
bit_model.layers[1],
125
tf.keras.layers.Normalization(mean=0, variance=1),
126
],
127
name="embedding_model",
128
)
129
130
embedding_model.summary()
131
132
"""
133
Take note of the normalization layer inside the model. It is used to project the
134
representation vectors to the space of unit-spheres.
135
"""
136
137
"""
138
## Hashing utilities
139
"""
140
141
142
def hash_func(embedding, random_vectors):
143
embedding = np.array(embedding)
144
145
# Random projection.
146
bools = np.dot(embedding, random_vectors) > 0
147
return [bool2int(bool_vec) for bool_vec in bools]
148
149
150
def bool2int(x):
151
y = 0
152
for i, j in enumerate(x):
153
if j:
154
y += 1 << i
155
return y
156
157
158
"""
159
The shape of the vectors coming out of `embedding_model` is `(2048,)`, and considering practical
160
aspects (storage, retrieval performance, etc.) it is quite large. So, there arises a need
161
to reduce the dimensionality of the embedding vectors without reducing their information
162
content. This is where *random projection* comes into the picture.
163
It is based on the principle that if the
164
distance between a group of points on a given plane is _approximately_ preserved, the
165
dimensionality of that plane can further be reduced.
166
167
Inside `hash_func()`, we first reduce the dimensionality of the embedding vectors. Then
168
we compute the bitwise hash values of the images to determine their hash buckets. Images
169
having same hash values are likely to go into the same hash bucket. From a deployment
170
perspective, bitwise hash values are cheaper to store and operate on.
171
"""
172
173
"""
174
## Query utilities
175
176
The `Table` class is responsible for building a single hash table. Each entry in the hash
177
table is a mapping between the reduced embedding of an image from our dataset and a
178
unique identifier. Because our dimensionality reduction technique involves randomness, it
179
can so happen that similar images are not mapped to the same hash bucket everytime the
180
process run. To reduce this effect, we will take results from multiple tables into
181
consideration -- the number of tables and the reduction dimensionality are the key
182
hyperparameters here.
183
184
Crucially, you wouldn't reimplement locality-sensitive hashing yourself when working with
185
real world applications. Instead, you'd likely use one of the following popular libraries:
186
187
* [ScaNN](https://github.com/google-research/google-research/tree/master/scann)
188
* [Annoy](https://github.com/spotify/annoy)
189
* [Vald](https://github.com/vdaas/vald)
190
"""
191
192
193
class Table:
194
def __init__(self, hash_size, dim):
195
self.table = {}
196
self.hash_size = hash_size
197
self.random_vectors = np.random.randn(hash_size, dim).T
198
199
def add(self, id, vectors, label):
200
# Create a unique indentifier.
201
entry = {"id_label": str(id) + "_" + str(label)}
202
203
# Compute the hash values.
204
hashes = hash_func(vectors, self.random_vectors)
205
206
# Add the hash values to the current table.
207
for h in hashes:
208
if h in self.table:
209
self.table[h].append(entry)
210
else:
211
self.table[h] = [entry]
212
213
def query(self, vectors):
214
# Compute hash value for the query vector.
215
hashes = hash_func(vectors, self.random_vectors)
216
results = []
217
218
# Loop over the query hashes and determine if they exist in
219
# the current table.
220
for h in hashes:
221
if h in self.table:
222
results.extend(self.table[h])
223
return results
224
225
226
"""
227
In the following `LSH` class we will pack the utilities to have multiple hash tables.
228
"""
229
230
231
class LSH:
232
def __init__(self, hash_size, dim, num_tables):
233
self.num_tables = num_tables
234
self.tables = []
235
for i in range(self.num_tables):
236
self.tables.append(Table(hash_size, dim))
237
238
def add(self, id, vectors, label):
239
for table in self.tables:
240
table.add(id, vectors, label)
241
242
def query(self, vectors):
243
results = []
244
for table in self.tables:
245
results.extend(table.query(vectors))
246
return results
247
248
249
"""
250
Now we can encapsulate the logic for building and operating with the master LSH table (a
251
collection of many tables) inside a class. It has two methods:
252
253
* `train()`: Responsible for building the final LSH table.
254
* `query()`: Computes the number of matches given a query image and also quantifies the
255
similarity score.
256
"""
257
258
259
class BuildLSHTable:
260
def __init__(
261
self,
262
prediction_model,
263
concrete_function=False,
264
hash_size=8,
265
dim=2048,
266
num_tables=10,
267
):
268
self.hash_size = hash_size
269
self.dim = dim
270
self.num_tables = num_tables
271
self.lsh = LSH(self.hash_size, self.dim, self.num_tables)
272
273
self.prediction_model = prediction_model
274
self.concrete_function = concrete_function
275
276
def train(self, training_files):
277
for id, training_file in enumerate(training_files):
278
# Unpack the data.
279
image, label = training_file
280
if len(image.shape) < 4:
281
image = image[None, ...]
282
283
# Compute embeddings and update the LSH tables.
284
# More on `self.concrete_function()` later.
285
if self.concrete_function:
286
features = self.prediction_model(tf.constant(image))[
287
"normalization"
288
].numpy()
289
else:
290
features = self.prediction_model.predict(image)
291
self.lsh.add(id, features, label)
292
293
def query(self, image, verbose=True):
294
# Compute the embeddings of the query image and fetch the results.
295
if len(image.shape) < 4:
296
image = image[None, ...]
297
298
if self.concrete_function:
299
features = self.prediction_model(tf.constant(image))[
300
"normalization"
301
].numpy()
302
else:
303
features = self.prediction_model.predict(image)
304
305
results = self.lsh.query(features)
306
if verbose:
307
print("Matches:", len(results))
308
309
# Calculate Jaccard index to quantify the similarity.
310
counts = {}
311
for r in results:
312
if r["id_label"] in counts:
313
counts[r["id_label"]] += 1
314
else:
315
counts[r["id_label"]] = 1
316
for k in counts:
317
counts[k] = float(counts[k]) / self.dim
318
return counts
319
320
321
"""
322
## Create LSH tables
323
324
With our helper utilities and classes implemented, we can now build our LSH table. Since
325
we will be benchmarking performance between optimized and unoptimized embedding models, we
326
will also warm up our GPU to avoid any unfair comparison.
327
"""
328
329
330
# Utility to warm up the GPU.
331
def warmup():
332
dummy_sample = tf.ones((1, IMAGE_SIZE, IMAGE_SIZE, 3))
333
for _ in range(100):
334
_ = embedding_model.predict(dummy_sample)
335
336
337
"""
338
Now we can first do the GPU wam-up and proceed to build the master LSH table with
339
`embedding_model`.
340
"""
341
342
warmup()
343
344
training_files = zip(images, labels)
345
lsh_builder = BuildLSHTable(embedding_model)
346
lsh_builder.train(training_files)
347
348
349
"""
350
At the time of writing, the wall time was 54.1 seconds on a Tesla T4 GPU. This timing may
351
vary based on the GPU you are using.
352
"""
353
354
"""
355
## Optimize the model with TensorRT
356
357
For NVIDIA-based GPUs, the
358
[TensorRT framework](https://docs.nvidia.com/deeplearning/frameworks/tf-trt-user-guide/index.html)
359
can be used to dramatically enhance the inference latency by using various model
360
optimization techniques like pruning, constant folding, layer fusion, and so on. Here we
361
will use the `tf.experimental.tensorrt` module to optimize our embedding model.
362
"""
363
364
# First serialize the embedding model as a SavedModel.
365
embedding_model.save("embedding_model")
366
367
# Initialize the conversion parameters.
368
params = tf.experimental.tensorrt.ConversionParams(
369
precision_mode="FP16", maximum_cached_engines=16
370
)
371
372
# Run the conversion.
373
converter = tf.experimental.tensorrt.Converter(
374
input_saved_model_dir="embedding_model", conversion_params=params
375
)
376
converter.convert()
377
converter.save("tensorrt_embedding_model")
378
379
"""
380
**Notes on the parameters inside of `tf.experimental.tensorrt.ConversionParams()`**:
381
382
* `precision_mode` defines the numerical precision of the operations in the
383
to-be-converted model.
384
* `maximum_cached_engines` specifies the maximum number of TRT engines that will be
385
cached to handle dynamic operations (operations with unknown shapes).
386
387
To learn more about the other options, refer to the
388
[official documentation](https://www.tensorflow.org/api_docs/python/tf/experimental/tensorrt/ConversionParams).
389
You can also explore the different quantization options provided by the
390
`tf.experimental.tensorrt` module.
391
"""
392
393
# Load the converted model.
394
root = tf.saved_model.load("tensorrt_embedding_model")
395
trt_model_function = root.signatures["serving_default"]
396
397
"""
398
## Build LSH tables with optimized model
399
"""
400
401
warmup()
402
403
training_files = zip(images, labels)
404
lsh_builder_trt = BuildLSHTable(trt_model_function, concrete_function=True)
405
lsh_builder_trt.train(training_files)
406
407
"""
408
Notice the difference in the wall time which is **13.1 seconds**. Earlier, with the
409
unoptimized model it was **54.1 seconds**.
410
411
We can take a closer look into one of the hash tables and get an idea of how they are
412
represented.
413
"""
414
415
idx = 0
416
for hash, entry in lsh_builder_trt.lsh.tables[0].table.items():
417
if idx == 5:
418
break
419
if len(entry) < 5:
420
print(hash, entry)
421
idx += 1
422
423
"""
424
## Visualize results on validation images
425
426
In this section we will first writing a couple of utility functions to visualize the
427
similar image parsing process. Then we will benchmark the query performance of the models
428
with and without optimization.
429
"""
430
431
"""
432
First, we take 100 images from the validation set for testing purposes.
433
"""
434
435
validation_images = []
436
validation_labels = []
437
438
for image, label in validation_ds.take(100):
439
image = tf.image.resize(image, (224, 224))
440
validation_images.append(image.numpy())
441
validation_labels.append(label.numpy())
442
443
validation_images = np.array(validation_images)
444
validation_labels = np.array(validation_labels)
445
validation_images.shape, validation_labels.shape
446
447
448
"""
449
Now we write our visualization utilities.
450
"""
451
452
453
def plot_images(images, labels):
454
plt.figure(figsize=(20, 10))
455
columns = 5
456
for i, image in enumerate(images):
457
ax = plt.subplot(len(images) // columns + 1, columns, i + 1)
458
if i == 0:
459
ax.set_title("Query Image\n" + "Label: {}".format(labels[i]))
460
else:
461
ax.set_title("Similar Image # " + str(i) + "\nLabel: {}".format(labels[i]))
462
plt.imshow(image.astype("int"))
463
plt.axis("off")
464
465
466
def visualize_lsh(lsh_class):
467
idx = np.random.choice(len(validation_images))
468
image = validation_images[idx]
469
label = validation_labels[idx]
470
results = lsh_class.query(image)
471
472
candidates = []
473
labels = []
474
overlaps = []
475
476
for idx, r in enumerate(sorted(results, key=results.get, reverse=True)):
477
if idx == 4:
478
break
479
image_id, label = r.split("_")[0], r.split("_")[1]
480
candidates.append(images[int(image_id)])
481
labels.append(label)
482
overlaps.append(results[r])
483
484
candidates.insert(0, image)
485
labels.insert(0, label)
486
487
plot_images(candidates, labels)
488
489
490
"""
491
### Non-TRT model
492
"""
493
494
for _ in range(5):
495
visualize_lsh(lsh_builder)
496
497
visualize_lsh(lsh_builder)
498
499
"""
500
### TRT model
501
"""
502
503
for _ in range(5):
504
visualize_lsh(lsh_builder_trt)
505
506
"""
507
As you may have noticed, there are a couple of incorrect results. This can be mitigated in
508
a few ways:
509
510
* Better models for generating the initial embeddings especially for noisy samples. We can
511
use techniques like [ArcFace](https://arxiv.org/abs/1801.07698),
512
[Supervised Contrastive Learning](https://arxiv.org/abs/2004.11362), etc.
513
that implicitly encourage better learning of representations for retrieval purposes.
514
* The trade-off between the number of tables and the reduction dimensionality is crucial
515
and helps set the right recall required for your application.
516
"""
517
518
"""
519
## Benchmarking query performance
520
"""
521
522
523
def benchmark(lsh_class):
524
warmup()
525
526
start_time = time.time()
527
for _ in range(1000):
528
image = np.ones((1, 224, 224, 3)).astype("float32")
529
_ = lsh_class.query(image, verbose=False)
530
end_time = time.time() - start_time
531
print(f"Time taken: {end_time:.3f}")
532
533
534
benchmark(lsh_builder)
535
536
benchmark(lsh_builder_trt)
537
538
"""
539
We can immediately notice a stark difference between the query performance of the two
540
models.
541
"""
542
543
"""
544
## Final remarks
545
546
In this example, we explored the TensorRT framework from NVIDIA for optimizing our model.
547
It's best suited for GPU-based inference servers. There are other choices for such
548
frameworks that cater to different hardware platforms:
549
550
* [TensorFlow Lite](https://www.tensorflow.org/lite) for mobile and edge devices.
551
* [ONNX](hhttps://onnx.ai/) for commodity CPU-based servers.
552
* [Apache TVM](https://tvm.apache.org/), compiler for machine learning models covering
553
various platforms.
554
555
Here are a few resources you might want to check out to learn more
556
about applications based on vector similary search in general:
557
558
* [ANN Benchmarks](http://ann-benchmarks.com/)
559
* [Accelerating Large-Scale Inference with Anisotropic Vector Quantization(ScaNN)](https://arxiv.org/abs/1908.10396)
560
* [Spreading vectors for similarity search](https://arxiv.org/abs/1806.03198)
561
* [Building a real-time embeddings similarity matching system](https://cloud.google.com/architecture/building-real-time-embeddings-similarity-matching-system)
562
"""
563
564