Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/metric_learning.py
3507 views
1
"""
2
Title: Metric learning for image similarity search
3
Author: [Mat Kelcey](https://twitter.com/mat_kelcey)
4
Date created: 2020/06/05
5
Last modified: 2020/06/09
6
Description: Example of using similarity metric learning on CIFAR-10 images.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Overview
12
13
Metric learning aims to train models that can embed inputs into a high-dimensional space
14
such that "similar" inputs, as defined by the training scheme, are located close to each
15
other. These models once trained can produce embeddings for downstream systems where such
16
similarity is useful; examples include as a ranking signal for search or as a form of
17
pretrained embedding model for another supervised problem.
18
19
For a more detailed overview of metric learning see:
20
21
* [What is metric learning?](http://contrib.scikit-learn.org/metric-learn/introduction.html)
22
* ["Using crossentropy for metric learning" tutorial](https://www.youtube.com/watch?v=Jb4Ewl5RzkI)
23
"""
24
25
"""
26
## Setup
27
28
Set Keras backend to tensorflow.
29
"""
30
import os
31
32
os.environ["KERAS_BACKEND"] = "tensorflow"
33
34
import random
35
import matplotlib.pyplot as plt
36
import numpy as np
37
import tensorflow as tf
38
from collections import defaultdict
39
from PIL import Image
40
from sklearn.metrics import ConfusionMatrixDisplay
41
import keras
42
from keras import layers
43
44
"""
45
## Dataset
46
47
For this example we will be using the
48
[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.
49
"""
50
51
from keras.datasets import cifar10
52
53
54
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
55
56
x_train = x_train.astype("float32") / 255.0
57
y_train = np.squeeze(y_train)
58
x_test = x_test.astype("float32") / 255.0
59
y_test = np.squeeze(y_test)
60
61
"""
62
To get a sense of the dataset we can visualise a grid of 25 random examples.
63
64
65
"""
66
67
height_width = 32
68
69
70
def show_collage(examples):
71
box_size = height_width + 2
72
num_rows, num_cols = examples.shape[:2]
73
74
collage = Image.new(
75
mode="RGB",
76
size=(num_cols * box_size, num_rows * box_size),
77
color=(250, 250, 250),
78
)
79
for row_idx in range(num_rows):
80
for col_idx in range(num_cols):
81
array = (np.array(examples[row_idx, col_idx]) * 255).astype(np.uint8)
82
collage.paste(
83
Image.fromarray(array), (col_idx * box_size, row_idx * box_size)
84
)
85
86
# Double size for visualisation.
87
collage = collage.resize((2 * num_cols * box_size, 2 * num_rows * box_size))
88
return collage
89
90
91
# Show a collage of 5x5 random images.
92
sample_idxs = np.random.randint(0, 50000, size=(5, 5))
93
examples = x_train[sample_idxs]
94
show_collage(examples)
95
96
"""
97
Metric learning provides training data not as explicit `(X, y)` pairs but instead uses
98
multiple instances that are related in the way we want to express similarity. In our
99
example we will use instances of the same class to represent similarity; a single
100
training instance will not be one image, but a pair of images of the same class. When
101
referring to the images in this pair we'll use the common metric learning names of the
102
`anchor` (a randomly chosen image) and the `positive` (another randomly chosen image of
103
the same class).
104
105
To facilitate this we need to build a form of lookup that maps from classes to the
106
instances of that class. When generating data for training we will sample from this
107
lookup.
108
"""
109
110
class_idx_to_train_idxs = defaultdict(list)
111
for y_train_idx, y in enumerate(y_train):
112
class_idx_to_train_idxs[y].append(y_train_idx)
113
114
class_idx_to_test_idxs = defaultdict(list)
115
for y_test_idx, y in enumerate(y_test):
116
class_idx_to_test_idxs[y].append(y_test_idx)
117
118
"""
119
For this example we are using the simplest approach to training; a batch will consist of
120
`(anchor, positive)` pairs spread across the classes. The goal of learning will be to
121
move the anchor and positive pairs closer together and further away from other instances
122
in the batch. In this case the batch size will be dictated by the number of classes; for
123
CIFAR-10 this is 10.
124
"""
125
126
num_classes = 10
127
128
129
class AnchorPositivePairs(keras.utils.Sequence):
130
def __init__(self, num_batches):
131
super().__init__()
132
self.num_batches = num_batches
133
134
def __len__(self):
135
return self.num_batches
136
137
def __getitem__(self, _idx):
138
x = np.empty((2, num_classes, height_width, height_width, 3), dtype=np.float32)
139
for class_idx in range(num_classes):
140
examples_for_class = class_idx_to_train_idxs[class_idx]
141
anchor_idx = random.choice(examples_for_class)
142
positive_idx = random.choice(examples_for_class)
143
while positive_idx == anchor_idx:
144
positive_idx = random.choice(examples_for_class)
145
x[0, class_idx] = x_train[anchor_idx]
146
x[1, class_idx] = x_train[positive_idx]
147
return x
148
149
150
"""
151
We can visualise a batch in another collage. The top row shows randomly chosen anchors
152
from the 10 classes, the bottom row shows the corresponding 10 positives.
153
"""
154
155
examples = next(iter(AnchorPositivePairs(num_batches=1)))
156
157
show_collage(examples)
158
159
"""
160
## Embedding model
161
162
We define a custom model with a `train_step` that first embeds both anchors and positives
163
and then uses their pairwise dot products as logits for a softmax.
164
"""
165
166
167
class EmbeddingModel(keras.Model):
168
def train_step(self, data):
169
# Note: Workaround for open issue, to be removed.
170
if isinstance(data, tuple):
171
data = data[0]
172
anchors, positives = data[0], data[1]
173
174
with tf.GradientTape() as tape:
175
# Run both anchors and positives through model.
176
anchor_embeddings = self(anchors, training=True)
177
positive_embeddings = self(positives, training=True)
178
179
# Calculate cosine similarity between anchors and positives. As they have
180
# been normalised this is just the pair wise dot products.
181
similarities = keras.ops.einsum(
182
"ae,pe->ap", anchor_embeddings, positive_embeddings
183
)
184
185
# Since we intend to use these as logits we scale them by a temperature.
186
# This value would normally be chosen as a hyper parameter.
187
temperature = 0.2
188
similarities /= temperature
189
190
# We use these similarities as logits for a softmax. The labels for
191
# this call are just the sequence [0, 1, 2, ..., num_classes] since we
192
# want the main diagonal values, which correspond to the anchor/positive
193
# pairs, to be high. This loss will move embeddings for the
194
# anchor/positive pairs together and move all other pairs apart.
195
sparse_labels = keras.ops.arange(num_classes)
196
loss = self.compute_loss(y=sparse_labels, y_pred=similarities)
197
198
# Calculate gradients and apply via optimizer.
199
gradients = tape.gradient(loss, self.trainable_variables)
200
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
201
202
# Update and return metrics (specifically the one for the loss value).
203
for metric in self.metrics:
204
# Calling `self.compile` will by default add a `keras.metrics.Mean` loss
205
if metric.name == "loss":
206
metric.update_state(loss)
207
else:
208
metric.update_state(sparse_labels, similarities)
209
210
return {m.name: m.result() for m in self.metrics}
211
212
213
"""
214
Next we describe the architecture that maps from an image to an embedding. This model
215
simply consists of a sequence of 2d convolutions followed by global pooling with a final
216
linear projection to an embedding space. As is common in metric learning we normalise the
217
embeddings so that we can use simple dot products to measure similarity. For simplicity
218
this model is intentionally small.
219
"""
220
221
inputs = layers.Input(shape=(height_width, height_width, 3))
222
x = layers.Conv2D(filters=32, kernel_size=3, strides=2, activation="relu")(inputs)
223
x = layers.Conv2D(filters=64, kernel_size=3, strides=2, activation="relu")(x)
224
x = layers.Conv2D(filters=128, kernel_size=3, strides=2, activation="relu")(x)
225
x = layers.GlobalAveragePooling2D()(x)
226
embeddings = layers.Dense(units=8, activation=None)(x)
227
embeddings = layers.UnitNormalization()(embeddings)
228
229
model = EmbeddingModel(inputs, embeddings)
230
231
"""
232
Finally we run the training. On a Google Colab GPU instance this takes about a minute.
233
"""
234
model.compile(
235
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
236
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
237
)
238
239
history = model.fit(AnchorPositivePairs(num_batches=1000), epochs=20)
240
241
plt.plot(history.history["loss"])
242
plt.show()
243
244
"""
245
## Testing
246
247
We can review the quality of this model by applying it to the test set and considering
248
near neighbours in the embedding space.
249
250
First we embed the test set and calculate all near neighbours. Recall that since the
251
embeddings are unit length we can calculate cosine similarity via dot products.
252
"""
253
254
near_neighbours_per_example = 10
255
256
embeddings = model.predict(x_test)
257
gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings)
258
near_neighbours = np.argsort(gram_matrix.T)[:, -(near_neighbours_per_example + 1) :]
259
260
"""
261
As a visual check of these embeddings we can build a collage of the near neighbours for 5
262
random examples. The first column of the image below is a randomly selected image, the
263
following 10 columns show the nearest neighbours in order of similarity.
264
"""
265
266
num_collage_examples = 5
267
268
examples = np.empty(
269
(
270
num_collage_examples,
271
near_neighbours_per_example + 1,
272
height_width,
273
height_width,
274
3,
275
),
276
dtype=np.float32,
277
)
278
for row_idx in range(num_collage_examples):
279
examples[row_idx, 0] = x_test[row_idx]
280
anchor_near_neighbours = reversed(near_neighbours[row_idx][:-1])
281
for col_idx, nn_idx in enumerate(anchor_near_neighbours):
282
examples[row_idx, col_idx + 1] = x_test[nn_idx]
283
284
show_collage(examples)
285
286
"""
287
We can also get a quantified view of the performance by considering the correctness of
288
near neighbours in terms of a confusion matrix.
289
290
Let us sample 10 examples from each of the 10 classes and consider their near neighbours
291
as a form of prediction; that is, does the example and its near neighbours share the same
292
class?
293
294
We observe that each animal class does generally well, and is confused the most with the
295
other animal classes. The vehicle classes follow the same pattern.
296
"""
297
298
confusion_matrix = np.zeros((num_classes, num_classes))
299
300
# For each class.
301
for class_idx in range(num_classes):
302
# Consider 10 examples.
303
example_idxs = class_idx_to_test_idxs[class_idx][:10]
304
for y_test_idx in example_idxs:
305
# And count the classes of its near neighbours.
306
for nn_idx in near_neighbours[y_test_idx][:-1]:
307
nn_class_idx = y_test[nn_idx]
308
confusion_matrix[class_idx, nn_class_idx] += 1
309
310
# Display a confusion matrix.
311
labels = [
312
"Airplane",
313
"Automobile",
314
"Bird",
315
"Cat",
316
"Deer",
317
"Dog",
318
"Frog",
319
"Horse",
320
"Ship",
321
"Truck",
322
]
323
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=labels)
324
disp.plot(include_values=True, cmap="viridis", ax=None, xticks_rotation="vertical")
325
plt.show()
326
327