Path: blob/master/examples/vision/masked_image_modeling.py
3507 views
"""1Title: Masked image modeling with Autoencoders2Author: [Aritra Roy Gosthipaty](https://twitter.com/arig23498), [Sayak Paul](https://twitter.com/RisingSayak)3Date created: 2021/12/204Last modified: 2021/12/215Description: Implementing Masked Autoencoders for self-supervised pretraining.6Accelerator: GPU7"""89"""10## Introduction1112In deep learning, models with growing **capacity** and **capability** can easily overfit13on large datasets (ImageNet-1K). In the field of natural language processing, the14appetite for data has been **successfully addressed** by self-supervised pretraining.1516In the academic paper17[Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377)18by He et. al. the authors propose a simple yet effective method to pretrain large19vision models (here [ViT Huge](https://arxiv.org/abs/2010.11929)). Inspired from20the pretraining algorithm of BERT ([Devlin et al.](https://arxiv.org/abs/1810.04805)),21they mask patches of an image and, through an autoencoder predict the masked patches.22In the spirit of "masked language modeling", this pretraining task could be referred23to as "masked image modeling".2425In this example, we implement26[Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377)27with the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset. After28pretraining a scaled down version of ViT, we also implement the linear evaluation29pipeline on CIFAR-10.303132This implementation covers (MAE refers to Masked Autoencoder):3334- The masking algorithm35- MAE encoder36- MAE decoder37- Evaluation with linear probing3839As a reference, we reuse some of the code presented in40[this example](https://keras.io/examples/vision/image_classification_with_vision_transformer/).4142"""4344"""45## Imports46"""47import os4849os.environ["KERAS_BACKEND"] = "tensorflow"5051import tensorflow as tf52import keras53from keras import layers5455import matplotlib.pyplot as plt56import numpy as np57import random5859# Setting seeds for reproducibility.60SEED = 4261keras.utils.set_random_seed(SEED)6263"""64## Hyperparameters for pretraining6566Please feel free to change the hyperparameters and check your results. The best way to67get an intuition about the architecture is to experiment with it. Our hyperparameters are68heavily inspired by the design guidelines laid out by the authors in69[the original paper](https://arxiv.org/abs/2111.06377).70"""7172# DATA73BUFFER_SIZE = 102474BATCH_SIZE = 25675AUTO = tf.data.AUTOTUNE76INPUT_SHAPE = (32, 32, 3)77NUM_CLASSES = 107879# OPTIMIZER80LEARNING_RATE = 5e-381WEIGHT_DECAY = 1e-48283# PRETRAINING84EPOCHS = 1008586# AUGMENTATION87IMAGE_SIZE = 48 # We will resize input images to this size.88PATCH_SIZE = 6 # Size of the patches to be extracted from the input images.89NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 290MASK_PROPORTION = 0.75 # We have found 75% masking to give us the best results.9192# ENCODER and DECODER93LAYER_NORM_EPS = 1e-694ENC_PROJECTION_DIM = 12895DEC_PROJECTION_DIM = 6496ENC_NUM_HEADS = 497ENC_LAYERS = 698DEC_NUM_HEADS = 499DEC_LAYERS = (1002 # The decoder is lightweight but should be reasonably deep for reconstruction.101)102ENC_TRANSFORMER_UNITS = [103ENC_PROJECTION_DIM * 2,104ENC_PROJECTION_DIM,105] # Size of the transformer layers.106DEC_TRANSFORMER_UNITS = [107DEC_PROJECTION_DIM * 2,108DEC_PROJECTION_DIM,109]110111"""112## Load and prepare the CIFAR-10 dataset113"""114115(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()116(x_train, y_train), (x_val, y_val) = (117(x_train[:40000], y_train[:40000]),118(x_train[40000:], y_train[40000:]),119)120print(f"Training samples: {len(x_train)}")121print(f"Validation samples: {len(x_val)}")122print(f"Testing samples: {len(x_test)}")123124train_ds = tf.data.Dataset.from_tensor_slices(x_train)125train_ds = train_ds.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(AUTO)126127val_ds = tf.data.Dataset.from_tensor_slices(x_val)128val_ds = val_ds.batch(BATCH_SIZE).prefetch(AUTO)129130test_ds = tf.data.Dataset.from_tensor_slices(x_test)131test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO)132133"""134## Data augmentation135136In previous self-supervised pretraining methodologies137([SimCLR](https://arxiv.org/abs/2002.05709) alike), we have noticed that the data138augmentation pipeline plays an important role. On the other hand the authors of this139paper point out that Masked Autoencoders **do not** rely on augmentations. They propose a140simple augmentation pipeline of:141142143- Resizing144- Random cropping (fixed-sized or random sized)145- Random horizontal flipping146"""147148149def get_train_augmentation_model():150model = keras.Sequential(151[152layers.Rescaling(1 / 255.0),153layers.Resizing(INPUT_SHAPE[0] + 20, INPUT_SHAPE[0] + 20),154layers.RandomCrop(IMAGE_SIZE, IMAGE_SIZE),155layers.RandomFlip("horizontal"),156],157name="train_data_augmentation",158)159return model160161162def get_test_augmentation_model():163model = keras.Sequential(164[165layers.Rescaling(1 / 255.0),166layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),167],168name="test_data_augmentation",169)170return model171172173"""174## A layer for extracting patches from images175176This layer takes images as input and divides them into patches. The layer also includes177two utility method:178179- `show_patched_image` -- Takes a batch of images and its corresponding patches to plot a180random pair of image and patches.181- `reconstruct_from_patch` -- Takes a single instance of patches and stitches them182together into the original image.183"""184185186class Patches(layers.Layer):187def __init__(self, patch_size=PATCH_SIZE, **kwargs):188super().__init__(**kwargs)189self.patch_size = patch_size190191# Assuming the image has three channels each patch would be192# of size (patch_size, patch_size, 3).193self.resize = layers.Reshape((-1, patch_size * patch_size * 3))194195def call(self, images):196# Create patches from the input images197patches = tf.image.extract_patches(198images=images,199sizes=[1, self.patch_size, self.patch_size, 1],200strides=[1, self.patch_size, self.patch_size, 1],201rates=[1, 1, 1, 1],202padding="VALID",203)204205# Reshape the patches to (batch, num_patches, patch_area) and return it.206patches = self.resize(patches)207return patches208209def show_patched_image(self, images, patches):210# This is a utility function which accepts a batch of images and its211# corresponding patches and help visualize one image and its patches212# side by side.213idx = np.random.choice(patches.shape[0])214print(f"Index selected: {idx}.")215216plt.figure(figsize=(4, 4))217plt.imshow(keras.utils.array_to_img(images[idx]))218plt.axis("off")219plt.show()220221n = int(np.sqrt(patches.shape[1]))222plt.figure(figsize=(4, 4))223for i, patch in enumerate(patches[idx]):224ax = plt.subplot(n, n, i + 1)225patch_img = tf.reshape(patch, (self.patch_size, self.patch_size, 3))226plt.imshow(keras.utils.img_to_array(patch_img))227plt.axis("off")228plt.show()229230# Return the index chosen to validate it outside the method.231return idx232233# taken from https://stackoverflow.com/a/58082878/10319735234def reconstruct_from_patch(self, patch):235# This utility function takes patches from a *single* image and236# reconstructs it back into the image. This is useful for the train237# monitor callback.238num_patches = patch.shape[0]239n = int(np.sqrt(num_patches))240patch = tf.reshape(patch, (num_patches, self.patch_size, self.patch_size, 3))241rows = tf.split(patch, n, axis=0)242rows = [tf.concat(tf.unstack(x), axis=1) for x in rows]243reconstructed = tf.concat(rows, axis=0)244return reconstructed245246247"""248Let's visualize the image patches.249"""250251# Get a batch of images.252image_batch = next(iter(train_ds))253254# Augment the images.255augmentation_model = get_train_augmentation_model()256augmented_images = augmentation_model(image_batch)257258# Define the patch layer.259patch_layer = Patches()260261# Get the patches from the batched images.262patches = patch_layer(images=augmented_images)263264# Now pass the images and the corresponding patches265# to the `show_patched_image` method.266random_index = patch_layer.show_patched_image(images=augmented_images, patches=patches)267268# Chose the same chose image and try reconstructing the patches269# into the original image.270image = patch_layer.reconstruct_from_patch(patches[random_index])271plt.imshow(image)272plt.axis("off")273plt.show()274275"""276## Patch encoding with masking277278Quoting the paper279280> Following ViT, we divide an image into regular non-overlapping patches. Then we sample281a subset of patches and mask (i.e., remove) the remaining ones. Our sampling strategy is282straightforward: we sample random patches without replacement, following a uniform283distribution. We simply refer to this as “random sampling”.284285This layer includes masking and encoding the patches.286287The utility methods of the layer are:288289- `get_random_indices` -- Provides the mask and unmask indices.290- `generate_masked_image` -- Takes patches and unmask indices, results in a random masked291image. This is an essential utility method for our training monitor callback (defined292later).293"""294295296class PatchEncoder(layers.Layer):297def __init__(298self,299patch_size=PATCH_SIZE,300projection_dim=ENC_PROJECTION_DIM,301mask_proportion=MASK_PROPORTION,302downstream=False,303**kwargs,304):305super().__init__(**kwargs)306self.patch_size = patch_size307self.projection_dim = projection_dim308self.mask_proportion = mask_proportion309self.downstream = downstream310311# This is a trainable mask token initialized randomly from a normal312# distribution.313self.mask_token = tf.Variable(314tf.random.normal([1, patch_size * patch_size * 3]), trainable=True315)316317def build(self, input_shape):318(_, self.num_patches, self.patch_area) = input_shape319320# Create the projection layer for the patches.321self.projection = layers.Dense(units=self.projection_dim)322323# Create the positional embedding layer.324self.position_embedding = layers.Embedding(325input_dim=self.num_patches, output_dim=self.projection_dim326)327328# Number of patches that will be masked.329self.num_mask = int(self.mask_proportion * self.num_patches)330331def call(self, patches):332# Get the positional embeddings.333batch_size = tf.shape(patches)[0]334positions = tf.range(start=0, limit=self.num_patches, delta=1)335pos_embeddings = self.position_embedding(positions[tf.newaxis, ...])336pos_embeddings = tf.tile(337pos_embeddings, [batch_size, 1, 1]338) # (B, num_patches, projection_dim)339340# Embed the patches.341patch_embeddings = (342self.projection(patches) + pos_embeddings343) # (B, num_patches, projection_dim)344345if self.downstream:346return patch_embeddings347else:348mask_indices, unmask_indices = self.get_random_indices(batch_size)349# The encoder input is the unmasked patch embeddings. Here we gather350# all the patches that should be unmasked.351unmasked_embeddings = tf.gather(352patch_embeddings, unmask_indices, axis=1, batch_dims=1353) # (B, unmask_numbers, projection_dim)354355# Get the unmasked and masked position embeddings. We will need them356# for the decoder.357unmasked_positions = tf.gather(358pos_embeddings, unmask_indices, axis=1, batch_dims=1359) # (B, unmask_numbers, projection_dim)360masked_positions = tf.gather(361pos_embeddings, mask_indices, axis=1, batch_dims=1362) # (B, mask_numbers, projection_dim)363364# Repeat the mask token number of mask times.365# Mask tokens replace the masks of the image.366mask_tokens = tf.repeat(self.mask_token, repeats=self.num_mask, axis=0)367mask_tokens = tf.repeat(368mask_tokens[tf.newaxis, ...], repeats=batch_size, axis=0369)370371# Get the masked embeddings for the tokens.372masked_embeddings = self.projection(mask_tokens) + masked_positions373return (374unmasked_embeddings, # Input to the encoder.375masked_embeddings, # First part of input to the decoder.376unmasked_positions, # Added to the encoder outputs.377mask_indices, # The indices that were masked.378unmask_indices, # The indices that were unmaksed.379)380381def get_random_indices(self, batch_size):382# Create random indices from a uniform distribution and then split383# it into mask and unmask indices.384rand_indices = tf.argsort(385tf.random.uniform(shape=(batch_size, self.num_patches)), axis=-1386)387mask_indices = rand_indices[:, : self.num_mask]388unmask_indices = rand_indices[:, self.num_mask :]389return mask_indices, unmask_indices390391def generate_masked_image(self, patches, unmask_indices):392# Choose a random patch and it corresponding unmask index.393idx = np.random.choice(patches.shape[0])394patch = patches[idx]395unmask_index = unmask_indices[idx]396397# Build a numpy array of same shape as patch.398new_patch = np.zeros_like(patch)399400# Iterate of the new_patch and plug the unmasked patches.401count = 0402for i in range(unmask_index.shape[0]):403new_patch[unmask_index[i]] = patch[unmask_index[i]]404return new_patch, idx405406407"""408Let's see the masking process in action on a sample image.409"""410411# Create the patch encoder layer.412patch_encoder = PatchEncoder()413414# Get the embeddings and positions.415(416unmasked_embeddings,417masked_embeddings,418unmasked_positions,419mask_indices,420unmask_indices,421) = patch_encoder(patches=patches)422423424# Show a maksed patch image.425new_patch, random_index = patch_encoder.generate_masked_image(patches, unmask_indices)426427plt.figure(figsize=(10, 10))428plt.subplot(1, 2, 1)429img = patch_layer.reconstruct_from_patch(new_patch)430plt.imshow(keras.utils.array_to_img(img))431plt.axis("off")432plt.title("Masked")433plt.subplot(1, 2, 2)434img = augmented_images[random_index]435plt.imshow(keras.utils.array_to_img(img))436plt.axis("off")437plt.title("Original")438plt.show()439440"""441## MLP442443This serves as the fully connected feed forward network of the transformer architecture.444"""445446447def mlp(x, dropout_rate, hidden_units):448for units in hidden_units:449x = layers.Dense(units, activation=tf.nn.gelu)(x)450x = layers.Dropout(dropout_rate)(x)451return x452453454"""455## MAE encoder456457The MAE encoder is ViT. The only point to note here is that the encoder outputs a layer458normalized output.459"""460461462def create_encoder(num_heads=ENC_NUM_HEADS, num_layers=ENC_LAYERS):463inputs = layers.Input((None, ENC_PROJECTION_DIM))464x = inputs465466for _ in range(num_layers):467# Layer normalization 1.468x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x)469470# Create a multi-head attention layer.471attention_output = layers.MultiHeadAttention(472num_heads=num_heads, key_dim=ENC_PROJECTION_DIM, dropout=0.1473)(x1, x1)474475# Skip connection 1.476x2 = layers.Add()([attention_output, x])477478# Layer normalization 2.479x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)480481# MLP.482x3 = mlp(x3, hidden_units=ENC_TRANSFORMER_UNITS, dropout_rate=0.1)483484# Skip connection 2.485x = layers.Add()([x3, x2])486487outputs = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x)488return keras.Model(inputs, outputs, name="mae_encoder")489490491"""492## MAE decoder493494The authors point out that they use an **asymmetric** autoencoder model. They use a495lightweight decoder that takes "<10% computation per token vs. the encoder". We are not496specific with the "<10% computation" in our implementation but have used a smaller497decoder (both in terms of depth and projection dimensions).498"""499500501def create_decoder(502num_layers=DEC_LAYERS, num_heads=DEC_NUM_HEADS, image_size=IMAGE_SIZE503):504inputs = layers.Input((NUM_PATCHES, ENC_PROJECTION_DIM))505x = layers.Dense(DEC_PROJECTION_DIM)(inputs)506507for _ in range(num_layers):508# Layer normalization 1.509x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x)510511# Create a multi-head attention layer.512attention_output = layers.MultiHeadAttention(513num_heads=num_heads, key_dim=DEC_PROJECTION_DIM, dropout=0.1514)(x1, x1)515516# Skip connection 1.517x2 = layers.Add()([attention_output, x])518519# Layer normalization 2.520x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)521522# MLP.523x3 = mlp(x3, hidden_units=DEC_TRANSFORMER_UNITS, dropout_rate=0.1)524525# Skip connection 2.526x = layers.Add()([x3, x2])527528x = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x)529x = layers.Flatten()(x)530pre_final = layers.Dense(units=image_size * image_size * 3, activation="sigmoid")(x)531outputs = layers.Reshape((image_size, image_size, 3))(pre_final)532533return keras.Model(inputs, outputs, name="mae_decoder")534535536"""537## MAE trainer538539This is the trainer module. We wrap the encoder and decoder inside of a `tf.keras.Model`540subclass. This allows us to customize what happens in the `model.fit()` loop.541"""542543544class MaskedAutoencoder(keras.Model):545def __init__(546self,547train_augmentation_model,548test_augmentation_model,549patch_layer,550patch_encoder,551encoder,552decoder,553**kwargs,554):555super().__init__(**kwargs)556self.train_augmentation_model = train_augmentation_model557self.test_augmentation_model = test_augmentation_model558self.patch_layer = patch_layer559self.patch_encoder = patch_encoder560self.encoder = encoder561self.decoder = decoder562563def calculate_loss(self, images, test=False):564# Augment the input images.565if test:566augmented_images = self.test_augmentation_model(images)567else:568augmented_images = self.train_augmentation_model(images)569570# Patch the augmented images.571patches = self.patch_layer(augmented_images)572573# Encode the patches.574(575unmasked_embeddings,576masked_embeddings,577unmasked_positions,578mask_indices,579unmask_indices,580) = self.patch_encoder(patches)581582# Pass the unmaksed patche to the encoder.583encoder_outputs = self.encoder(unmasked_embeddings)584585# Create the decoder inputs.586encoder_outputs = encoder_outputs + unmasked_positions587decoder_inputs = tf.concat([encoder_outputs, masked_embeddings], axis=1)588589# Decode the inputs.590decoder_outputs = self.decoder(decoder_inputs)591decoder_patches = self.patch_layer(decoder_outputs)592593loss_patch = tf.gather(patches, mask_indices, axis=1, batch_dims=1)594loss_output = tf.gather(decoder_patches, mask_indices, axis=1, batch_dims=1)595596# Compute the total loss.597total_loss = self.compute_loss(y=loss_patch, y_pred=loss_output)598599return total_loss, loss_patch, loss_output600601def train_step(self, images):602with tf.GradientTape() as tape:603total_loss, loss_patch, loss_output = self.calculate_loss(images)604605# Apply gradients.606train_vars = [607self.train_augmentation_model.trainable_variables,608self.patch_layer.trainable_variables,609self.patch_encoder.trainable_variables,610self.encoder.trainable_variables,611self.decoder.trainable_variables,612]613grads = tape.gradient(total_loss, train_vars)614tv_list = []615for grad, var in zip(grads, train_vars):616for g, v in zip(grad, var):617tv_list.append((g, v))618self.optimizer.apply_gradients(tv_list)619620# Report progress.621results = {}622for metric in self.metrics:623metric.update_state(loss_patch, loss_output)624results[metric.name] = metric.result()625return results626627def test_step(self, images):628total_loss, loss_patch, loss_output = self.calculate_loss(images, test=True)629630# Update the trackers.631results = {}632for metric in self.metrics:633metric.update_state(loss_patch, loss_output)634results[metric.name] = metric.result()635return results636637638"""639## Model initialization640"""641642train_augmentation_model = get_train_augmentation_model()643test_augmentation_model = get_test_augmentation_model()644patch_layer = Patches()645patch_encoder = PatchEncoder()646encoder = create_encoder()647decoder = create_decoder()648649mae_model = MaskedAutoencoder(650train_augmentation_model=train_augmentation_model,651test_augmentation_model=test_augmentation_model,652patch_layer=patch_layer,653patch_encoder=patch_encoder,654encoder=encoder,655decoder=decoder,656)657658"""659## Training callbacks660"""661662"""663### Visualization callback664"""665666# Taking a batch of test inputs to measure model's progress.667test_images = next(iter(test_ds))668669670class TrainMonitor(keras.callbacks.Callback):671def __init__(self, epoch_interval=None):672self.epoch_interval = epoch_interval673674def on_epoch_end(self, epoch, logs=None):675if self.epoch_interval and epoch % self.epoch_interval == 0:676test_augmented_images = self.model.test_augmentation_model(test_images)677test_patches = self.model.patch_layer(test_augmented_images)678(679test_unmasked_embeddings,680test_masked_embeddings,681test_unmasked_positions,682test_mask_indices,683test_unmask_indices,684) = self.model.patch_encoder(test_patches)685test_encoder_outputs = self.model.encoder(test_unmasked_embeddings)686test_encoder_outputs = test_encoder_outputs + test_unmasked_positions687test_decoder_inputs = tf.concat(688[test_encoder_outputs, test_masked_embeddings], axis=1689)690test_decoder_outputs = self.model.decoder(test_decoder_inputs)691692# Show a maksed patch image.693test_masked_patch, idx = self.model.patch_encoder.generate_masked_image(694test_patches, test_unmask_indices695)696print(f"\nIdx chosen: {idx}")697original_image = test_augmented_images[idx]698masked_image = self.model.patch_layer.reconstruct_from_patch(699test_masked_patch700)701reconstructed_image = test_decoder_outputs[idx]702703fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))704ax[0].imshow(original_image)705ax[0].set_title(f"Original: {epoch:03d}")706707ax[1].imshow(masked_image)708ax[1].set_title(f"Masked: {epoch:03d}")709710ax[2].imshow(reconstructed_image)711ax[2].set_title(f"Resonstructed: {epoch:03d}")712713plt.show()714plt.close()715716717"""718### Learning rate scheduler719"""720721# Some code is taken from:722# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2.723724725class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):726def __init__(727self, learning_rate_base, total_steps, warmup_learning_rate, warmup_steps728):729super().__init__()730731self.learning_rate_base = learning_rate_base732self.total_steps = total_steps733self.warmup_learning_rate = warmup_learning_rate734self.warmup_steps = warmup_steps735self.pi = tf.constant(np.pi)736737def __call__(self, step):738if self.total_steps < self.warmup_steps:739raise ValueError("Total_steps must be larger or equal to warmup_steps.")740741cos_annealed_lr = tf.cos(742self.pi743* (tf.cast(step, tf.float32) - self.warmup_steps)744/ float(self.total_steps - self.warmup_steps)745)746learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr)747748if self.warmup_steps > 0:749if self.learning_rate_base < self.warmup_learning_rate:750raise ValueError(751"Learning_rate_base must be larger or equal to "752"warmup_learning_rate."753)754slope = (755self.learning_rate_base - self.warmup_learning_rate756) / self.warmup_steps757warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate758learning_rate = tf.where(759step < self.warmup_steps, warmup_rate, learning_rate760)761return tf.where(762step > self.total_steps, 0.0, learning_rate, name="learning_rate"763)764765766total_steps = int((len(x_train) / BATCH_SIZE) * EPOCHS)767warmup_epoch_percentage = 0.15768warmup_steps = int(total_steps * warmup_epoch_percentage)769scheduled_lrs = WarmUpCosine(770learning_rate_base=LEARNING_RATE,771total_steps=total_steps,772warmup_learning_rate=0.0,773warmup_steps=warmup_steps,774)775776lrs = [scheduled_lrs(step) for step in range(total_steps)]777plt.plot(lrs)778plt.xlabel("Step", fontsize=14)779plt.ylabel("LR", fontsize=14)780plt.show()781782# Assemble the callbacks.783train_callbacks = [TrainMonitor(epoch_interval=5)]784785"""786## Model compilation and training787"""788789optimizer = keras.optimizers.AdamW(790learning_rate=scheduled_lrs, weight_decay=WEIGHT_DECAY791)792793# Compile and pretrain the model.794mae_model.compile(795optimizer=optimizer, loss=keras.losses.MeanSquaredError(), metrics=["mae"]796)797history = mae_model.fit(798train_ds,799epochs=EPOCHS,800validation_data=val_ds,801callbacks=train_callbacks,802)803804# Measure its performance.805loss, mae = mae_model.evaluate(test_ds)806print(f"Loss: {loss:.2f}")807print(f"MAE: {mae:.2f}")808809"""810## Evaluation with linear probing811"""812813"""814### Extract the encoder model along with other layers815"""816817# Extract the augmentation layers.818train_augmentation_model = mae_model.train_augmentation_model819test_augmentation_model = mae_model.test_augmentation_model820821# Extract the patchers.822patch_layer = mae_model.patch_layer823patch_encoder = mae_model.patch_encoder824patch_encoder.downstream = True # Swtich the downstream flag to True.825826# Extract the encoder.827encoder = mae_model.encoder828829# Pack as a model.830downstream_model = keras.Sequential(831[832layers.Input((IMAGE_SIZE, IMAGE_SIZE, 3)),833patch_layer,834patch_encoder,835encoder,836layers.BatchNormalization(), # Refer to A.1 (Linear probing).837layers.GlobalAveragePooling1D(),838layers.Dense(NUM_CLASSES, activation="softmax"),839],840name="linear_probe_model",841)842843# Only the final classification layer of the `downstream_model` should be trainable.844for layer in downstream_model.layers[:-1]:845layer.trainable = False846847downstream_model.summary()848849"""850We are using average pooling to extract learned representations from the MAE encoder.851Another approach would be to use a learnable dummy token inside the encoder during852pretraining (resembling the [CLS] token). Then we can extract representations from that853token during the downstream tasks.854"""855856"""857### Prepare datasets for linear probing858"""859860861def prepare_data(images, labels, is_train=True):862if is_train:863augmentation_model = train_augmentation_model864else:865augmentation_model = test_augmentation_model866867dataset = tf.data.Dataset.from_tensor_slices((images, labels))868if is_train:869dataset = dataset.shuffle(BUFFER_SIZE)870871dataset = dataset.batch(BATCH_SIZE).map(872lambda x, y: (augmentation_model(x), y), num_parallel_calls=AUTO873)874return dataset.prefetch(AUTO)875876877train_ds = prepare_data(x_train, y_train)878val_ds = prepare_data(x_train, y_train, is_train=False)879test_ds = prepare_data(x_test, y_test, is_train=False)880881"""882### Perform linear probing883"""884885linear_probe_epochs = 50886linear_prob_lr = 0.1887warm_epoch_percentage = 0.1888steps = int((len(x_train) // BATCH_SIZE) * linear_probe_epochs)889890warmup_steps = int(steps * warm_epoch_percentage)891scheduled_lrs = WarmUpCosine(892learning_rate_base=linear_prob_lr,893total_steps=steps,894warmup_learning_rate=0.0,895warmup_steps=warmup_steps,896)897898optimizer = keras.optimizers.SGD(learning_rate=scheduled_lrs, momentum=0.9)899downstream_model.compile(900optimizer=optimizer, loss="sparse_categorical_crossentropy", metrics=["accuracy"]901)902downstream_model.fit(train_ds, validation_data=val_ds, epochs=linear_probe_epochs)903904loss, accuracy = downstream_model.evaluate(test_ds)905accuracy = round(accuracy * 100, 2)906print(f"Accuracy on the test set: {accuracy}%.")907908"""909We believe that with a more sophisticated hyperparameter tuning process and a longer910pretraining it is possible to improve this performance further. For comparison, we took911the encoder architecture and912[trained it from scratch](https://github.com/ariG23498/mae-scalable-vision-learners/blob/master/regular-classification.ipynb)913in a fully supervised manner. This gave us ~76% test top-1 accuracy. The authors of914MAE demonstrates strong performance on the ImageNet-1k dataset as well as915other downstream tasks like object detection and semantic segmentation.916"""917918"""919## Final notes920921We refer the interested readers to other examples on self-supervised learning present on922keras.io:923924* [SimCLR](https://keras.io/examples/vision/semisupervised_simclr/)925* [NNCLR](https://keras.io/examples/vision/nnclr)926* [SimSiam](https://keras.io/examples/vision/simsiam)927928This idea of using BERT flavored pretraining in computer vision was also explored in929[Selfie](https://arxiv.org/abs/1906.02940), but it could not demonstrate strong results.930Another concurrent work that explores the idea of masked image modeling is931[SimMIM](https://arxiv.org/abs/2111.09886). Finally, as a fun fact, we, the authors of932this example also explored the idea of ["reconstruction as a pretext task"](https://i.ibb.co/k5CpwDX/image.png)933in 2020 but we could not prevent the network from representation collapse, and934hence we did not get strong downstream performance.935936We would like to thank [Xinlei Chen](http://xinleic.xyz/)937(one of the authors of MAE) for helpful discussions. We are grateful to938[JarvisLabs](https://jarvislabs.ai/) and939[Google Developers Experts](https://developers.google.com/programs/experts/)940program for helping with GPU credits.941"""942943944