Path: blob/master/examples/vision/metric_learning.py
3507 views
"""1Title: Metric learning for image similarity search2Author: [Mat Kelcey](https://twitter.com/mat_kelcey)3Date created: 2020/06/054Last modified: 2020/06/095Description: Example of using similarity metric learning on CIFAR-10 images.6Accelerator: GPU7"""89"""10## Overview1112Metric learning aims to train models that can embed inputs into a high-dimensional space13such that "similar" inputs, as defined by the training scheme, are located close to each14other. These models once trained can produce embeddings for downstream systems where such15similarity is useful; examples include as a ranking signal for search or as a form of16pretrained embedding model for another supervised problem.1718For a more detailed overview of metric learning see:1920* [What is metric learning?](http://contrib.scikit-learn.org/metric-learn/introduction.html)21* ["Using crossentropy for metric learning" tutorial](https://www.youtube.com/watch?v=Jb4Ewl5RzkI)22"""2324"""25## Setup2627Set Keras backend to tensorflow.28"""29import os3031os.environ["KERAS_BACKEND"] = "tensorflow"3233import random34import matplotlib.pyplot as plt35import numpy as np36import tensorflow as tf37from collections import defaultdict38from PIL import Image39from sklearn.metrics import ConfusionMatrixDisplay40import keras41from keras import layers4243"""44## Dataset4546For this example we will be using the47[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.48"""4950from keras.datasets import cifar10515253(x_train, y_train), (x_test, y_test) = cifar10.load_data()5455x_train = x_train.astype("float32") / 255.056y_train = np.squeeze(y_train)57x_test = x_test.astype("float32") / 255.058y_test = np.squeeze(y_test)5960"""61To get a sense of the dataset we can visualise a grid of 25 random examples.626364"""6566height_width = 32676869def show_collage(examples):70box_size = height_width + 271num_rows, num_cols = examples.shape[:2]7273collage = Image.new(74mode="RGB",75size=(num_cols * box_size, num_rows * box_size),76color=(250, 250, 250),77)78for row_idx in range(num_rows):79for col_idx in range(num_cols):80array = (np.array(examples[row_idx, col_idx]) * 255).astype(np.uint8)81collage.paste(82Image.fromarray(array), (col_idx * box_size, row_idx * box_size)83)8485# Double size for visualisation.86collage = collage.resize((2 * num_cols * box_size, 2 * num_rows * box_size))87return collage888990# Show a collage of 5x5 random images.91sample_idxs = np.random.randint(0, 50000, size=(5, 5))92examples = x_train[sample_idxs]93show_collage(examples)9495"""96Metric learning provides training data not as explicit `(X, y)` pairs but instead uses97multiple instances that are related in the way we want to express similarity. In our98example we will use instances of the same class to represent similarity; a single99training instance will not be one image, but a pair of images of the same class. When100referring to the images in this pair we'll use the common metric learning names of the101`anchor` (a randomly chosen image) and the `positive` (another randomly chosen image of102the same class).103104To facilitate this we need to build a form of lookup that maps from classes to the105instances of that class. When generating data for training we will sample from this106lookup.107"""108109class_idx_to_train_idxs = defaultdict(list)110for y_train_idx, y in enumerate(y_train):111class_idx_to_train_idxs[y].append(y_train_idx)112113class_idx_to_test_idxs = defaultdict(list)114for y_test_idx, y in enumerate(y_test):115class_idx_to_test_idxs[y].append(y_test_idx)116117"""118For this example we are using the simplest approach to training; a batch will consist of119`(anchor, positive)` pairs spread across the classes. The goal of learning will be to120move the anchor and positive pairs closer together and further away from other instances121in the batch. In this case the batch size will be dictated by the number of classes; for122CIFAR-10 this is 10.123"""124125num_classes = 10126127128class AnchorPositivePairs(keras.utils.Sequence):129def __init__(self, num_batches):130super().__init__()131self.num_batches = num_batches132133def __len__(self):134return self.num_batches135136def __getitem__(self, _idx):137x = np.empty((2, num_classes, height_width, height_width, 3), dtype=np.float32)138for class_idx in range(num_classes):139examples_for_class = class_idx_to_train_idxs[class_idx]140anchor_idx = random.choice(examples_for_class)141positive_idx = random.choice(examples_for_class)142while positive_idx == anchor_idx:143positive_idx = random.choice(examples_for_class)144x[0, class_idx] = x_train[anchor_idx]145x[1, class_idx] = x_train[positive_idx]146return x147148149"""150We can visualise a batch in another collage. The top row shows randomly chosen anchors151from the 10 classes, the bottom row shows the corresponding 10 positives.152"""153154examples = next(iter(AnchorPositivePairs(num_batches=1)))155156show_collage(examples)157158"""159## Embedding model160161We define a custom model with a `train_step` that first embeds both anchors and positives162and then uses their pairwise dot products as logits for a softmax.163"""164165166class EmbeddingModel(keras.Model):167def train_step(self, data):168# Note: Workaround for open issue, to be removed.169if isinstance(data, tuple):170data = data[0]171anchors, positives = data[0], data[1]172173with tf.GradientTape() as tape:174# Run both anchors and positives through model.175anchor_embeddings = self(anchors, training=True)176positive_embeddings = self(positives, training=True)177178# Calculate cosine similarity between anchors and positives. As they have179# been normalised this is just the pair wise dot products.180similarities = keras.ops.einsum(181"ae,pe->ap", anchor_embeddings, positive_embeddings182)183184# Since we intend to use these as logits we scale them by a temperature.185# This value would normally be chosen as a hyper parameter.186temperature = 0.2187similarities /= temperature188189# We use these similarities as logits for a softmax. The labels for190# this call are just the sequence [0, 1, 2, ..., num_classes] since we191# want the main diagonal values, which correspond to the anchor/positive192# pairs, to be high. This loss will move embeddings for the193# anchor/positive pairs together and move all other pairs apart.194sparse_labels = keras.ops.arange(num_classes)195loss = self.compute_loss(y=sparse_labels, y_pred=similarities)196197# Calculate gradients and apply via optimizer.198gradients = tape.gradient(loss, self.trainable_variables)199self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))200201# Update and return metrics (specifically the one for the loss value).202for metric in self.metrics:203# Calling `self.compile` will by default add a `keras.metrics.Mean` loss204if metric.name == "loss":205metric.update_state(loss)206else:207metric.update_state(sparse_labels, similarities)208209return {m.name: m.result() for m in self.metrics}210211212"""213Next we describe the architecture that maps from an image to an embedding. This model214simply consists of a sequence of 2d convolutions followed by global pooling with a final215linear projection to an embedding space. As is common in metric learning we normalise the216embeddings so that we can use simple dot products to measure similarity. For simplicity217this model is intentionally small.218"""219220inputs = layers.Input(shape=(height_width, height_width, 3))221x = layers.Conv2D(filters=32, kernel_size=3, strides=2, activation="relu")(inputs)222x = layers.Conv2D(filters=64, kernel_size=3, strides=2, activation="relu")(x)223x = layers.Conv2D(filters=128, kernel_size=3, strides=2, activation="relu")(x)224x = layers.GlobalAveragePooling2D()(x)225embeddings = layers.Dense(units=8, activation=None)(x)226embeddings = layers.UnitNormalization()(embeddings)227228model = EmbeddingModel(inputs, embeddings)229230"""231Finally we run the training. On a Google Colab GPU instance this takes about a minute.232"""233model.compile(234optimizer=keras.optimizers.Adam(learning_rate=1e-3),235loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),236)237238history = model.fit(AnchorPositivePairs(num_batches=1000), epochs=20)239240plt.plot(history.history["loss"])241plt.show()242243"""244## Testing245246We can review the quality of this model by applying it to the test set and considering247near neighbours in the embedding space.248249First we embed the test set and calculate all near neighbours. Recall that since the250embeddings are unit length we can calculate cosine similarity via dot products.251"""252253near_neighbours_per_example = 10254255embeddings = model.predict(x_test)256gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings)257near_neighbours = np.argsort(gram_matrix.T)[:, -(near_neighbours_per_example + 1) :]258259"""260As a visual check of these embeddings we can build a collage of the near neighbours for 5261random examples. The first column of the image below is a randomly selected image, the262following 10 columns show the nearest neighbours in order of similarity.263"""264265num_collage_examples = 5266267examples = np.empty(268(269num_collage_examples,270near_neighbours_per_example + 1,271height_width,272height_width,2733,274),275dtype=np.float32,276)277for row_idx in range(num_collage_examples):278examples[row_idx, 0] = x_test[row_idx]279anchor_near_neighbours = reversed(near_neighbours[row_idx][:-1])280for col_idx, nn_idx in enumerate(anchor_near_neighbours):281examples[row_idx, col_idx + 1] = x_test[nn_idx]282283show_collage(examples)284285"""286We can also get a quantified view of the performance by considering the correctness of287near neighbours in terms of a confusion matrix.288289Let us sample 10 examples from each of the 10 classes and consider their near neighbours290as a form of prediction; that is, does the example and its near neighbours share the same291class?292293We observe that each animal class does generally well, and is confused the most with the294other animal classes. The vehicle classes follow the same pattern.295"""296297confusion_matrix = np.zeros((num_classes, num_classes))298299# For each class.300for class_idx in range(num_classes):301# Consider 10 examples.302example_idxs = class_idx_to_test_idxs[class_idx][:10]303for y_test_idx in example_idxs:304# And count the classes of its near neighbours.305for nn_idx in near_neighbours[y_test_idx][:-1]:306nn_class_idx = y_test[nn_idx]307confusion_matrix[class_idx, nn_class_idx] += 1308309# Display a confusion matrix.310labels = [311"Airplane",312"Automobile",313"Bird",314"Cat",315"Deer",316"Dog",317"Frog",318"Horse",319"Ship",320"Truck",321]322disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=labels)323disp.plot(include_values=True, cmap="viridis", ax=None, xticks_rotation="vertical")324plt.show()325326327