Path: blob/master/examples/vision/near_dup_search.py
3507 views
"""1Title: Near-duplicate image search2Author: [Sayak Paul](https://twitter.com/RisingSayak)3Date created: 2021/09/104Last modified: 2023/08/305Description: Building a near-duplicate image search utility using deep learning and locality-sensitive hashing.6Accelerator: GPU7"""89"""10## Introduction1112Fetching similar images in (near) real time is an important use case of information13retrieval systems. Some popular products utilizing it include Pinterest, Google Image14Search, etc. In this example, we will build a similar image search utility using15[Locality Sensitive Hashing](https://towardsdatascience.com/understanding-locality-sensitive-hashing-49f6d1f6134)16(LSH) and [random projection](https://en.wikipedia.org/wiki/Random_projection) on top17of the image representations computed by a pretrained image classifier.18This kind of search engine is also known19as a _near-duplicate (or near-dup) image detector_.20We will also look into optimizing the inference performance of21our search utility on GPU using [TensorRT](https://developer.nvidia.com/tensorrt).2223There are other examples under [keras.io/examples/vision](https://keras.io/examples/vision)24that are worth checking out in this regard:2526* [Metric learning for image similarity search](https://keras.io/examples/vision/metric_learning)27* [Image similarity estimation using a Siamese Network with a triplet loss](https://keras.io/examples/vision/siamese_network)2829Finally, this example uses the following resource as a reference and as such reuses some30of its code:31[Locality Sensitive Hashing for Similar Item Search](https://towardsdatascience.com/locality-sensitive-hashing-for-music-search-f2f1940ace23).3233_Note that in order to optimize the performance of our parser,34you should have a GPU runtime available._35"""3637"""38## Setup39"""4041"""shell42pip install tensorrt43"""4445"""46## Imports47"""4849import matplotlib.pyplot as plt50import tensorflow as tf51import tensorrt52import numpy as np53import time5455import tensorflow_datasets as tfds5657tfds.disable_progress_bar()5859"""60## Load the dataset and create a training set of 1,000 images6162To keep the run time of the example short, we will be using a subset of 1,000 images from63the `tf_flowers` dataset (available through64[TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/tf_flowers))65to build our vocabulary.66"""6768train_ds, validation_ds = tfds.load(69"tf_flowers", split=["train[:85%]", "train[85%:]"], as_supervised=True70)7172IMAGE_SIZE = 22473NUM_IMAGES = 10007475images = []76labels = []7778for image, label in train_ds.take(NUM_IMAGES):79image = tf.image.resize(image, (IMAGE_SIZE, IMAGE_SIZE))80images.append(image.numpy())81labels.append(label.numpy())8283images = np.array(images)84labels = np.array(labels)8586"""87## Load a pre-trained model88"""8990"""91In this section, we load an image classification model that was trained on the92`tf_flowers` dataset. 85% of the total images were used to build the training set. For93more details on the training, refer to94[this notebook](https://github.com/sayakpaul/near-dup-parser/blob/main/bit-supervised-training.ipynb).9596The underlying model is a BiT-ResNet (proposed in97[Big Transfer (BiT): General Visual Representation Learning](https://arxiv.org/abs/1912.11370)).98The BiT-ResNet family of models is known to provide excellent transfer performance across99a wide variety of different downstream tasks.100"""101102"""shell103wget -q https://github.com/sayakpaul/near-dup-parser/releases/download/v0.1.0/flower_model_bit_0.96875.zip104unzip -qq flower_model_bit_0.96875.zip105"""106107bit_model = tf.keras.models.load_model("flower_model_bit_0.96875")108bit_model.count_params()109110"""111## Create an embedding model112113To retrieve similar images given a query image, we need to first generate vector114representations of all the images involved. We do this via an115embedding model that extracts output features from our pretrained classifier and116normalizes the resulting feature vectors.117"""118119embedding_model = tf.keras.Sequential(120[121tf.keras.layers.Input((IMAGE_SIZE, IMAGE_SIZE, 3)),122tf.keras.layers.Rescaling(scale=1.0 / 255),123bit_model.layers[1],124tf.keras.layers.Normalization(mean=0, variance=1),125],126name="embedding_model",127)128129embedding_model.summary()130131"""132Take note of the normalization layer inside the model. It is used to project the133representation vectors to the space of unit-spheres.134"""135136"""137## Hashing utilities138"""139140141def hash_func(embedding, random_vectors):142embedding = np.array(embedding)143144# Random projection.145bools = np.dot(embedding, random_vectors) > 0146return [bool2int(bool_vec) for bool_vec in bools]147148149def bool2int(x):150y = 0151for i, j in enumerate(x):152if j:153y += 1 << i154return y155156157"""158The shape of the vectors coming out of `embedding_model` is `(2048,)`, and considering practical159aspects (storage, retrieval performance, etc.) it is quite large. So, there arises a need160to reduce the dimensionality of the embedding vectors without reducing their information161content. This is where *random projection* comes into the picture.162It is based on the principle that if the163distance between a group of points on a given plane is _approximately_ preserved, the164dimensionality of that plane can further be reduced.165166Inside `hash_func()`, we first reduce the dimensionality of the embedding vectors. Then167we compute the bitwise hash values of the images to determine their hash buckets. Images168having same hash values are likely to go into the same hash bucket. From a deployment169perspective, bitwise hash values are cheaper to store and operate on.170"""171172"""173## Query utilities174175The `Table` class is responsible for building a single hash table. Each entry in the hash176table is a mapping between the reduced embedding of an image from our dataset and a177unique identifier. Because our dimensionality reduction technique involves randomness, it178can so happen that similar images are not mapped to the same hash bucket everytime the179process run. To reduce this effect, we will take results from multiple tables into180consideration -- the number of tables and the reduction dimensionality are the key181hyperparameters here.182183Crucially, you wouldn't reimplement locality-sensitive hashing yourself when working with184real world applications. Instead, you'd likely use one of the following popular libraries:185186* [ScaNN](https://github.com/google-research/google-research/tree/master/scann)187* [Annoy](https://github.com/spotify/annoy)188* [Vald](https://github.com/vdaas/vald)189"""190191192class Table:193def __init__(self, hash_size, dim):194self.table = {}195self.hash_size = hash_size196self.random_vectors = np.random.randn(hash_size, dim).T197198def add(self, id, vectors, label):199# Create a unique indentifier.200entry = {"id_label": str(id) + "_" + str(label)}201202# Compute the hash values.203hashes = hash_func(vectors, self.random_vectors)204205# Add the hash values to the current table.206for h in hashes:207if h in self.table:208self.table[h].append(entry)209else:210self.table[h] = [entry]211212def query(self, vectors):213# Compute hash value for the query vector.214hashes = hash_func(vectors, self.random_vectors)215results = []216217# Loop over the query hashes and determine if they exist in218# the current table.219for h in hashes:220if h in self.table:221results.extend(self.table[h])222return results223224225"""226In the following `LSH` class we will pack the utilities to have multiple hash tables.227"""228229230class LSH:231def __init__(self, hash_size, dim, num_tables):232self.num_tables = num_tables233self.tables = []234for i in range(self.num_tables):235self.tables.append(Table(hash_size, dim))236237def add(self, id, vectors, label):238for table in self.tables:239table.add(id, vectors, label)240241def query(self, vectors):242results = []243for table in self.tables:244results.extend(table.query(vectors))245return results246247248"""249Now we can encapsulate the logic for building and operating with the master LSH table (a250collection of many tables) inside a class. It has two methods:251252* `train()`: Responsible for building the final LSH table.253* `query()`: Computes the number of matches given a query image and also quantifies the254similarity score.255"""256257258class BuildLSHTable:259def __init__(260self,261prediction_model,262concrete_function=False,263hash_size=8,264dim=2048,265num_tables=10,266):267self.hash_size = hash_size268self.dim = dim269self.num_tables = num_tables270self.lsh = LSH(self.hash_size, self.dim, self.num_tables)271272self.prediction_model = prediction_model273self.concrete_function = concrete_function274275def train(self, training_files):276for id, training_file in enumerate(training_files):277# Unpack the data.278image, label = training_file279if len(image.shape) < 4:280image = image[None, ...]281282# Compute embeddings and update the LSH tables.283# More on `self.concrete_function()` later.284if self.concrete_function:285features = self.prediction_model(tf.constant(image))[286"normalization"287].numpy()288else:289features = self.prediction_model.predict(image)290self.lsh.add(id, features, label)291292def query(self, image, verbose=True):293# Compute the embeddings of the query image and fetch the results.294if len(image.shape) < 4:295image = image[None, ...]296297if self.concrete_function:298features = self.prediction_model(tf.constant(image))[299"normalization"300].numpy()301else:302features = self.prediction_model.predict(image)303304results = self.lsh.query(features)305if verbose:306print("Matches:", len(results))307308# Calculate Jaccard index to quantify the similarity.309counts = {}310for r in results:311if r["id_label"] in counts:312counts[r["id_label"]] += 1313else:314counts[r["id_label"]] = 1315for k in counts:316counts[k] = float(counts[k]) / self.dim317return counts318319320"""321## Create LSH tables322323With our helper utilities and classes implemented, we can now build our LSH table. Since324we will be benchmarking performance between optimized and unoptimized embedding models, we325will also warm up our GPU to avoid any unfair comparison.326"""327328329# Utility to warm up the GPU.330def warmup():331dummy_sample = tf.ones((1, IMAGE_SIZE, IMAGE_SIZE, 3))332for _ in range(100):333_ = embedding_model.predict(dummy_sample)334335336"""337Now we can first do the GPU wam-up and proceed to build the master LSH table with338`embedding_model`.339"""340341warmup()342343training_files = zip(images, labels)344lsh_builder = BuildLSHTable(embedding_model)345lsh_builder.train(training_files)346347348"""349At the time of writing, the wall time was 54.1 seconds on a Tesla T4 GPU. This timing may350vary based on the GPU you are using.351"""352353"""354## Optimize the model with TensorRT355356For NVIDIA-based GPUs, the357[TensorRT framework](https://docs.nvidia.com/deeplearning/frameworks/tf-trt-user-guide/index.html)358can be used to dramatically enhance the inference latency by using various model359optimization techniques like pruning, constant folding, layer fusion, and so on. Here we360will use the `tf.experimental.tensorrt` module to optimize our embedding model.361"""362363# First serialize the embedding model as a SavedModel.364embedding_model.save("embedding_model")365366# Initialize the conversion parameters.367params = tf.experimental.tensorrt.ConversionParams(368precision_mode="FP16", maximum_cached_engines=16369)370371# Run the conversion.372converter = tf.experimental.tensorrt.Converter(373input_saved_model_dir="embedding_model", conversion_params=params374)375converter.convert()376converter.save("tensorrt_embedding_model")377378"""379**Notes on the parameters inside of `tf.experimental.tensorrt.ConversionParams()`**:380381* `precision_mode` defines the numerical precision of the operations in the382to-be-converted model.383* `maximum_cached_engines` specifies the maximum number of TRT engines that will be384cached to handle dynamic operations (operations with unknown shapes).385386To learn more about the other options, refer to the387[official documentation](https://www.tensorflow.org/api_docs/python/tf/experimental/tensorrt/ConversionParams).388You can also explore the different quantization options provided by the389`tf.experimental.tensorrt` module.390"""391392# Load the converted model.393root = tf.saved_model.load("tensorrt_embedding_model")394trt_model_function = root.signatures["serving_default"]395396"""397## Build LSH tables with optimized model398"""399400warmup()401402training_files = zip(images, labels)403lsh_builder_trt = BuildLSHTable(trt_model_function, concrete_function=True)404lsh_builder_trt.train(training_files)405406"""407Notice the difference in the wall time which is **13.1 seconds**. Earlier, with the408unoptimized model it was **54.1 seconds**.409410We can take a closer look into one of the hash tables and get an idea of how they are411represented.412"""413414idx = 0415for hash, entry in lsh_builder_trt.lsh.tables[0].table.items():416if idx == 5:417break418if len(entry) < 5:419print(hash, entry)420idx += 1421422"""423## Visualize results on validation images424425In this section we will first writing a couple of utility functions to visualize the426similar image parsing process. Then we will benchmark the query performance of the models427with and without optimization.428"""429430"""431First, we take 100 images from the validation set for testing purposes.432"""433434validation_images = []435validation_labels = []436437for image, label in validation_ds.take(100):438image = tf.image.resize(image, (224, 224))439validation_images.append(image.numpy())440validation_labels.append(label.numpy())441442validation_images = np.array(validation_images)443validation_labels = np.array(validation_labels)444validation_images.shape, validation_labels.shape445446447"""448Now we write our visualization utilities.449"""450451452def plot_images(images, labels):453plt.figure(figsize=(20, 10))454columns = 5455for i, image in enumerate(images):456ax = plt.subplot(len(images) // columns + 1, columns, i + 1)457if i == 0:458ax.set_title("Query Image\n" + "Label: {}".format(labels[i]))459else:460ax.set_title("Similar Image # " + str(i) + "\nLabel: {}".format(labels[i]))461plt.imshow(image.astype("int"))462plt.axis("off")463464465def visualize_lsh(lsh_class):466idx = np.random.choice(len(validation_images))467image = validation_images[idx]468label = validation_labels[idx]469results = lsh_class.query(image)470471candidates = []472labels = []473overlaps = []474475for idx, r in enumerate(sorted(results, key=results.get, reverse=True)):476if idx == 4:477break478image_id, label = r.split("_")[0], r.split("_")[1]479candidates.append(images[int(image_id)])480labels.append(label)481overlaps.append(results[r])482483candidates.insert(0, image)484labels.insert(0, label)485486plot_images(candidates, labels)487488489"""490### Non-TRT model491"""492493for _ in range(5):494visualize_lsh(lsh_builder)495496visualize_lsh(lsh_builder)497498"""499### TRT model500"""501502for _ in range(5):503visualize_lsh(lsh_builder_trt)504505"""506As you may have noticed, there are a couple of incorrect results. This can be mitigated in507a few ways:508509* Better models for generating the initial embeddings especially for noisy samples. We can510use techniques like [ArcFace](https://arxiv.org/abs/1801.07698),511[Supervised Contrastive Learning](https://arxiv.org/abs/2004.11362), etc.512that implicitly encourage better learning of representations for retrieval purposes.513* The trade-off between the number of tables and the reduction dimensionality is crucial514and helps set the right recall required for your application.515"""516517"""518## Benchmarking query performance519"""520521522def benchmark(lsh_class):523warmup()524525start_time = time.time()526for _ in range(1000):527image = np.ones((1, 224, 224, 3)).astype("float32")528_ = lsh_class.query(image, verbose=False)529end_time = time.time() - start_time530print(f"Time taken: {end_time:.3f}")531532533benchmark(lsh_builder)534535benchmark(lsh_builder_trt)536537"""538We can immediately notice a stark difference between the query performance of the two539models.540"""541542"""543## Final remarks544545In this example, we explored the TensorRT framework from NVIDIA for optimizing our model.546It's best suited for GPU-based inference servers. There are other choices for such547frameworks that cater to different hardware platforms:548549* [TensorFlow Lite](https://www.tensorflow.org/lite) for mobile and edge devices.550* [ONNX](hhttps://onnx.ai/) for commodity CPU-based servers.551* [Apache TVM](https://tvm.apache.org/), compiler for machine learning models covering552various platforms.553554Here are a few resources you might want to check out to learn more555about applications based on vector similary search in general:556557* [ANN Benchmarks](http://ann-benchmarks.com/)558* [Accelerating Large-Scale Inference with Anisotropic Vector Quantization(ScaNN)](https://arxiv.org/abs/1908.10396)559* [Spreading vectors for similarity search](https://arxiv.org/abs/1806.03198)560* [Building a real-time embeddings similarity matching system](https://cloud.google.com/architecture/building-real-time-embeddings-similarity-matching-system)561"""562563564