Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/adamatch.py
3507 views
1
"""
2
Title: Semi-supervision and domain adaptation with AdaMatch
3
Author: [Sayak Paul](https://twitter.com/RisingSayak)
4
Date created: 2021/06/19
5
Last modified: 2021/06/19
6
Description: Unifying semi-supervised learning and unsupervised domain adaptation with AdaMatch.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
In this example, we will implement the AdaMatch algorithm, proposed in
14
[AdaMatch: A Unified Approach to Semi-Supervised Learning and Domain Adaptation](https://arxiv.org/abs/2106.04732)
15
by Berthelot et al. It sets a new state-of-the-art in unsupervised domain adaptation (as of
16
June 2021). AdaMatch is particularly interesting because it
17
unifies semi-supervised learning (SSL) and unsupervised domain adaptation
18
(UDA) under one framework. It thereby provides a way to perform semi-supervised domain
19
adaptation (SSDA).
20
21
This example requires TensorFlow 2.5 or higher, as well as TensorFlow Models, which can
22
be installed using the following command:
23
"""
24
25
"""shell
26
pip install -q tf-models-official==2.9.2
27
"""
28
29
"""
30
Before we proceed, let's review a few preliminary concepts underlying this example.
31
"""
32
33
"""
34
## Preliminaries
35
36
In **semi-supervised learning (SSL)**, we use a small amount of labeled data to
37
train models on a bigger unlabeled dataset. Popular semi-supervised learning methods
38
for computer vision include [FixMatch](https://arxiv.org/abs/2001.07685),
39
[MixMatch](https://arxiv.org/abs/1905.02249),
40
[Noisy Student Training](https://arxiv.org/abs/1911.04252), etc. You can refer to
41
[this example](https://keras.io/examples/vision/consistency_training/) to get an idea
42
of what a standard SSL workflow looks like.
43
44
In **unsupervised domain adaptation**, we have access to a source labeled dataset and
45
a target *unlabeled* dataset. Then the task is to learn a model that can generalize well
46
to the target dataset. The source and the target datasets vary in terms of distribution.
47
The following figure provides an illustration of this idea. In the present example, we use the
48
[MNIST dataset](http://yann.lecun.com/exdb/mnist/) as the source dataset, while the target dataset is
49
[SVHN](http://ufldl.stanford.edu/housenumbers/), which consists of images of house
50
numbers. Both datasets have various varying factors in terms of texture, viewpoint,
51
appearance, etc.: their domains, or distributions, are different from one
52
another.
53
54
![](https://i.imgur.com/dJFSJuT.png)
55
56
Popular domain adaptation algorithms in deep learning include
57
[Deep CORAL](https://arxiv.org/abs/1612.01939),
58
[Moment Matching](https://arxiv.org/abs/1812.01754), etc.
59
"""
60
61
"""
62
## Setup
63
"""
64
65
import tensorflow as tf
66
67
tf.random.set_seed(42)
68
69
import numpy as np
70
71
from tensorflow import keras
72
from tensorflow.keras import layers
73
from tensorflow.keras import regularizers
74
from keras_cv.layers import RandAugment
75
76
import tensorflow_datasets as tfds
77
78
tfds.disable_progress_bar()
79
80
"""
81
## Prepare the data
82
"""
83
84
# MNIST
85
(
86
(mnist_x_train, mnist_y_train),
87
(mnist_x_test, mnist_y_test),
88
) = keras.datasets.mnist.load_data()
89
90
# Add a channel dimension
91
mnist_x_train = tf.expand_dims(mnist_x_train, -1)
92
mnist_x_test = tf.expand_dims(mnist_x_test, -1)
93
94
# Convert the labels to one-hot encoded vectors
95
mnist_y_train = tf.one_hot(mnist_y_train, 10).numpy()
96
97
# SVHN
98
svhn_train, svhn_test = tfds.load(
99
"svhn_cropped", split=["train", "test"], as_supervised=True
100
)
101
102
"""
103
## Define constants and hyperparameters
104
"""
105
106
RESIZE_TO = 32
107
108
SOURCE_BATCH_SIZE = 64
109
TARGET_BATCH_SIZE = 3 * SOURCE_BATCH_SIZE # Reference: Section 3.2
110
EPOCHS = 10
111
STEPS_PER_EPOCH = len(mnist_x_train) // SOURCE_BATCH_SIZE
112
TOTAL_STEPS = EPOCHS * STEPS_PER_EPOCH
113
114
AUTO = tf.data.AUTOTUNE
115
LEARNING_RATE = 0.03
116
117
WEIGHT_DECAY = 0.0005
118
INIT = "he_normal"
119
DEPTH = 28
120
WIDTH_MULT = 2
121
122
"""
123
## Data augmentation utilities
124
125
A standard element of SSL algorithms is to feed weakly and strongly augmented versions of
126
the same images to the learning model to make its predictions consistent. For strong
127
augmentation, [RandAugment](https://arxiv.org/abs/1909.13719) is a standard choice. For
128
weak augmentation, we will use horizontal flipping and random cropping.
129
"""
130
131
# Initialize `RandAugment` object with 2 layers of
132
# augmentation transforms and strength of 5.
133
augmenter = RandAugment(value_range=(0, 255), augmentations_per_image=2, magnitude=0.5)
134
135
136
def weak_augment(image, source=True):
137
if image.dtype != tf.float32:
138
image = tf.cast(image, tf.float32)
139
140
# MNIST images are grayscale, this is why we first convert them to
141
# RGB images.
142
if source:
143
image = tf.image.resize_with_pad(image, RESIZE_TO, RESIZE_TO)
144
image = tf.tile(image, [1, 1, 3])
145
image = tf.image.random_flip_left_right(image)
146
image = tf.image.random_crop(image, (RESIZE_TO, RESIZE_TO, 3))
147
return image
148
149
150
def strong_augment(image, source=True):
151
if image.dtype != tf.float32:
152
image = tf.cast(image, tf.float32)
153
154
if source:
155
image = tf.image.resize_with_pad(image, RESIZE_TO, RESIZE_TO)
156
image = tf.tile(image, [1, 1, 3])
157
image = augmenter(image)
158
return image
159
160
161
"""
162
## Data loading utilities
163
"""
164
165
166
def create_individual_ds(ds, aug_func, source=True):
167
if source:
168
batch_size = SOURCE_BATCH_SIZE
169
else:
170
# During training 3x more target unlabeled samples are shown
171
# to the model in AdaMatch (Section 3.2 of the paper).
172
batch_size = TARGET_BATCH_SIZE
173
ds = ds.shuffle(batch_size * 10, seed=42)
174
175
if source:
176
ds = ds.map(lambda x, y: (aug_func(x), y), num_parallel_calls=AUTO)
177
else:
178
ds = ds.map(lambda x, y: (aug_func(x, False), y), num_parallel_calls=AUTO)
179
180
ds = ds.batch(batch_size).prefetch(AUTO)
181
return ds
182
183
184
"""
185
`_w` and `_s` suffixes denote weak and strong respectively.
186
"""
187
188
source_ds = tf.data.Dataset.from_tensor_slices((mnist_x_train, mnist_y_train))
189
source_ds_w = create_individual_ds(source_ds, weak_augment)
190
source_ds_s = create_individual_ds(source_ds, strong_augment)
191
final_source_ds = tf.data.Dataset.zip((source_ds_w, source_ds_s))
192
193
target_ds_w = create_individual_ds(svhn_train, weak_augment, source=False)
194
target_ds_s = create_individual_ds(svhn_train, strong_augment, source=False)
195
final_target_ds = tf.data.Dataset.zip((target_ds_w, target_ds_s))
196
197
"""
198
Here's what a single image batch looks like:
199
200
![](https://i.imgur.com/aver8cG.png)
201
"""
202
203
"""
204
## Loss computation utilities
205
"""
206
207
208
def compute_loss_source(source_labels, logits_source_w, logits_source_s):
209
loss_func = keras.losses.CategoricalCrossentropy(from_logits=True)
210
# First compute the losses between original source labels and
211
# predictions made on the weakly and strongly augmented versions
212
# of the same images.
213
w_loss = loss_func(source_labels, logits_source_w)
214
s_loss = loss_func(source_labels, logits_source_s)
215
return w_loss + s_loss
216
217
218
def compute_loss_target(target_pseudo_labels_w, logits_target_s, mask):
219
loss_func = keras.losses.CategoricalCrossentropy(from_logits=True, reduction="none")
220
target_pseudo_labels_w = tf.stop_gradient(target_pseudo_labels_w)
221
# For calculating loss for the target samples, we treat the pseudo labels
222
# as the ground-truth. These are not considered during backpropagation
223
# which is a standard SSL practice.
224
target_loss = loss_func(target_pseudo_labels_w, logits_target_s)
225
226
# More on `mask` later.
227
mask = tf.cast(mask, target_loss.dtype)
228
target_loss *= mask
229
return tf.reduce_mean(target_loss, 0)
230
231
232
"""
233
## Subclassed model for AdaMatch training
234
235
The figure below presents the overall workflow of AdaMatch (taken from the
236
[original paper](https://arxiv.org/abs/2106.04732)):
237
238
![](https://i.imgur.com/1QsEm2M.png)
239
240
Here's a brief step-by-step breakdown of the workflow:
241
242
1. We first retrieve the weakly and strongly augmented pairs of images from the source and
243
target datasets.
244
2. We prepare two concatenated copies:
245
i. One where both pairs are concatenated.
246
ii. One where only the source data image pair is concatenated.
247
3. We run two forward passes through the model:
248
i. The first forward pass uses the concatenated copy obtained from **2.i**. In
249
this forward pass, the [Batch Normalization](https://arxiv.org/abs/1502.03167) statistics
250
are updated.
251
ii. In the second forward pass, we only use the concatenated copy obtained from **2.ii**.
252
Batch Normalization layers are run in inference mode.
253
4. The respective logits are computed for both the forward passes.
254
5. The logits go through a series of transformations, introduced in the paper (which
255
we will discuss shortly).
256
6. We compute the loss and update the gradients of the underlying model.
257
"""
258
259
260
class AdaMatch(keras.Model):
261
def __init__(self, model, total_steps, tau=0.9):
262
super().__init__()
263
self.model = model
264
self.tau = tau # Denotes the confidence threshold
265
self.loss_tracker = tf.keras.metrics.Mean(name="loss")
266
self.total_steps = total_steps
267
self.current_step = tf.Variable(0, dtype="int64")
268
269
@property
270
def metrics(self):
271
return [self.loss_tracker]
272
273
# This is a warmup schedule to update the weight of the
274
# loss contributed by the target unlabeled samples. More
275
# on this in the text.
276
def compute_mu(self):
277
pi = tf.constant(np.pi, dtype="float32")
278
step = tf.cast(self.current_step, dtype="float32")
279
return 0.5 - tf.cos(tf.math.minimum(pi, (2 * pi * step) / self.total_steps)) / 2
280
281
def train_step(self, data):
282
## Unpack and organize the data ##
283
source_ds, target_ds = data
284
(source_w, source_labels), (source_s, _) = source_ds
285
(
286
(target_w, _),
287
(target_s, _),
288
) = target_ds # Notice that we are NOT using any labels here.
289
290
combined_images = tf.concat([source_w, source_s, target_w, target_s], 0)
291
combined_source = tf.concat([source_w, source_s], 0)
292
293
total_source = tf.shape(combined_source)[0]
294
total_target = tf.shape(tf.concat([target_w, target_s], 0))[0]
295
296
with tf.GradientTape() as tape:
297
## Forward passes ##
298
combined_logits = self.model(combined_images, training=True)
299
z_d_prime_source = self.model(
300
combined_source, training=False
301
) # No BatchNorm update.
302
z_prime_source = combined_logits[:total_source]
303
304
## 1. Random logit interpolation for the source images ##
305
lambd = tf.random.uniform((total_source, 10), 0, 1)
306
final_source_logits = (lambd * z_prime_source) + (
307
(1 - lambd) * z_d_prime_source
308
)
309
310
## 2. Distribution alignment (only consider weakly augmented images) ##
311
# Compute softmax for logits of the WEAKLY augmented SOURCE images.
312
y_hat_source_w = tf.nn.softmax(final_source_logits[: tf.shape(source_w)[0]])
313
314
# Extract logits for the WEAKLY augmented TARGET images and compute softmax.
315
logits_target = combined_logits[total_source:]
316
logits_target_w = logits_target[: tf.shape(target_w)[0]]
317
y_hat_target_w = tf.nn.softmax(logits_target_w)
318
319
# Align the target label distribution to that of the source.
320
expectation_ratio = tf.reduce_mean(y_hat_source_w) / tf.reduce_mean(
321
y_hat_target_w
322
)
323
y_tilde_target_w = tf.math.l2_normalize(
324
y_hat_target_w * expectation_ratio, 1
325
)
326
327
## 3. Relative confidence thresholding ##
328
row_wise_max = tf.reduce_max(y_hat_source_w, axis=-1)
329
final_sum = tf.reduce_mean(row_wise_max, 0)
330
c_tau = self.tau * final_sum
331
mask = tf.reduce_max(y_tilde_target_w, axis=-1) >= c_tau
332
333
## Compute losses (pay attention to the indexing) ##
334
source_loss = compute_loss_source(
335
source_labels,
336
final_source_logits[: tf.shape(source_w)[0]],
337
final_source_logits[tf.shape(source_w)[0] :],
338
)
339
target_loss = compute_loss_target(
340
y_tilde_target_w, logits_target[tf.shape(target_w)[0] :], mask
341
)
342
343
t = self.compute_mu() # Compute weight for the target loss
344
total_loss = source_loss + (t * target_loss)
345
self.current_step.assign_add(
346
1
347
) # Update current training step for the scheduler
348
349
gradients = tape.gradient(total_loss, self.model.trainable_variables)
350
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
351
352
self.loss_tracker.update_state(total_loss)
353
return {"loss": self.loss_tracker.result()}
354
355
356
"""
357
The authors introduce three improvements in the paper:
358
359
* In AdaMatch, we perform two forward passes, and only one of them is respsonsible for
360
updating the Batch Normalization statistics. This is done to account for distribution
361
shifts in the target dataset. In the other forward pass, we only use the source sample,
362
and the Batch Normalization layers are run in inference mode. Logits for the source
363
samples (weakly and strongly augmented versions) from these two passes are slightly
364
different from one another because of how Batch Normalization layers are run. Final
365
logits for the source samples are computed by linearly interpolating between these two
366
different pairs of logits. This induces a form of consistency regularization. This step
367
is referred to as **random logit interpolation**.
368
* **Distribution alignment** is used to align the source and target label distributions.
369
This further helps the underlying model learn *domain-invariant representations*. In case
370
of unsupervised domain adaptation, we don't have access to any labels of the target
371
dataset. This is why pseudo labels are generated from the underlying model.
372
* The underlying model generates pseudo-labels for the target samples. It's likely that
373
the model would make faulty predictions. Those can propagate back as we make progress in
374
the training, and hurt the overall performance. To compensate for that, we filter the
375
high-confidence predictions based on a threshold (hence the use of `mask` inside
376
`compute_loss_target()`). In AdaMatch, this threshold is relatively adjusted which is why
377
it is called **relative confidence thresholding**.
378
379
For more details on these methods and to know how each of them contribute please refer to
380
[the paper](https://arxiv.org/abs/2106.04732).
381
382
**About `compute_mu()`**:
383
384
Rather than using a fixed scalar quantity, a varying scalar is used in AdaMatch. It
385
denotes the weight of the loss contibuted by the target samples. Visually, the weight
386
scheduler look like so:
387
388
![](https://i.imgur.com/dG7i9uH.png)
389
390
This scheduler increases the weight of the target domain loss from 0 to 1 for the first
391
half of the training. Then it keeps that weight at 1 for the second half of the training.
392
"""
393
394
"""
395
## Instantiate a Wide-ResNet-28-2
396
397
The authors use a [WideResNet-28-2](https://arxiv.org/abs/1605.07146) for the dataset
398
pairs we are using in this example. Most of the following code has been referred from
399
[this script](https://github.com/asmith26/wide_resnets_keras/blob/master/main.py). Note
400
that the following model has a scaling layer inside it that scales the pixel values to
401
[0, 1].
402
"""
403
404
405
def wide_basic(x, n_input_plane, n_output_plane, stride):
406
conv_params = [[3, 3, stride, "same"], [3, 3, (1, 1), "same"]]
407
408
n_bottleneck_plane = n_output_plane
409
410
# Residual block
411
for i, v in enumerate(conv_params):
412
if i == 0:
413
if n_input_plane != n_output_plane:
414
x = layers.BatchNormalization()(x)
415
x = layers.Activation("relu")(x)
416
convs = x
417
else:
418
convs = layers.BatchNormalization()(x)
419
convs = layers.Activation("relu")(convs)
420
convs = layers.Conv2D(
421
n_bottleneck_plane,
422
(v[0], v[1]),
423
strides=v[2],
424
padding=v[3],
425
kernel_initializer=INIT,
426
kernel_regularizer=regularizers.l2(WEIGHT_DECAY),
427
use_bias=False,
428
)(convs)
429
else:
430
convs = layers.BatchNormalization()(convs)
431
convs = layers.Activation("relu")(convs)
432
convs = layers.Conv2D(
433
n_bottleneck_plane,
434
(v[0], v[1]),
435
strides=v[2],
436
padding=v[3],
437
kernel_initializer=INIT,
438
kernel_regularizer=regularizers.l2(WEIGHT_DECAY),
439
use_bias=False,
440
)(convs)
441
442
# Shortcut connection: identity function or 1x1
443
# convolutional
444
# (depends on difference between input & output shape - this
445
# corresponds to whether we are using the first block in
446
# each
447
# group; see `block_series()`).
448
if n_input_plane != n_output_plane:
449
shortcut = layers.Conv2D(
450
n_output_plane,
451
(1, 1),
452
strides=stride,
453
padding="same",
454
kernel_initializer=INIT,
455
kernel_regularizer=regularizers.l2(WEIGHT_DECAY),
456
use_bias=False,
457
)(x)
458
else:
459
shortcut = x
460
461
return layers.Add()([convs, shortcut])
462
463
464
# Stacking residual units on the same stage
465
def block_series(x, n_input_plane, n_output_plane, count, stride):
466
x = wide_basic(x, n_input_plane, n_output_plane, stride)
467
for i in range(2, int(count + 1)):
468
x = wide_basic(x, n_output_plane, n_output_plane, stride=1)
469
return x
470
471
472
def get_network(image_size=32, num_classes=10):
473
n = (DEPTH - 4) / 6
474
n_stages = [16, 16 * WIDTH_MULT, 32 * WIDTH_MULT, 64 * WIDTH_MULT]
475
476
inputs = keras.Input(shape=(image_size, image_size, 3))
477
x = layers.Rescaling(scale=1.0 / 255)(inputs)
478
479
conv1 = layers.Conv2D(
480
n_stages[0],
481
(3, 3),
482
strides=1,
483
padding="same",
484
kernel_initializer=INIT,
485
kernel_regularizer=regularizers.l2(WEIGHT_DECAY),
486
use_bias=False,
487
)(x)
488
489
## Add wide residual blocks ##
490
491
conv2 = block_series(
492
conv1,
493
n_input_plane=n_stages[0],
494
n_output_plane=n_stages[1],
495
count=n,
496
stride=(1, 1),
497
) # Stage 1
498
499
conv3 = block_series(
500
conv2,
501
n_input_plane=n_stages[1],
502
n_output_plane=n_stages[2],
503
count=n,
504
stride=(2, 2),
505
) # Stage 2
506
507
conv4 = block_series(
508
conv3,
509
n_input_plane=n_stages[2],
510
n_output_plane=n_stages[3],
511
count=n,
512
stride=(2, 2),
513
) # Stage 3
514
515
batch_norm = layers.BatchNormalization()(conv4)
516
relu = layers.Activation("relu")(batch_norm)
517
518
# Classifier
519
trunk_outputs = layers.GlobalAveragePooling2D()(relu)
520
outputs = layers.Dense(
521
num_classes, kernel_regularizer=regularizers.l2(WEIGHT_DECAY)
522
)(trunk_outputs)
523
524
return keras.Model(inputs, outputs)
525
526
527
"""
528
We can now instantiate a Wide ResNet model like so. Note that the purpose of using a
529
Wide ResNet here is to keep the implementation as close to the original one
530
as possible.
531
"""
532
533
wrn_model = get_network()
534
print(f"Model has {wrn_model.count_params()/1e6} Million parameters.")
535
536
"""
537
## Instantiate AdaMatch model and compile it
538
"""
539
540
reduce_lr = keras.optimizers.schedules.CosineDecay(LEARNING_RATE, TOTAL_STEPS, 0.25)
541
optimizer = keras.optimizers.Adam(reduce_lr)
542
543
adamatch_trainer = AdaMatch(model=wrn_model, total_steps=TOTAL_STEPS)
544
adamatch_trainer.compile(optimizer=optimizer)
545
546
"""
547
## Model training
548
"""
549
550
total_ds = tf.data.Dataset.zip((final_source_ds, final_target_ds))
551
adamatch_trainer.fit(total_ds, epochs=EPOCHS)
552
553
"""
554
## Evaluation on the target and source test sets
555
"""
556
557
# Compile the AdaMatch model to yield accuracy.
558
adamatch_trained_model = adamatch_trainer.model
559
adamatch_trained_model.compile(metrics=keras.metrics.SparseCategoricalAccuracy())
560
561
# Score on the target test set.
562
svhn_test = svhn_test.batch(TARGET_BATCH_SIZE).prefetch(AUTO)
563
_, accuracy = adamatch_trained_model.evaluate(svhn_test)
564
print(f"Accuracy on target test set: {accuracy * 100:.2f}%")
565
566
"""
567
With more training, this score improves. When this same network is trained with
568
standard classification objective, it yields an accuracy of **7.20%** which is
569
significantly lower than what we got with AdaMatch. You can check out
570
[this notebook](https://colab.research.google.com/github/sayakpaul/AdaMatch-TF/blob/main/Vanilla_WideResNet.ipynb)
571
to learn more about the hyperparameters and other experimental details.
572
"""
573
574
575
# Utility function for preprocessing the source test set.
576
def prepare_test_ds_source(image, label):
577
image = tf.image.resize_with_pad(image, RESIZE_TO, RESIZE_TO)
578
image = tf.tile(image, [1, 1, 3])
579
return image, label
580
581
582
source_test_ds = tf.data.Dataset.from_tensor_slices((mnist_x_test, mnist_y_test))
583
source_test_ds = (
584
source_test_ds.map(prepare_test_ds_source, num_parallel_calls=AUTO)
585
.batch(TARGET_BATCH_SIZE)
586
.prefetch(AUTO)
587
)
588
589
# Evaluation on the source test set.
590
_, accuracy = adamatch_trained_model.evaluate(source_test_ds)
591
print(f"Accuracy on source test set: {accuracy * 100:.2f}%")
592
593
"""
594
You can reproduce the results by using these
595
[model weights](https://github.com/sayakpaul/AdaMatch-TF/releases/tag/v1.0.0).
596
"""
597
598
"""
599
**Example available on HuggingFace**
600
| Trained Model | Demo |
601
| :--: | :--: |
602
| [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Model-AdaMatch%20Domain%20Adaption-black.svg)](https://huggingface.co/keras-io/adamatch-domain-adaption) | [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Spaces-AdaMatch%20Domain%20Adaption-black.svg)](https://huggingface.co/spaces/keras-io/adamatch-domain-adaption) |
603
"""
604
605