Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/metric_learning_tf_similarity.py
3507 views
1
"""
2
Title: Metric learning for image similarity search using TensorFlow Similarity
3
Author: [Owen Vallis](https://twitter.com/owenvallis)
4
Date created: 2021/09/30
5
Last modified: 2022/02/29
6
Description: Example of using similarity metric learning on CIFAR-10 images.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Overview
12
13
This example is based on the
14
["Metric learning for image similarity search" example](https://keras.io/examples/vision/metric_learning/).
15
We aim to use the same data set but implement the model using
16
[TensorFlow Similarity](https://github.com/tensorflow/similarity).
17
18
Metric learning aims to train models that can embed inputs into a
19
high-dimensional space such that "similar" inputs are pulled closer to each
20
other and "dissimilar" inputs are pushed farther apart. Once trained, these
21
models can produce embeddings for downstream systems where such similarity is
22
useful, for instance as a ranking signal for search or as a form of pretrained
23
embedding model for another supervised problem.
24
25
For a more detailed overview of metric learning, see:
26
27
* [What is metric learning?](http://contrib.scikit-learn.org/metric-learn/introduction.html)
28
* ["Using crossentropy for metric learning" tutorial](https://www.youtube.com/watch?v=Jb4Ewl5RzkI)
29
"""
30
31
"""
32
## Setup
33
34
This tutorial will use the [TensorFlow Similarity](https://github.com/tensorflow/similarity) library
35
to learn and evaluate the similarity embedding.
36
TensorFlow Similarity provides components that:
37
38
* Make training contrastive models simple and fast.
39
* Make it easier to ensure that batches contain pairs of examples.
40
* Enable the evaluation of the quality of the embedding.
41
42
TensorFlow Similarity can be installed easily via pip, as follows:
43
44
```
45
pip -q install tensorflow_similarity
46
```
47
48
"""
49
50
import random
51
52
from matplotlib import pyplot as plt
53
from mpl_toolkits import axes_grid1
54
import numpy as np
55
56
import tensorflow as tf
57
from tensorflow import keras
58
59
import tensorflow_similarity as tfsim
60
61
62
tfsim.utils.tf_cap_memory()
63
64
print("TensorFlow:", tf.__version__)
65
print("TensorFlow Similarity:", tfsim.__version__)
66
67
"""
68
## Dataset samplers
69
70
We will be using the
71
[CIFAR-10](https://www.tensorflow.org/datasets/catalog/cifar10)
72
dataset for this tutorial.
73
74
For a similarity model to learn efficiently, each batch must contain at least 2
75
examples of each class.
76
77
To make this easy, tf_similarity offers `Sampler` objects that enable you to set both
78
the number of classes and the minimum number of examples of each class per
79
batch.
80
81
The training and validation datasets will be created using the
82
`TFDatasetMultiShotMemorySampler` object. This creates a sampler that loads datasets
83
from [TensorFlow Datasets](https://www.tensorflow.org/datasets) and yields
84
batches containing a target number of classes and a target number of examples
85
per class. Additionally, we can restrict the sampler to only yield the subset of
86
classes defined in `class_list`, enabling us to train on a subset of the classes
87
and then test how the embedding generalizes to the unseen classes. This can be
88
useful when working on few-shot learning problems.
89
90
The following cell creates a train_ds sample that:
91
92
* Loads the CIFAR-10 dataset from TFDS and then takes the `examples_per_class_per_batch`.
93
* Ensures the sampler restricts the classes to those defined in `class_list`.
94
* Ensures each batch contains 10 different classes with 8 examples each.
95
96
We also create a validation dataset in the same way, but we limit the total number of
97
examples per class to 100 and the examples per class per batch is set to the
98
default of 2.
99
"""
100
# This determines the number of classes used during training.
101
# Here we are using all the classes.
102
num_known_classes = 10
103
class_list = random.sample(population=range(10), k=num_known_classes)
104
105
classes_per_batch = 10
106
# Passing multiple examples per class per batch ensures that each example has
107
# multiple positive pairs. This can be useful when performing triplet mining or
108
# when using losses like `MultiSimilarityLoss` or `CircleLoss` as these can
109
# take a weighted mix of all the positive pairs. In general, more examples per
110
# class will lead to more information for the positive pairs, while more classes
111
# per batch will provide more varied information in the negative pairs. However,
112
# the losses compute the pairwise distance between the examples in a batch so
113
# the upper limit of the batch size is restricted by the memory.
114
examples_per_class_per_batch = 8
115
116
print(
117
"Batch size is: "
118
f"{min(classes_per_batch, num_known_classes) * examples_per_class_per_batch}"
119
)
120
121
print(" Create Training Data ".center(34, "#"))
122
train_ds = tfsim.samplers.TFDatasetMultiShotMemorySampler(
123
"cifar10",
124
classes_per_batch=min(classes_per_batch, num_known_classes),
125
splits="train",
126
steps_per_epoch=4000,
127
examples_per_class_per_batch=examples_per_class_per_batch,
128
class_list=class_list,
129
)
130
131
print("\n" + " Create Validation Data ".center(34, "#"))
132
val_ds = tfsim.samplers.TFDatasetMultiShotMemorySampler(
133
"cifar10",
134
classes_per_batch=classes_per_batch,
135
splits="test",
136
total_examples_per_class=100,
137
)
138
139
"""
140
## Visualize the dataset
141
142
The samplers will shuffle the dataset, so we can get a sense of the dataset by
143
plotting the first 25 images.
144
145
The samplers provide a `get_slice(begin, size)` method that allows us to easily
146
select a block of samples.
147
148
Alternatively, we can use the `generate_batch()` method to yield a batch. This
149
can allow us to check that a batch contains the expected number of classes and
150
examples per class.
151
"""
152
153
num_cols = num_rows = 5
154
# Get the first 25 examples.
155
x_slice, y_slice = train_ds.get_slice(begin=0, size=num_cols * num_rows)
156
157
fig = plt.figure(figsize=(6.0, 6.0))
158
grid = axes_grid1.ImageGrid(fig, 111, nrows_ncols=(num_cols, num_rows), axes_pad=0.1)
159
160
for ax, im, label in zip(grid, x_slice, y_slice):
161
ax.imshow(im)
162
ax.axis("off")
163
164
"""
165
## Embedding model
166
167
Next we define a `SimilarityModel` using the Keras Functional API. The model
168
is a standard convnet with the addition of a `MetricEmbedding` layer that
169
applies L2 normalization. The metric embedding layer is helpful when using
170
`Cosine` distance as we only care about the angle between the vectors.
171
172
Additionally, the `SimilarityModel` provides a number of helper methods for:
173
174
* Indexing embedded examples
175
* Performing example lookups
176
* Evaluating the classification
177
* Evaluating the quality of the embedding space
178
179
See the [TensorFlow Similarity documentation](https://github.com/tensorflow/similarity)
180
for more details.
181
"""
182
183
embedding_size = 256
184
185
inputs = keras.layers.Input((32, 32, 3))
186
x = keras.layers.Rescaling(scale=1.0 / 255)(inputs)
187
x = keras.layers.Conv2D(64, 3, activation="relu")(x)
188
x = keras.layers.BatchNormalization()(x)
189
x = keras.layers.Conv2D(128, 3, activation="relu")(x)
190
x = keras.layers.BatchNormalization()(x)
191
x = keras.layers.MaxPool2D((4, 4))(x)
192
x = keras.layers.Conv2D(256, 3, activation="relu")(x)
193
x = keras.layers.BatchNormalization()(x)
194
x = keras.layers.Conv2D(256, 3, activation="relu")(x)
195
x = keras.layers.GlobalMaxPool2D()(x)
196
outputs = tfsim.layers.MetricEmbedding(embedding_size)(x)
197
198
# building model
199
model = tfsim.models.SimilarityModel(inputs, outputs)
200
model.summary()
201
202
"""
203
## Similarity loss
204
205
The similarity loss expects batches containing at least 2 examples of each
206
class, from which it computes the loss over the pairwise positive and negative
207
distances. Here we are using `MultiSimilarityLoss()`
208
([paper](ihttps://arxiv.org/abs/1904.06627)), one of several losses in
209
[TensorFlow Similarity](https://github.com/tensorflow/similarity). This loss
210
attempts to use all informative pairs in the batch, taking into account the
211
self-similarity, positive-similarity, and the negative-similarity.
212
"""
213
214
epochs = 3
215
learning_rate = 0.002
216
val_steps = 50
217
218
# init similarity loss
219
loss = tfsim.losses.MultiSimilarityLoss()
220
221
# compiling and training
222
model.compile(
223
optimizer=keras.optimizers.Adam(learning_rate),
224
loss=loss,
225
steps_per_execution=10,
226
)
227
history = model.fit(
228
train_ds, epochs=epochs, validation_data=val_ds, validation_steps=val_steps
229
)
230
231
"""
232
## Indexing
233
234
Now that we have trained our model, we can create an index of examples. Here we
235
batch index the first 200 validation examples by passing the x and y to the index
236
along with storing the image in the data parameter. The `x_index` values are
237
embedded and then added to the index to make them searchable. The `y_index` and
238
data parameters are optional but allow the user to associate metadata with the
239
embedded example.
240
"""
241
242
x_index, y_index = val_ds.get_slice(begin=0, size=200)
243
model.reset_index()
244
model.index(x_index, y_index, data=x_index)
245
246
"""
247
## Calibration
248
249
Once the index is built, we can calibrate a distance threshold using a matching
250
strategy and a calibration metric.
251
252
Here we are searching for the optimal F1 score while using K=1 as our
253
classifier. All matches at or below the calibrated threshold distance will be
254
labeled as a Positive match between the query example and the label associated
255
with the match result, while all matches above the threshold distance will be
256
labeled as a Negative match.
257
258
Additionally, we pass in extra metrics to compute as well. All values in the
259
output are computed at the calibrated threshold.
260
261
Finally, `model.calibrate()` returns a `CalibrationResults` object containing:
262
263
* `"cutpoints"`: A Python dict mapping the cutpoint name to a dict containing the
264
`ClassificationMetric` values associated with a particular distance threshold,
265
e.g., `"optimal" : {"acc": 0.90, "f1": 0.92}`.
266
* `"thresholds"`: A Python dict mapping `ClassificationMetric` names to a list
267
containing the metric's value computed at each of the distance thresholds, e.g.,
268
`{"f1": [0.99, 0.80], "distance": [0.0, 1.0]}`.
269
"""
270
271
x_train, y_train = train_ds.get_slice(begin=0, size=1000)
272
calibration = model.calibrate(
273
x_train,
274
y_train,
275
calibration_metric="f1",
276
matcher="match_nearest",
277
extra_metrics=["precision", "recall", "binary_accuracy"],
278
verbose=1,
279
)
280
281
"""
282
## Visualization
283
284
It may be difficult to get a sense of the model quality from the metrics alone.
285
A complementary approach is to manually inspect a set of query results to get a
286
feel for the match quality.
287
288
Here we take 10 validation examples and plot them with their 5 nearest
289
neighbors and the distances to the query example. Looking at the results, we see
290
that while they are imperfect they still represent meaningfully similar images,
291
and that the model is able to find similar images irrespective of their pose or
292
image illumination.
293
294
We can also see that the model is very confident with certain images, resulting
295
in very small distances between the query and the neighbors. Conversely, we see
296
more mistakes in the class labels as the distances become larger. This is one of
297
the reasons why calibration is critical for matching applications.
298
"""
299
300
num_neighbors = 5
301
labels = [
302
"Airplane",
303
"Automobile",
304
"Bird",
305
"Cat",
306
"Deer",
307
"Dog",
308
"Frog",
309
"Horse",
310
"Ship",
311
"Truck",
312
"Unknown",
313
]
314
class_mapping = {c_id: c_lbl for c_id, c_lbl in zip(range(11), labels)}
315
316
x_display, y_display = val_ds.get_slice(begin=200, size=10)
317
# lookup nearest neighbors in the index
318
nns = model.lookup(x_display, k=num_neighbors)
319
320
# display
321
for idx in np.argsort(y_display):
322
tfsim.visualization.viz_neigbors_imgs(
323
x_display[idx],
324
y_display[idx],
325
nns[idx],
326
class_mapping=class_mapping,
327
fig_size=(16, 2),
328
)
329
330
"""
331
## Metrics
332
333
We can also plot the extra metrics contained in the `CalibrationResults` to get
334
a sense of the matching performance as the distance threshold increases.
335
336
The following plots show the Precision, Recall, and F1 Score. We can see that
337
the matching precision degrades as the distance increases, but that the
338
percentage of the queries that we accept as positive matches (recall) grows
339
faster up to the calibrated distance threshold.
340
"""
341
342
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
343
x = calibration.thresholds["distance"]
344
345
ax1.plot(x, calibration.thresholds["precision"], label="precision")
346
ax1.plot(x, calibration.thresholds["recall"], label="recall")
347
ax1.plot(x, calibration.thresholds["f1"], label="f1 score")
348
ax1.legend()
349
ax1.set_title("Metric evolution as distance increase")
350
ax1.set_xlabel("Distance")
351
ax1.set_ylim((-0.05, 1.05))
352
353
ax2.plot(calibration.thresholds["recall"], calibration.thresholds["precision"])
354
ax2.set_title("Precision recall curve")
355
ax2.set_xlabel("Recall")
356
ax2.set_ylabel("Precision")
357
ax2.set_ylim((-0.05, 1.05))
358
plt.show()
359
360
"""
361
We can also take 100 examples for each class and plot the confusion matrix for
362
each example and their nearest match. We also add an "extra" 10th class to
363
represent the matches above the calibrated distance threshold.
364
365
We can see that most of the errors are between the animal classes with an
366
interesting number of confusions between Airplane and Bird. Additionally, we see
367
that only a few of the 100 examples for each class returned matches outside of
368
the calibrated distance threshold.
369
"""
370
371
cutpoint = "optimal"
372
373
# This yields 100 examples for each class.
374
# We defined this when we created the val_ds sampler.
375
x_confusion, y_confusion = val_ds.get_slice(0, -1)
376
377
matches = model.match(x_confusion, cutpoint=cutpoint, no_match_label=10)
378
cm = tfsim.visualization.confusion_matrix(
379
matches,
380
y_confusion,
381
labels=labels,
382
title="Confusion matrix for cutpoint:%s" % cutpoint,
383
normalize=False,
384
)
385
386
"""
387
## No Match
388
389
We can plot the examples outside of the calibrated threshold to see which images
390
are not matching any indexed examples.
391
392
This may provide insight into what other examples may need to be indexed or
393
surface anomalous examples within the class.
394
"""
395
396
idx_no_match = np.where(np.array(matches) == 10)
397
no_match_queries = x_confusion[idx_no_match]
398
if len(no_match_queries):
399
plt.imshow(no_match_queries[0])
400
else:
401
print("All queries have a match below the distance threshold.")
402
403
"""
404
## Visualize clusters
405
406
One of the best ways to quickly get a sense of the quality of how the model is
407
doing and understand it's short comings is to project the embedding into a 2D
408
space.
409
410
This allows us to inspect clusters of images and understand which classes are
411
entangled.
412
"""
413
414
# Each class in val_ds was restricted to 100 examples.
415
num_examples_to_clusters = 1000
416
thumb_size = 96
417
plot_size = 800
418
vx, vy = val_ds.get_slice(0, num_examples_to_clusters)
419
420
# Uncomment to run the interactive projector.
421
# tfsim.visualization.projector(
422
# model.predict(vx),
423
# labels=vy,
424
# images=vx,
425
# class_mapping=class_mapping,
426
# image_size=thumb_size,
427
# plot_size=plot_size,
428
# )
429
430