Path: blob/master/examples/vision/consistency_training.py
3507 views
"""1Title: Consistency training with supervision2Author: [Sayak Paul](https://twitter.com/RisingSayak)3Date created: 2021/04/134Last modified: 2021/04/195Description: Training with consistency regularization for robustness against data distribution shifts.6Accelerator: GPU7"""89"""10Deep learning models excel in many image recognition tasks when the data is independent11and identically distributed (i.i.d.). However, they can suffer from performance12degradation caused by subtle distribution shifts in the input data (such as random13noise, contrast change, and blurring). So, naturally, there arises a question of14why. As discussed in [A Fourier Perspective on Model Robustness in Computer Vision](https://arxiv.org/pdf/1906.08988.pdf)),15there's no reason for deep learning models to be robust against such shifts. Standard16model training procedures (such as standard image classification training workflows)17*don't* enable a model to learn beyond what's fed to it in the form of training data.1819In this example, we will be training an image classification model enforcing a sense of20*consistency* inside it by doing the following:2122* Train a standard image classification model.23* Train an _equal or larger_ model on a noisy version of the dataset (augmented using24[RandAugment](https://arxiv.org/abs/1909.13719)).25* To do this, we will first obtain predictions of the previous model on the clean images26of the dataset.27* We will then use these predictions and train the second model to match these28predictions on the noisy variant of the same images. This is identical to the workflow of29[*Knowledge Distillation*](https://keras.io/examples/vision/knowledge_distillation/) but30since the student model is equal or larger in size this process is also referred to as31***Self-Training***.3233This overall training workflow finds its roots in works like34[FixMatch](https://arxiv.org/abs/2001.07685), [Unsupervised Data Augmentation for Consistency Training](https://arxiv.org/abs/1904.12848),35and [Noisy Student Training](https://arxiv.org/abs/1911.04252). Since this training36process encourages a model yield consistent predictions for clean as well as noisy37images, it's often referred to as *consistency training* or *training with consistency38regularization*. Although the example focuses on using consistency training to enhance39the robustness of models to common corruptions this example can also serve a template40for performing _weakly supervised learning_.4142This example requires TensorFlow 2.4 or higher, as well as TensorFlow Hub and TensorFlow43Models, which can be installed using the following command:4445"""4647"""shell48pip install -q tf-models-official tensorflow-addons49"""5051"""52## Imports and setup53"""5455from official.vision.image_classification.augment import RandAugment56from tensorflow.keras import layers5758import tensorflow as tf59import tensorflow_addons as tfa60import matplotlib.pyplot as plt6162tf.random.set_seed(42)6364"""65## Define hyperparameters66"""6768AUTO = tf.data.AUTOTUNE69BATCH_SIZE = 12870EPOCHS = 57172CROP_TO = 7273RESIZE_TO = 967475"""76## Load the CIFAR-10 dataset77"""7879(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()8081val_samples = 4950082new_train_x, new_y_train = x_train[: val_samples + 1], y_train[: val_samples + 1]83val_x, val_y = x_train[val_samples:], y_train[val_samples:]8485"""86## Create TensorFlow `Dataset` objects87"""8889# Initialize `RandAugment` object with 2 layers of90# augmentation transforms and strength of 9.91augmenter = RandAugment(num_layers=2, magnitude=9)9293"""94For training the teacher model, we will only be using two geometric augmentation95transforms: random horizontal flip and random crop.96"""979899def preprocess_train(image, label, noisy=True):100image = tf.image.random_flip_left_right(image)101# We first resize the original image to a larger dimension102# and then we take random crops from it.103image = tf.image.resize(image, [RESIZE_TO, RESIZE_TO])104image = tf.image.random_crop(image, [CROP_TO, CROP_TO, 3])105if noisy:106image = augmenter.distort(image)107return image, label108109110def preprocess_test(image, label):111image = tf.image.resize(image, [CROP_TO, CROP_TO])112return image, label113114115train_ds = tf.data.Dataset.from_tensor_slices((new_train_x, new_y_train))116validation_ds = tf.data.Dataset.from_tensor_slices((val_x, val_y))117test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))118119"""120We make sure `train_clean_ds` and `train_noisy_ds` are shuffled using the *same* seed to121ensure their orders are exactly the same. This will be helpful during training the122student model.123"""124125# This dataset will be used to train the first model.126train_clean_ds = (127train_ds.shuffle(BATCH_SIZE * 10, seed=42)128.map(lambda x, y: (preprocess_train(x, y, noisy=False)), num_parallel_calls=AUTO)129.batch(BATCH_SIZE)130.prefetch(AUTO)131)132133# This prepares the `Dataset` object to use RandAugment.134train_noisy_ds = (135train_ds.shuffle(BATCH_SIZE * 10, seed=42)136.map(preprocess_train, num_parallel_calls=AUTO)137.batch(BATCH_SIZE)138.prefetch(AUTO)139)140141validation_ds = (142validation_ds.map(preprocess_test, num_parallel_calls=AUTO)143.batch(BATCH_SIZE)144.prefetch(AUTO)145)146147test_ds = (148test_ds.map(preprocess_test, num_parallel_calls=AUTO)149.batch(BATCH_SIZE)150.prefetch(AUTO)151)152153# This dataset will be used to train the second model.154consistency_training_ds = tf.data.Dataset.zip((train_clean_ds, train_noisy_ds))155156"""157## Visualize the datasets158"""159160sample_images, sample_labels = next(iter(train_clean_ds))161plt.figure(figsize=(10, 10))162for i, image in enumerate(sample_images[:9]):163ax = plt.subplot(3, 3, i + 1)164plt.imshow(image.numpy().astype("int"))165plt.axis("off")166167sample_images, sample_labels = next(iter(train_noisy_ds))168plt.figure(figsize=(10, 10))169for i, image in enumerate(sample_images[:9]):170ax = plt.subplot(3, 3, i + 1)171plt.imshow(image.numpy().astype("int"))172plt.axis("off")173174"""175## Define a model building utility function176177We now define our model building utility. Our model is based on the [ResNet50V2 architecture](https://arxiv.org/abs/1603.05027).178"""179180181def get_training_model(num_classes=10):182resnet50_v2 = tf.keras.applications.ResNet50V2(183weights=None,184include_top=False,185input_shape=(CROP_TO, CROP_TO, 3),186)187model = tf.keras.Sequential(188[189layers.Input((CROP_TO, CROP_TO, 3)),190layers.Rescaling(scale=1.0 / 127.5, offset=-1),191resnet50_v2,192layers.GlobalAveragePooling2D(),193layers.Dense(num_classes),194]195)196return model197198199"""200In the interest of reproducibility, we serialize the initial random weights of the201teacher network.202"""203204initial_teacher_model = get_training_model()205initial_teacher_model.save_weights("initial_teacher_model.h5")206207"""208## Train the teacher model209210As noted in Noisy Student Training, if the teacher model is trained with *geometric211ensembling* and when the student model is forced to mimic that, it leads to better212performance. The original work uses [Stochastic Depth](https://arxiv.org/abs/1603.09382)213and [Dropout](https://jmlr.org/papers/v15/srivastava14a.html) to bring in the ensembling214part but for this example, we will use [Stochastic Weight Averaging](https://arxiv.org/abs/1803.05407)215(SWA) which also resembles geometric ensembling.216"""217218# Define the callbacks.219reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(patience=3)220early_stopping = tf.keras.callbacks.EarlyStopping(221patience=10, restore_best_weights=True222)223224# Initialize SWA from tf-hub.225SWA = tfa.optimizers.SWA226227# Compile and train the teacher model.228teacher_model = get_training_model()229teacher_model.load_weights("initial_teacher_model.h5")230teacher_model.compile(231# Notice that we are wrapping our optimizer within SWA232optimizer=SWA(tf.keras.optimizers.Adam()),233loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),234metrics=["accuracy"],235)236history = teacher_model.fit(237train_clean_ds,238epochs=EPOCHS,239validation_data=validation_ds,240callbacks=[reduce_lr, early_stopping],241)242243# Evaluate the teacher model on the test set.244_, acc = teacher_model.evaluate(test_ds, verbose=0)245print(f"Test accuracy: {acc*100}%")246247"""248## Define a self-training utility249250For this part, we will borrow the `Distiller` class from [this Keras Example](https://keras.io/examples/vision/knowledge_distillation/).251"""252253254# Majority of the code is taken from:255# https://keras.io/examples/vision/knowledge_distillation/256class SelfTrainer(tf.keras.Model):257def __init__(self, student, teacher):258super().__init__()259self.student = student260self.teacher = teacher261262def compile(263self,264optimizer,265metrics,266student_loss_fn,267distillation_loss_fn,268temperature=3,269):270super().compile(optimizer=optimizer, metrics=metrics)271self.student_loss_fn = student_loss_fn272self.distillation_loss_fn = distillation_loss_fn273self.temperature = temperature274275def train_step(self, data):276# Since our dataset is a zip of two independent datasets,277# after initially parsing them, we segregate the278# respective images and labels next.279clean_ds, noisy_ds = data280clean_images, _ = clean_ds281noisy_images, y = noisy_ds282283# Forward pass of teacher284teacher_predictions = self.teacher(clean_images, training=False)285286with tf.GradientTape() as tape:287# Forward pass of student288student_predictions = self.student(noisy_images, training=True)289290# Compute losses291student_loss = self.student_loss_fn(y, student_predictions)292distillation_loss = self.distillation_loss_fn(293tf.nn.softmax(teacher_predictions / self.temperature, axis=1),294tf.nn.softmax(student_predictions / self.temperature, axis=1),295)296total_loss = (student_loss + distillation_loss) / 2297298# Compute gradients299trainable_vars = self.student.trainable_variables300gradients = tape.gradient(total_loss, trainable_vars)301302# Update weights303self.optimizer.apply_gradients(zip(gradients, trainable_vars))304305# Update the metrics configured in `compile()`306self.compiled_metrics.update_state(307y, tf.nn.softmax(student_predictions, axis=1)308)309310# Return a dict of performance311results = {m.name: m.result() for m in self.metrics}312results.update({"total_loss": total_loss})313return results314315def test_step(self, data):316# During inference, we only pass a dataset consisting images and labels.317x, y = data318319# Compute predictions320y_prediction = self.student(x, training=False)321322# Update the metrics323self.compiled_metrics.update_state(y, tf.nn.softmax(y_prediction, axis=1))324325# Return a dict of performance326results = {m.name: m.result() for m in self.metrics}327return results328329330"""331The only difference in this implementation is the way loss is being calculated. **Instead332of weighted the distillation loss and student loss differently we are taking their333average following Noisy Student Training**.334"""335336"""337## Train the student model338"""339340# Define the callbacks.341# We are using a larger decay factor to stabilize the training.342reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(343patience=3, factor=0.5, monitor="val_accuracy"344)345early_stopping = tf.keras.callbacks.EarlyStopping(346patience=10, restore_best_weights=True, monitor="val_accuracy"347)348349# Compile and train the student model.350self_trainer = SelfTrainer(student=get_training_model(), teacher=teacher_model)351self_trainer.compile(352# Notice we are *not* using SWA here.353optimizer="adam",354metrics=["accuracy"],355student_loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),356distillation_loss_fn=tf.keras.losses.KLDivergence(),357temperature=10,358)359history = self_trainer.fit(360consistency_training_ds,361epochs=EPOCHS,362validation_data=validation_ds,363callbacks=[reduce_lr, early_stopping],364)365366# Evaluate the student model.367acc = self_trainer.evaluate(test_ds, verbose=0)368print(f"Test accuracy from student model: {acc*100}%")369370"""371## Assess the robustness of the models372373A standard benchmark of assessing the robustness of vision models is to record their374performance on corrupted datasets like ImageNet-C and CIFAR-10-C both of which were375proposed in [Benchmarking Neural Network Robustness to Common Corruptions and376Perturbations](https://arxiv.org/abs/1903.12261). For this example, we will be using the377CIFAR-10-C dataset which has 19 different corruptions on 5 different severity levels. To378assess the robustness of the models on this dataset, we will do the following:379380* Run the pre-trained models on the highest level of severities and obtain the top-1381accuracies.382* Compute the mean top-1 accuracy.383384For the purpose of this example, we won't be going through these steps. This is why we385trained the models for only 5 epochs. You can check out [this386repository](https://github.com/sayakpaul/Consistency-Training-with-Supervision) that387demonstrates the full-scale training experiments and also the aforementioned assessment.388The figure below presents an executive summary of that assessment:389390391392**Mean Top-1** results stand for the CIFAR-10-C dataset and **Test Top-1** results stand393for the CIFAR-10 test set. It's clear that consistency training has an advantage on not394only enhancing the model robustness but also on improving the standard test performance.395"""396397398