Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/masked_image_modeling.py
3507 views
1
"""
2
Title: Masked image modeling with Autoencoders
3
Author: [Aritra Roy Gosthipaty](https://twitter.com/arig23498), [Sayak Paul](https://twitter.com/RisingSayak)
4
Date created: 2021/12/20
5
Last modified: 2021/12/21
6
Description: Implementing Masked Autoencoders for self-supervised pretraining.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
In deep learning, models with growing **capacity** and **capability** can easily overfit
14
on large datasets (ImageNet-1K). In the field of natural language processing, the
15
appetite for data has been **successfully addressed** by self-supervised pretraining.
16
17
In the academic paper
18
[Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377)
19
by He et. al. the authors propose a simple yet effective method to pretrain large
20
vision models (here [ViT Huge](https://arxiv.org/abs/2010.11929)). Inspired from
21
the pretraining algorithm of BERT ([Devlin et al.](https://arxiv.org/abs/1810.04805)),
22
they mask patches of an image and, through an autoencoder predict the masked patches.
23
In the spirit of "masked language modeling", this pretraining task could be referred
24
to as "masked image modeling".
25
26
In this example, we implement
27
[Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377)
28
with the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset. After
29
pretraining a scaled down version of ViT, we also implement the linear evaluation
30
pipeline on CIFAR-10.
31
32
33
This implementation covers (MAE refers to Masked Autoencoder):
34
35
- The masking algorithm
36
- MAE encoder
37
- MAE decoder
38
- Evaluation with linear probing
39
40
As a reference, we reuse some of the code presented in
41
[this example](https://keras.io/examples/vision/image_classification_with_vision_transformer/).
42
43
"""
44
45
"""
46
## Imports
47
"""
48
import os
49
50
os.environ["KERAS_BACKEND"] = "tensorflow"
51
52
import tensorflow as tf
53
import keras
54
from keras import layers
55
56
import matplotlib.pyplot as plt
57
import numpy as np
58
import random
59
60
# Setting seeds for reproducibility.
61
SEED = 42
62
keras.utils.set_random_seed(SEED)
63
64
"""
65
## Hyperparameters for pretraining
66
67
Please feel free to change the hyperparameters and check your results. The best way to
68
get an intuition about the architecture is to experiment with it. Our hyperparameters are
69
heavily inspired by the design guidelines laid out by the authors in
70
[the original paper](https://arxiv.org/abs/2111.06377).
71
"""
72
73
# DATA
74
BUFFER_SIZE = 1024
75
BATCH_SIZE = 256
76
AUTO = tf.data.AUTOTUNE
77
INPUT_SHAPE = (32, 32, 3)
78
NUM_CLASSES = 10
79
80
# OPTIMIZER
81
LEARNING_RATE = 5e-3
82
WEIGHT_DECAY = 1e-4
83
84
# PRETRAINING
85
EPOCHS = 100
86
87
# AUGMENTATION
88
IMAGE_SIZE = 48 # We will resize input images to this size.
89
PATCH_SIZE = 6 # Size of the patches to be extracted from the input images.
90
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2
91
MASK_PROPORTION = 0.75 # We have found 75% masking to give us the best results.
92
93
# ENCODER and DECODER
94
LAYER_NORM_EPS = 1e-6
95
ENC_PROJECTION_DIM = 128
96
DEC_PROJECTION_DIM = 64
97
ENC_NUM_HEADS = 4
98
ENC_LAYERS = 6
99
DEC_NUM_HEADS = 4
100
DEC_LAYERS = (
101
2 # The decoder is lightweight but should be reasonably deep for reconstruction.
102
)
103
ENC_TRANSFORMER_UNITS = [
104
ENC_PROJECTION_DIM * 2,
105
ENC_PROJECTION_DIM,
106
] # Size of the transformer layers.
107
DEC_TRANSFORMER_UNITS = [
108
DEC_PROJECTION_DIM * 2,
109
DEC_PROJECTION_DIM,
110
]
111
112
"""
113
## Load and prepare the CIFAR-10 dataset
114
"""
115
116
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
117
(x_train, y_train), (x_val, y_val) = (
118
(x_train[:40000], y_train[:40000]),
119
(x_train[40000:], y_train[40000:]),
120
)
121
print(f"Training samples: {len(x_train)}")
122
print(f"Validation samples: {len(x_val)}")
123
print(f"Testing samples: {len(x_test)}")
124
125
train_ds = tf.data.Dataset.from_tensor_slices(x_train)
126
train_ds = train_ds.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(AUTO)
127
128
val_ds = tf.data.Dataset.from_tensor_slices(x_val)
129
val_ds = val_ds.batch(BATCH_SIZE).prefetch(AUTO)
130
131
test_ds = tf.data.Dataset.from_tensor_slices(x_test)
132
test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO)
133
134
"""
135
## Data augmentation
136
137
In previous self-supervised pretraining methodologies
138
([SimCLR](https://arxiv.org/abs/2002.05709) alike), we have noticed that the data
139
augmentation pipeline plays an important role. On the other hand the authors of this
140
paper point out that Masked Autoencoders **do not** rely on augmentations. They propose a
141
simple augmentation pipeline of:
142
143
144
- Resizing
145
- Random cropping (fixed-sized or random sized)
146
- Random horizontal flipping
147
"""
148
149
150
def get_train_augmentation_model():
151
model = keras.Sequential(
152
[
153
layers.Rescaling(1 / 255.0),
154
layers.Resizing(INPUT_SHAPE[0] + 20, INPUT_SHAPE[0] + 20),
155
layers.RandomCrop(IMAGE_SIZE, IMAGE_SIZE),
156
layers.RandomFlip("horizontal"),
157
],
158
name="train_data_augmentation",
159
)
160
return model
161
162
163
def get_test_augmentation_model():
164
model = keras.Sequential(
165
[
166
layers.Rescaling(1 / 255.0),
167
layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),
168
],
169
name="test_data_augmentation",
170
)
171
return model
172
173
174
"""
175
## A layer for extracting patches from images
176
177
This layer takes images as input and divides them into patches. The layer also includes
178
two utility method:
179
180
- `show_patched_image` -- Takes a batch of images and its corresponding patches to plot a
181
random pair of image and patches.
182
- `reconstruct_from_patch` -- Takes a single instance of patches and stitches them
183
together into the original image.
184
"""
185
186
187
class Patches(layers.Layer):
188
def __init__(self, patch_size=PATCH_SIZE, **kwargs):
189
super().__init__(**kwargs)
190
self.patch_size = patch_size
191
192
# Assuming the image has three channels each patch would be
193
# of size (patch_size, patch_size, 3).
194
self.resize = layers.Reshape((-1, patch_size * patch_size * 3))
195
196
def call(self, images):
197
# Create patches from the input images
198
patches = tf.image.extract_patches(
199
images=images,
200
sizes=[1, self.patch_size, self.patch_size, 1],
201
strides=[1, self.patch_size, self.patch_size, 1],
202
rates=[1, 1, 1, 1],
203
padding="VALID",
204
)
205
206
# Reshape the patches to (batch, num_patches, patch_area) and return it.
207
patches = self.resize(patches)
208
return patches
209
210
def show_patched_image(self, images, patches):
211
# This is a utility function which accepts a batch of images and its
212
# corresponding patches and help visualize one image and its patches
213
# side by side.
214
idx = np.random.choice(patches.shape[0])
215
print(f"Index selected: {idx}.")
216
217
plt.figure(figsize=(4, 4))
218
plt.imshow(keras.utils.array_to_img(images[idx]))
219
plt.axis("off")
220
plt.show()
221
222
n = int(np.sqrt(patches.shape[1]))
223
plt.figure(figsize=(4, 4))
224
for i, patch in enumerate(patches[idx]):
225
ax = plt.subplot(n, n, i + 1)
226
patch_img = tf.reshape(patch, (self.patch_size, self.patch_size, 3))
227
plt.imshow(keras.utils.img_to_array(patch_img))
228
plt.axis("off")
229
plt.show()
230
231
# Return the index chosen to validate it outside the method.
232
return idx
233
234
# taken from https://stackoverflow.com/a/58082878/10319735
235
def reconstruct_from_patch(self, patch):
236
# This utility function takes patches from a *single* image and
237
# reconstructs it back into the image. This is useful for the train
238
# monitor callback.
239
num_patches = patch.shape[0]
240
n = int(np.sqrt(num_patches))
241
patch = tf.reshape(patch, (num_patches, self.patch_size, self.patch_size, 3))
242
rows = tf.split(patch, n, axis=0)
243
rows = [tf.concat(tf.unstack(x), axis=1) for x in rows]
244
reconstructed = tf.concat(rows, axis=0)
245
return reconstructed
246
247
248
"""
249
Let's visualize the image patches.
250
"""
251
252
# Get a batch of images.
253
image_batch = next(iter(train_ds))
254
255
# Augment the images.
256
augmentation_model = get_train_augmentation_model()
257
augmented_images = augmentation_model(image_batch)
258
259
# Define the patch layer.
260
patch_layer = Patches()
261
262
# Get the patches from the batched images.
263
patches = patch_layer(images=augmented_images)
264
265
# Now pass the images and the corresponding patches
266
# to the `show_patched_image` method.
267
random_index = patch_layer.show_patched_image(images=augmented_images, patches=patches)
268
269
# Chose the same chose image and try reconstructing the patches
270
# into the original image.
271
image = patch_layer.reconstruct_from_patch(patches[random_index])
272
plt.imshow(image)
273
plt.axis("off")
274
plt.show()
275
276
"""
277
## Patch encoding with masking
278
279
Quoting the paper
280
281
> Following ViT, we divide an image into regular non-overlapping patches. Then we sample
282
a subset of patches and mask (i.e., remove) the remaining ones. Our sampling strategy is
283
straightforward: we sample random patches without replacement, following a uniform
284
distribution. We simply refer to this as “random sampling”.
285
286
This layer includes masking and encoding the patches.
287
288
The utility methods of the layer are:
289
290
- `get_random_indices` -- Provides the mask and unmask indices.
291
- `generate_masked_image` -- Takes patches and unmask indices, results in a random masked
292
image. This is an essential utility method for our training monitor callback (defined
293
later).
294
"""
295
296
297
class PatchEncoder(layers.Layer):
298
def __init__(
299
self,
300
patch_size=PATCH_SIZE,
301
projection_dim=ENC_PROJECTION_DIM,
302
mask_proportion=MASK_PROPORTION,
303
downstream=False,
304
**kwargs,
305
):
306
super().__init__(**kwargs)
307
self.patch_size = patch_size
308
self.projection_dim = projection_dim
309
self.mask_proportion = mask_proportion
310
self.downstream = downstream
311
312
# This is a trainable mask token initialized randomly from a normal
313
# distribution.
314
self.mask_token = tf.Variable(
315
tf.random.normal([1, patch_size * patch_size * 3]), trainable=True
316
)
317
318
def build(self, input_shape):
319
(_, self.num_patches, self.patch_area) = input_shape
320
321
# Create the projection layer for the patches.
322
self.projection = layers.Dense(units=self.projection_dim)
323
324
# Create the positional embedding layer.
325
self.position_embedding = layers.Embedding(
326
input_dim=self.num_patches, output_dim=self.projection_dim
327
)
328
329
# Number of patches that will be masked.
330
self.num_mask = int(self.mask_proportion * self.num_patches)
331
332
def call(self, patches):
333
# Get the positional embeddings.
334
batch_size = tf.shape(patches)[0]
335
positions = tf.range(start=0, limit=self.num_patches, delta=1)
336
pos_embeddings = self.position_embedding(positions[tf.newaxis, ...])
337
pos_embeddings = tf.tile(
338
pos_embeddings, [batch_size, 1, 1]
339
) # (B, num_patches, projection_dim)
340
341
# Embed the patches.
342
patch_embeddings = (
343
self.projection(patches) + pos_embeddings
344
) # (B, num_patches, projection_dim)
345
346
if self.downstream:
347
return patch_embeddings
348
else:
349
mask_indices, unmask_indices = self.get_random_indices(batch_size)
350
# The encoder input is the unmasked patch embeddings. Here we gather
351
# all the patches that should be unmasked.
352
unmasked_embeddings = tf.gather(
353
patch_embeddings, unmask_indices, axis=1, batch_dims=1
354
) # (B, unmask_numbers, projection_dim)
355
356
# Get the unmasked and masked position embeddings. We will need them
357
# for the decoder.
358
unmasked_positions = tf.gather(
359
pos_embeddings, unmask_indices, axis=1, batch_dims=1
360
) # (B, unmask_numbers, projection_dim)
361
masked_positions = tf.gather(
362
pos_embeddings, mask_indices, axis=1, batch_dims=1
363
) # (B, mask_numbers, projection_dim)
364
365
# Repeat the mask token number of mask times.
366
# Mask tokens replace the masks of the image.
367
mask_tokens = tf.repeat(self.mask_token, repeats=self.num_mask, axis=0)
368
mask_tokens = tf.repeat(
369
mask_tokens[tf.newaxis, ...], repeats=batch_size, axis=0
370
)
371
372
# Get the masked embeddings for the tokens.
373
masked_embeddings = self.projection(mask_tokens) + masked_positions
374
return (
375
unmasked_embeddings, # Input to the encoder.
376
masked_embeddings, # First part of input to the decoder.
377
unmasked_positions, # Added to the encoder outputs.
378
mask_indices, # The indices that were masked.
379
unmask_indices, # The indices that were unmaksed.
380
)
381
382
def get_random_indices(self, batch_size):
383
# Create random indices from a uniform distribution and then split
384
# it into mask and unmask indices.
385
rand_indices = tf.argsort(
386
tf.random.uniform(shape=(batch_size, self.num_patches)), axis=-1
387
)
388
mask_indices = rand_indices[:, : self.num_mask]
389
unmask_indices = rand_indices[:, self.num_mask :]
390
return mask_indices, unmask_indices
391
392
def generate_masked_image(self, patches, unmask_indices):
393
# Choose a random patch and it corresponding unmask index.
394
idx = np.random.choice(patches.shape[0])
395
patch = patches[idx]
396
unmask_index = unmask_indices[idx]
397
398
# Build a numpy array of same shape as patch.
399
new_patch = np.zeros_like(patch)
400
401
# Iterate of the new_patch and plug the unmasked patches.
402
count = 0
403
for i in range(unmask_index.shape[0]):
404
new_patch[unmask_index[i]] = patch[unmask_index[i]]
405
return new_patch, idx
406
407
408
"""
409
Let's see the masking process in action on a sample image.
410
"""
411
412
# Create the patch encoder layer.
413
patch_encoder = PatchEncoder()
414
415
# Get the embeddings and positions.
416
(
417
unmasked_embeddings,
418
masked_embeddings,
419
unmasked_positions,
420
mask_indices,
421
unmask_indices,
422
) = patch_encoder(patches=patches)
423
424
425
# Show a maksed patch image.
426
new_patch, random_index = patch_encoder.generate_masked_image(patches, unmask_indices)
427
428
plt.figure(figsize=(10, 10))
429
plt.subplot(1, 2, 1)
430
img = patch_layer.reconstruct_from_patch(new_patch)
431
plt.imshow(keras.utils.array_to_img(img))
432
plt.axis("off")
433
plt.title("Masked")
434
plt.subplot(1, 2, 2)
435
img = augmented_images[random_index]
436
plt.imshow(keras.utils.array_to_img(img))
437
plt.axis("off")
438
plt.title("Original")
439
plt.show()
440
441
"""
442
## MLP
443
444
This serves as the fully connected feed forward network of the transformer architecture.
445
"""
446
447
448
def mlp(x, dropout_rate, hidden_units):
449
for units in hidden_units:
450
x = layers.Dense(units, activation=tf.nn.gelu)(x)
451
x = layers.Dropout(dropout_rate)(x)
452
return x
453
454
455
"""
456
## MAE encoder
457
458
The MAE encoder is ViT. The only point to note here is that the encoder outputs a layer
459
normalized output.
460
"""
461
462
463
def create_encoder(num_heads=ENC_NUM_HEADS, num_layers=ENC_LAYERS):
464
inputs = layers.Input((None, ENC_PROJECTION_DIM))
465
x = inputs
466
467
for _ in range(num_layers):
468
# Layer normalization 1.
469
x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x)
470
471
# Create a multi-head attention layer.
472
attention_output = layers.MultiHeadAttention(
473
num_heads=num_heads, key_dim=ENC_PROJECTION_DIM, dropout=0.1
474
)(x1, x1)
475
476
# Skip connection 1.
477
x2 = layers.Add()([attention_output, x])
478
479
# Layer normalization 2.
480
x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)
481
482
# MLP.
483
x3 = mlp(x3, hidden_units=ENC_TRANSFORMER_UNITS, dropout_rate=0.1)
484
485
# Skip connection 2.
486
x = layers.Add()([x3, x2])
487
488
outputs = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x)
489
return keras.Model(inputs, outputs, name="mae_encoder")
490
491
492
"""
493
## MAE decoder
494
495
The authors point out that they use an **asymmetric** autoencoder model. They use a
496
lightweight decoder that takes "<10% computation per token vs. the encoder". We are not
497
specific with the "<10% computation" in our implementation but have used a smaller
498
decoder (both in terms of depth and projection dimensions).
499
"""
500
501
502
def create_decoder(
503
num_layers=DEC_LAYERS, num_heads=DEC_NUM_HEADS, image_size=IMAGE_SIZE
504
):
505
inputs = layers.Input((NUM_PATCHES, ENC_PROJECTION_DIM))
506
x = layers.Dense(DEC_PROJECTION_DIM)(inputs)
507
508
for _ in range(num_layers):
509
# Layer normalization 1.
510
x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x)
511
512
# Create a multi-head attention layer.
513
attention_output = layers.MultiHeadAttention(
514
num_heads=num_heads, key_dim=DEC_PROJECTION_DIM, dropout=0.1
515
)(x1, x1)
516
517
# Skip connection 1.
518
x2 = layers.Add()([attention_output, x])
519
520
# Layer normalization 2.
521
x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)
522
523
# MLP.
524
x3 = mlp(x3, hidden_units=DEC_TRANSFORMER_UNITS, dropout_rate=0.1)
525
526
# Skip connection 2.
527
x = layers.Add()([x3, x2])
528
529
x = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x)
530
x = layers.Flatten()(x)
531
pre_final = layers.Dense(units=image_size * image_size * 3, activation="sigmoid")(x)
532
outputs = layers.Reshape((image_size, image_size, 3))(pre_final)
533
534
return keras.Model(inputs, outputs, name="mae_decoder")
535
536
537
"""
538
## MAE trainer
539
540
This is the trainer module. We wrap the encoder and decoder inside of a `tf.keras.Model`
541
subclass. This allows us to customize what happens in the `model.fit()` loop.
542
"""
543
544
545
class MaskedAutoencoder(keras.Model):
546
def __init__(
547
self,
548
train_augmentation_model,
549
test_augmentation_model,
550
patch_layer,
551
patch_encoder,
552
encoder,
553
decoder,
554
**kwargs,
555
):
556
super().__init__(**kwargs)
557
self.train_augmentation_model = train_augmentation_model
558
self.test_augmentation_model = test_augmentation_model
559
self.patch_layer = patch_layer
560
self.patch_encoder = patch_encoder
561
self.encoder = encoder
562
self.decoder = decoder
563
564
def calculate_loss(self, images, test=False):
565
# Augment the input images.
566
if test:
567
augmented_images = self.test_augmentation_model(images)
568
else:
569
augmented_images = self.train_augmentation_model(images)
570
571
# Patch the augmented images.
572
patches = self.patch_layer(augmented_images)
573
574
# Encode the patches.
575
(
576
unmasked_embeddings,
577
masked_embeddings,
578
unmasked_positions,
579
mask_indices,
580
unmask_indices,
581
) = self.patch_encoder(patches)
582
583
# Pass the unmaksed patche to the encoder.
584
encoder_outputs = self.encoder(unmasked_embeddings)
585
586
# Create the decoder inputs.
587
encoder_outputs = encoder_outputs + unmasked_positions
588
decoder_inputs = tf.concat([encoder_outputs, masked_embeddings], axis=1)
589
590
# Decode the inputs.
591
decoder_outputs = self.decoder(decoder_inputs)
592
decoder_patches = self.patch_layer(decoder_outputs)
593
594
loss_patch = tf.gather(patches, mask_indices, axis=1, batch_dims=1)
595
loss_output = tf.gather(decoder_patches, mask_indices, axis=1, batch_dims=1)
596
597
# Compute the total loss.
598
total_loss = self.compute_loss(y=loss_patch, y_pred=loss_output)
599
600
return total_loss, loss_patch, loss_output
601
602
def train_step(self, images):
603
with tf.GradientTape() as tape:
604
total_loss, loss_patch, loss_output = self.calculate_loss(images)
605
606
# Apply gradients.
607
train_vars = [
608
self.train_augmentation_model.trainable_variables,
609
self.patch_layer.trainable_variables,
610
self.patch_encoder.trainable_variables,
611
self.encoder.trainable_variables,
612
self.decoder.trainable_variables,
613
]
614
grads = tape.gradient(total_loss, train_vars)
615
tv_list = []
616
for grad, var in zip(grads, train_vars):
617
for g, v in zip(grad, var):
618
tv_list.append((g, v))
619
self.optimizer.apply_gradients(tv_list)
620
621
# Report progress.
622
results = {}
623
for metric in self.metrics:
624
metric.update_state(loss_patch, loss_output)
625
results[metric.name] = metric.result()
626
return results
627
628
def test_step(self, images):
629
total_loss, loss_patch, loss_output = self.calculate_loss(images, test=True)
630
631
# Update the trackers.
632
results = {}
633
for metric in self.metrics:
634
metric.update_state(loss_patch, loss_output)
635
results[metric.name] = metric.result()
636
return results
637
638
639
"""
640
## Model initialization
641
"""
642
643
train_augmentation_model = get_train_augmentation_model()
644
test_augmentation_model = get_test_augmentation_model()
645
patch_layer = Patches()
646
patch_encoder = PatchEncoder()
647
encoder = create_encoder()
648
decoder = create_decoder()
649
650
mae_model = MaskedAutoencoder(
651
train_augmentation_model=train_augmentation_model,
652
test_augmentation_model=test_augmentation_model,
653
patch_layer=patch_layer,
654
patch_encoder=patch_encoder,
655
encoder=encoder,
656
decoder=decoder,
657
)
658
659
"""
660
## Training callbacks
661
"""
662
663
"""
664
### Visualization callback
665
"""
666
667
# Taking a batch of test inputs to measure model's progress.
668
test_images = next(iter(test_ds))
669
670
671
class TrainMonitor(keras.callbacks.Callback):
672
def __init__(self, epoch_interval=None):
673
self.epoch_interval = epoch_interval
674
675
def on_epoch_end(self, epoch, logs=None):
676
if self.epoch_interval and epoch % self.epoch_interval == 0:
677
test_augmented_images = self.model.test_augmentation_model(test_images)
678
test_patches = self.model.patch_layer(test_augmented_images)
679
(
680
test_unmasked_embeddings,
681
test_masked_embeddings,
682
test_unmasked_positions,
683
test_mask_indices,
684
test_unmask_indices,
685
) = self.model.patch_encoder(test_patches)
686
test_encoder_outputs = self.model.encoder(test_unmasked_embeddings)
687
test_encoder_outputs = test_encoder_outputs + test_unmasked_positions
688
test_decoder_inputs = tf.concat(
689
[test_encoder_outputs, test_masked_embeddings], axis=1
690
)
691
test_decoder_outputs = self.model.decoder(test_decoder_inputs)
692
693
# Show a maksed patch image.
694
test_masked_patch, idx = self.model.patch_encoder.generate_masked_image(
695
test_patches, test_unmask_indices
696
)
697
print(f"\nIdx chosen: {idx}")
698
original_image = test_augmented_images[idx]
699
masked_image = self.model.patch_layer.reconstruct_from_patch(
700
test_masked_patch
701
)
702
reconstructed_image = test_decoder_outputs[idx]
703
704
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))
705
ax[0].imshow(original_image)
706
ax[0].set_title(f"Original: {epoch:03d}")
707
708
ax[1].imshow(masked_image)
709
ax[1].set_title(f"Masked: {epoch:03d}")
710
711
ax[2].imshow(reconstructed_image)
712
ax[2].set_title(f"Resonstructed: {epoch:03d}")
713
714
plt.show()
715
plt.close()
716
717
718
"""
719
### Learning rate scheduler
720
"""
721
722
# Some code is taken from:
723
# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2.
724
725
726
class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
727
def __init__(
728
self, learning_rate_base, total_steps, warmup_learning_rate, warmup_steps
729
):
730
super().__init__()
731
732
self.learning_rate_base = learning_rate_base
733
self.total_steps = total_steps
734
self.warmup_learning_rate = warmup_learning_rate
735
self.warmup_steps = warmup_steps
736
self.pi = tf.constant(np.pi)
737
738
def __call__(self, step):
739
if self.total_steps < self.warmup_steps:
740
raise ValueError("Total_steps must be larger or equal to warmup_steps.")
741
742
cos_annealed_lr = tf.cos(
743
self.pi
744
* (tf.cast(step, tf.float32) - self.warmup_steps)
745
/ float(self.total_steps - self.warmup_steps)
746
)
747
learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr)
748
749
if self.warmup_steps > 0:
750
if self.learning_rate_base < self.warmup_learning_rate:
751
raise ValueError(
752
"Learning_rate_base must be larger or equal to "
753
"warmup_learning_rate."
754
)
755
slope = (
756
self.learning_rate_base - self.warmup_learning_rate
757
) / self.warmup_steps
758
warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate
759
learning_rate = tf.where(
760
step < self.warmup_steps, warmup_rate, learning_rate
761
)
762
return tf.where(
763
step > self.total_steps, 0.0, learning_rate, name="learning_rate"
764
)
765
766
767
total_steps = int((len(x_train) / BATCH_SIZE) * EPOCHS)
768
warmup_epoch_percentage = 0.15
769
warmup_steps = int(total_steps * warmup_epoch_percentage)
770
scheduled_lrs = WarmUpCosine(
771
learning_rate_base=LEARNING_RATE,
772
total_steps=total_steps,
773
warmup_learning_rate=0.0,
774
warmup_steps=warmup_steps,
775
)
776
777
lrs = [scheduled_lrs(step) for step in range(total_steps)]
778
plt.plot(lrs)
779
plt.xlabel("Step", fontsize=14)
780
plt.ylabel("LR", fontsize=14)
781
plt.show()
782
783
# Assemble the callbacks.
784
train_callbacks = [TrainMonitor(epoch_interval=5)]
785
786
"""
787
## Model compilation and training
788
"""
789
790
optimizer = keras.optimizers.AdamW(
791
learning_rate=scheduled_lrs, weight_decay=WEIGHT_DECAY
792
)
793
794
# Compile and pretrain the model.
795
mae_model.compile(
796
optimizer=optimizer, loss=keras.losses.MeanSquaredError(), metrics=["mae"]
797
)
798
history = mae_model.fit(
799
train_ds,
800
epochs=EPOCHS,
801
validation_data=val_ds,
802
callbacks=train_callbacks,
803
)
804
805
# Measure its performance.
806
loss, mae = mae_model.evaluate(test_ds)
807
print(f"Loss: {loss:.2f}")
808
print(f"MAE: {mae:.2f}")
809
810
"""
811
## Evaluation with linear probing
812
"""
813
814
"""
815
### Extract the encoder model along with other layers
816
"""
817
818
# Extract the augmentation layers.
819
train_augmentation_model = mae_model.train_augmentation_model
820
test_augmentation_model = mae_model.test_augmentation_model
821
822
# Extract the patchers.
823
patch_layer = mae_model.patch_layer
824
patch_encoder = mae_model.patch_encoder
825
patch_encoder.downstream = True # Swtich the downstream flag to True.
826
827
# Extract the encoder.
828
encoder = mae_model.encoder
829
830
# Pack as a model.
831
downstream_model = keras.Sequential(
832
[
833
layers.Input((IMAGE_SIZE, IMAGE_SIZE, 3)),
834
patch_layer,
835
patch_encoder,
836
encoder,
837
layers.BatchNormalization(), # Refer to A.1 (Linear probing).
838
layers.GlobalAveragePooling1D(),
839
layers.Dense(NUM_CLASSES, activation="softmax"),
840
],
841
name="linear_probe_model",
842
)
843
844
# Only the final classification layer of the `downstream_model` should be trainable.
845
for layer in downstream_model.layers[:-1]:
846
layer.trainable = False
847
848
downstream_model.summary()
849
850
"""
851
We are using average pooling to extract learned representations from the MAE encoder.
852
Another approach would be to use a learnable dummy token inside the encoder during
853
pretraining (resembling the [CLS] token). Then we can extract representations from that
854
token during the downstream tasks.
855
"""
856
857
"""
858
### Prepare datasets for linear probing
859
"""
860
861
862
def prepare_data(images, labels, is_train=True):
863
if is_train:
864
augmentation_model = train_augmentation_model
865
else:
866
augmentation_model = test_augmentation_model
867
868
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
869
if is_train:
870
dataset = dataset.shuffle(BUFFER_SIZE)
871
872
dataset = dataset.batch(BATCH_SIZE).map(
873
lambda x, y: (augmentation_model(x), y), num_parallel_calls=AUTO
874
)
875
return dataset.prefetch(AUTO)
876
877
878
train_ds = prepare_data(x_train, y_train)
879
val_ds = prepare_data(x_train, y_train, is_train=False)
880
test_ds = prepare_data(x_test, y_test, is_train=False)
881
882
"""
883
### Perform linear probing
884
"""
885
886
linear_probe_epochs = 50
887
linear_prob_lr = 0.1
888
warm_epoch_percentage = 0.1
889
steps = int((len(x_train) // BATCH_SIZE) * linear_probe_epochs)
890
891
warmup_steps = int(steps * warm_epoch_percentage)
892
scheduled_lrs = WarmUpCosine(
893
learning_rate_base=linear_prob_lr,
894
total_steps=steps,
895
warmup_learning_rate=0.0,
896
warmup_steps=warmup_steps,
897
)
898
899
optimizer = keras.optimizers.SGD(learning_rate=scheduled_lrs, momentum=0.9)
900
downstream_model.compile(
901
optimizer=optimizer, loss="sparse_categorical_crossentropy", metrics=["accuracy"]
902
)
903
downstream_model.fit(train_ds, validation_data=val_ds, epochs=linear_probe_epochs)
904
905
loss, accuracy = downstream_model.evaluate(test_ds)
906
accuracy = round(accuracy * 100, 2)
907
print(f"Accuracy on the test set: {accuracy}%.")
908
909
"""
910
We believe that with a more sophisticated hyperparameter tuning process and a longer
911
pretraining it is possible to improve this performance further. For comparison, we took
912
the encoder architecture and
913
[trained it from scratch](https://github.com/ariG23498/mae-scalable-vision-learners/blob/master/regular-classification.ipynb)
914
in a fully supervised manner. This gave us ~76% test top-1 accuracy. The authors of
915
MAE demonstrates strong performance on the ImageNet-1k dataset as well as
916
other downstream tasks like object detection and semantic segmentation.
917
"""
918
919
"""
920
## Final notes
921
922
We refer the interested readers to other examples on self-supervised learning present on
923
keras.io:
924
925
* [SimCLR](https://keras.io/examples/vision/semisupervised_simclr/)
926
* [NNCLR](https://keras.io/examples/vision/nnclr)
927
* [SimSiam](https://keras.io/examples/vision/simsiam)
928
929
This idea of using BERT flavored pretraining in computer vision was also explored in
930
[Selfie](https://arxiv.org/abs/1906.02940), but it could not demonstrate strong results.
931
Another concurrent work that explores the idea of masked image modeling is
932
[SimMIM](https://arxiv.org/abs/2111.09886). Finally, as a fun fact, we, the authors of
933
this example also explored the idea of ["reconstruction as a pretext task"](https://i.ibb.co/k5CpwDX/image.png)
934
in 2020 but we could not prevent the network from representation collapse, and
935
hence we did not get strong downstream performance.
936
937
We would like to thank [Xinlei Chen](http://xinleic.xyz/)
938
(one of the authors of MAE) for helpful discussions. We are grateful to
939
[JarvisLabs](https://jarvislabs.ai/) and
940
[Google Developers Experts](https://developers.google.com/programs/experts/)
941
program for helping with GPU credits.
942
"""
943
944