Path: blob/master/examples/vision/metric_learning_tf_similarity.py
3507 views
"""1Title: Metric learning for image similarity search using TensorFlow Similarity2Author: [Owen Vallis](https://twitter.com/owenvallis)3Date created: 2021/09/304Last modified: 2022/02/295Description: Example of using similarity metric learning on CIFAR-10 images.6Accelerator: GPU7"""89"""10## Overview1112This example is based on the13["Metric learning for image similarity search" example](https://keras.io/examples/vision/metric_learning/).14We aim to use the same data set but implement the model using15[TensorFlow Similarity](https://github.com/tensorflow/similarity).1617Metric learning aims to train models that can embed inputs into a18high-dimensional space such that "similar" inputs are pulled closer to each19other and "dissimilar" inputs are pushed farther apart. Once trained, these20models can produce embeddings for downstream systems where such similarity is21useful, for instance as a ranking signal for search or as a form of pretrained22embedding model for another supervised problem.2324For a more detailed overview of metric learning, see:2526* [What is metric learning?](http://contrib.scikit-learn.org/metric-learn/introduction.html)27* ["Using crossentropy for metric learning" tutorial](https://www.youtube.com/watch?v=Jb4Ewl5RzkI)28"""2930"""31## Setup3233This tutorial will use the [TensorFlow Similarity](https://github.com/tensorflow/similarity) library34to learn and evaluate the similarity embedding.35TensorFlow Similarity provides components that:3637* Make training contrastive models simple and fast.38* Make it easier to ensure that batches contain pairs of examples.39* Enable the evaluation of the quality of the embedding.4041TensorFlow Similarity can be installed easily via pip, as follows:4243```44pip -q install tensorflow_similarity45```4647"""4849import random5051from matplotlib import pyplot as plt52from mpl_toolkits import axes_grid153import numpy as np5455import tensorflow as tf56from tensorflow import keras5758import tensorflow_similarity as tfsim596061tfsim.utils.tf_cap_memory()6263print("TensorFlow:", tf.__version__)64print("TensorFlow Similarity:", tfsim.__version__)6566"""67## Dataset samplers6869We will be using the70[CIFAR-10](https://www.tensorflow.org/datasets/catalog/cifar10)71dataset for this tutorial.7273For a similarity model to learn efficiently, each batch must contain at least 274examples of each class.7576To make this easy, tf_similarity offers `Sampler` objects that enable you to set both77the number of classes and the minimum number of examples of each class per78batch.7980The training and validation datasets will be created using the81`TFDatasetMultiShotMemorySampler` object. This creates a sampler that loads datasets82from [TensorFlow Datasets](https://www.tensorflow.org/datasets) and yields83batches containing a target number of classes and a target number of examples84per class. Additionally, we can restrict the sampler to only yield the subset of85classes defined in `class_list`, enabling us to train on a subset of the classes86and then test how the embedding generalizes to the unseen classes. This can be87useful when working on few-shot learning problems.8889The following cell creates a train_ds sample that:9091* Loads the CIFAR-10 dataset from TFDS and then takes the `examples_per_class_per_batch`.92* Ensures the sampler restricts the classes to those defined in `class_list`.93* Ensures each batch contains 10 different classes with 8 examples each.9495We also create a validation dataset in the same way, but we limit the total number of96examples per class to 100 and the examples per class per batch is set to the97default of 2.98"""99# This determines the number of classes used during training.100# Here we are using all the classes.101num_known_classes = 10102class_list = random.sample(population=range(10), k=num_known_classes)103104classes_per_batch = 10105# Passing multiple examples per class per batch ensures that each example has106# multiple positive pairs. This can be useful when performing triplet mining or107# when using losses like `MultiSimilarityLoss` or `CircleLoss` as these can108# take a weighted mix of all the positive pairs. In general, more examples per109# class will lead to more information for the positive pairs, while more classes110# per batch will provide more varied information in the negative pairs. However,111# the losses compute the pairwise distance between the examples in a batch so112# the upper limit of the batch size is restricted by the memory.113examples_per_class_per_batch = 8114115print(116"Batch size is: "117f"{min(classes_per_batch, num_known_classes) * examples_per_class_per_batch}"118)119120print(" Create Training Data ".center(34, "#"))121train_ds = tfsim.samplers.TFDatasetMultiShotMemorySampler(122"cifar10",123classes_per_batch=min(classes_per_batch, num_known_classes),124splits="train",125steps_per_epoch=4000,126examples_per_class_per_batch=examples_per_class_per_batch,127class_list=class_list,128)129130print("\n" + " Create Validation Data ".center(34, "#"))131val_ds = tfsim.samplers.TFDatasetMultiShotMemorySampler(132"cifar10",133classes_per_batch=classes_per_batch,134splits="test",135total_examples_per_class=100,136)137138"""139## Visualize the dataset140141The samplers will shuffle the dataset, so we can get a sense of the dataset by142plotting the first 25 images.143144The samplers provide a `get_slice(begin, size)` method that allows us to easily145select a block of samples.146147Alternatively, we can use the `generate_batch()` method to yield a batch. This148can allow us to check that a batch contains the expected number of classes and149examples per class.150"""151152num_cols = num_rows = 5153# Get the first 25 examples.154x_slice, y_slice = train_ds.get_slice(begin=0, size=num_cols * num_rows)155156fig = plt.figure(figsize=(6.0, 6.0))157grid = axes_grid1.ImageGrid(fig, 111, nrows_ncols=(num_cols, num_rows), axes_pad=0.1)158159for ax, im, label in zip(grid, x_slice, y_slice):160ax.imshow(im)161ax.axis("off")162163"""164## Embedding model165166Next we define a `SimilarityModel` using the Keras Functional API. The model167is a standard convnet with the addition of a `MetricEmbedding` layer that168applies L2 normalization. The metric embedding layer is helpful when using169`Cosine` distance as we only care about the angle between the vectors.170171Additionally, the `SimilarityModel` provides a number of helper methods for:172173* Indexing embedded examples174* Performing example lookups175* Evaluating the classification176* Evaluating the quality of the embedding space177178See the [TensorFlow Similarity documentation](https://github.com/tensorflow/similarity)179for more details.180"""181182embedding_size = 256183184inputs = keras.layers.Input((32, 32, 3))185x = keras.layers.Rescaling(scale=1.0 / 255)(inputs)186x = keras.layers.Conv2D(64, 3, activation="relu")(x)187x = keras.layers.BatchNormalization()(x)188x = keras.layers.Conv2D(128, 3, activation="relu")(x)189x = keras.layers.BatchNormalization()(x)190x = keras.layers.MaxPool2D((4, 4))(x)191x = keras.layers.Conv2D(256, 3, activation="relu")(x)192x = keras.layers.BatchNormalization()(x)193x = keras.layers.Conv2D(256, 3, activation="relu")(x)194x = keras.layers.GlobalMaxPool2D()(x)195outputs = tfsim.layers.MetricEmbedding(embedding_size)(x)196197# building model198model = tfsim.models.SimilarityModel(inputs, outputs)199model.summary()200201"""202## Similarity loss203204The similarity loss expects batches containing at least 2 examples of each205class, from which it computes the loss over the pairwise positive and negative206distances. Here we are using `MultiSimilarityLoss()`207([paper](ihttps://arxiv.org/abs/1904.06627)), one of several losses in208[TensorFlow Similarity](https://github.com/tensorflow/similarity). This loss209attempts to use all informative pairs in the batch, taking into account the210self-similarity, positive-similarity, and the negative-similarity.211"""212213epochs = 3214learning_rate = 0.002215val_steps = 50216217# init similarity loss218loss = tfsim.losses.MultiSimilarityLoss()219220# compiling and training221model.compile(222optimizer=keras.optimizers.Adam(learning_rate),223loss=loss,224steps_per_execution=10,225)226history = model.fit(227train_ds, epochs=epochs, validation_data=val_ds, validation_steps=val_steps228)229230"""231## Indexing232233Now that we have trained our model, we can create an index of examples. Here we234batch index the first 200 validation examples by passing the x and y to the index235along with storing the image in the data parameter. The `x_index` values are236embedded and then added to the index to make them searchable. The `y_index` and237data parameters are optional but allow the user to associate metadata with the238embedded example.239"""240241x_index, y_index = val_ds.get_slice(begin=0, size=200)242model.reset_index()243model.index(x_index, y_index, data=x_index)244245"""246## Calibration247248Once the index is built, we can calibrate a distance threshold using a matching249strategy and a calibration metric.250251Here we are searching for the optimal F1 score while using K=1 as our252classifier. All matches at or below the calibrated threshold distance will be253labeled as a Positive match between the query example and the label associated254with the match result, while all matches above the threshold distance will be255labeled as a Negative match.256257Additionally, we pass in extra metrics to compute as well. All values in the258output are computed at the calibrated threshold.259260Finally, `model.calibrate()` returns a `CalibrationResults` object containing:261262* `"cutpoints"`: A Python dict mapping the cutpoint name to a dict containing the263`ClassificationMetric` values associated with a particular distance threshold,264e.g., `"optimal" : {"acc": 0.90, "f1": 0.92}`.265* `"thresholds"`: A Python dict mapping `ClassificationMetric` names to a list266containing the metric's value computed at each of the distance thresholds, e.g.,267`{"f1": [0.99, 0.80], "distance": [0.0, 1.0]}`.268"""269270x_train, y_train = train_ds.get_slice(begin=0, size=1000)271calibration = model.calibrate(272x_train,273y_train,274calibration_metric="f1",275matcher="match_nearest",276extra_metrics=["precision", "recall", "binary_accuracy"],277verbose=1,278)279280"""281## Visualization282283It may be difficult to get a sense of the model quality from the metrics alone.284A complementary approach is to manually inspect a set of query results to get a285feel for the match quality.286287Here we take 10 validation examples and plot them with their 5 nearest288neighbors and the distances to the query example. Looking at the results, we see289that while they are imperfect they still represent meaningfully similar images,290and that the model is able to find similar images irrespective of their pose or291image illumination.292293We can also see that the model is very confident with certain images, resulting294in very small distances between the query and the neighbors. Conversely, we see295more mistakes in the class labels as the distances become larger. This is one of296the reasons why calibration is critical for matching applications.297"""298299num_neighbors = 5300labels = [301"Airplane",302"Automobile",303"Bird",304"Cat",305"Deer",306"Dog",307"Frog",308"Horse",309"Ship",310"Truck",311"Unknown",312]313class_mapping = {c_id: c_lbl for c_id, c_lbl in zip(range(11), labels)}314315x_display, y_display = val_ds.get_slice(begin=200, size=10)316# lookup nearest neighbors in the index317nns = model.lookup(x_display, k=num_neighbors)318319# display320for idx in np.argsort(y_display):321tfsim.visualization.viz_neigbors_imgs(322x_display[idx],323y_display[idx],324nns[idx],325class_mapping=class_mapping,326fig_size=(16, 2),327)328329"""330## Metrics331332We can also plot the extra metrics contained in the `CalibrationResults` to get333a sense of the matching performance as the distance threshold increases.334335The following plots show the Precision, Recall, and F1 Score. We can see that336the matching precision degrades as the distance increases, but that the337percentage of the queries that we accept as positive matches (recall) grows338faster up to the calibrated distance threshold.339"""340341fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))342x = calibration.thresholds["distance"]343344ax1.plot(x, calibration.thresholds["precision"], label="precision")345ax1.plot(x, calibration.thresholds["recall"], label="recall")346ax1.plot(x, calibration.thresholds["f1"], label="f1 score")347ax1.legend()348ax1.set_title("Metric evolution as distance increase")349ax1.set_xlabel("Distance")350ax1.set_ylim((-0.05, 1.05))351352ax2.plot(calibration.thresholds["recall"], calibration.thresholds["precision"])353ax2.set_title("Precision recall curve")354ax2.set_xlabel("Recall")355ax2.set_ylabel("Precision")356ax2.set_ylim((-0.05, 1.05))357plt.show()358359"""360We can also take 100 examples for each class and plot the confusion matrix for361each example and their nearest match. We also add an "extra" 10th class to362represent the matches above the calibrated distance threshold.363364We can see that most of the errors are between the animal classes with an365interesting number of confusions between Airplane and Bird. Additionally, we see366that only a few of the 100 examples for each class returned matches outside of367the calibrated distance threshold.368"""369370cutpoint = "optimal"371372# This yields 100 examples for each class.373# We defined this when we created the val_ds sampler.374x_confusion, y_confusion = val_ds.get_slice(0, -1)375376matches = model.match(x_confusion, cutpoint=cutpoint, no_match_label=10)377cm = tfsim.visualization.confusion_matrix(378matches,379y_confusion,380labels=labels,381title="Confusion matrix for cutpoint:%s" % cutpoint,382normalize=False,383)384385"""386## No Match387388We can plot the examples outside of the calibrated threshold to see which images389are not matching any indexed examples.390391This may provide insight into what other examples may need to be indexed or392surface anomalous examples within the class.393"""394395idx_no_match = np.where(np.array(matches) == 10)396no_match_queries = x_confusion[idx_no_match]397if len(no_match_queries):398plt.imshow(no_match_queries[0])399else:400print("All queries have a match below the distance threshold.")401402"""403## Visualize clusters404405One of the best ways to quickly get a sense of the quality of how the model is406doing and understand it's short comings is to project the embedding into a 2D407space.408409This allows us to inspect clusters of images and understand which classes are410entangled.411"""412413# Each class in val_ds was restricted to 100 examples.414num_examples_to_clusters = 1000415thumb_size = 96416plot_size = 800417vx, vy = val_ds.get_slice(0, num_examples_to_clusters)418419# Uncomment to run the interactive projector.420# tfsim.visualization.projector(421# model.predict(vx),422# labels=vy,423# images=vx,424# class_mapping=class_mapping,425# image_size=thumb_size,426# plot_size=plot_size,427# )428429430