Path: blob/master/examples/keras_recipes/better_knowledge_distillation.py
3507 views
"""1Title: Knowledge distillation recipes2Author: [Sayak Paul](https://twitter.com/RisingSayak)3Date created: 2021/08/014Last modified: 2021/08/015Description: Training better student models via knowledge distillation with function matching.6Accelerator: GPU7"""89"""10## Introduction1112Knowledge distillation ([Hinton et al.](https://arxiv.org/abs/1503.02531)) is a technique13that enables us to compress larger models into smaller ones. This allows us to reap the14benefits of high performing larger models, while reducing storage and memory costs and15achieving higher inference speed:1617* Smaller models -> smaller memory footprint18* Reduced complexity -> fewer floating-point operations (FLOPs)1920In [Knowledge distillation: A good teacher is patient and consistent](https://arxiv.org/abs/2106.05237),21Beyer et al. investigate various existing setups for performing knowledge distillation22and show that all of them lead to sub-optimal performance. Due to this,23practitioners often settle for other alternatives (quantization, pruning, weight24clustering, etc.) when developing production systems that are resource-constrained.2526Beyer et al. investigate how we can improve the student models that come out27of the knowledge distillation process and always match the performance of28their teacher models. In this example, we will study the recipes introduced by them, using29the [Flowers102 dataset](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/). As a30reference, with these recipes, the authors were able to produce a ResNet50 model that31achieves 82.8% accuracy on the ImageNet-1k dataset.3233In case you need a refresher on knowledge distillation and want to study how it is34implemented in Keras, you can refer to35[this example](https://keras.io/examples/vision/knowledge_distillation/).36You can also follow37[this example](https://keras.io/examples/vision/consistency_training/)38that shows an extension of knowledge distillation applied to consistency training.39"""4041"""42## Imports43"""4445import os4647os.environ["KERAS_BACKEND"] = "tensorflow"4849import keras50import tensorflow as tf5152import matplotlib.pyplot as plt53import numpy as np5455import tensorflow_datasets as tfds5657tfds.disable_progress_bar()5859"""60## Hyperparameters and constants61"""6263AUTO = tf.data.AUTOTUNE # Used to dynamically adjust parallelism.64BATCH_SIZE = 646566# Comes from Table 4 and "Training setup" section.67TEMPERATURE = 10 # Used to soften the logits before they go to softmax.68INIT_LR = 0.003 # Initial learning rate that will be decayed over the training period.69WEIGHT_DECAY = 0.001 # Used for regularization.70CLIP_THRESHOLD = 1.0 # Used for clipping the gradients by L2-norm.7172# We will first resize the training images to a bigger size and then we will take73# random crops of a lower size.74BIGGER = 16075RESIZE = 1287677"""78## Load the Flowers102 dataset79"""8081train_ds, validation_ds, test_ds = tfds.load(82"oxford_flowers102", split=["train", "validation", "test"], as_supervised=True83)84print(f"Number of training examples: {train_ds.cardinality()}.")85print(f"Number of validation examples: {validation_ds.cardinality()}.")86print(f"Number of test examples: {test_ds.cardinality()}.")8788"""89## Teacher model9091As is common with any distillation technique, it's important to first train a92well-performing teacher model which is usually larger than the subsequent student model.93The authors distill a BiT ResNet152x2 model (teacher) into a BiT ResNet50 model94(student).9596BiT stands for Big Transfer and was introduced in97[Big Transfer (BiT): General Visual Representation Learning](https://arxiv.org/abs/1912.11370).98BiT variants of ResNets use Group Normalization ([Wu et al.](https://arxiv.org/abs/1803.08494))99and Weight Standardization ([Qiao et al.](https://arxiv.org/abs/1903.10520v2))100in place of Batch Normalization ([Ioffe et al.](https://arxiv.org/abs/1502.03167)).101In order to limit the time it takes to run this example, we will be using a BiT102ResNet101x3 already trained on the Flowers102 dataset. You can refer to103[this notebook](https://github.com/sayakpaul/FunMatch-Distillation/blob/main/train_bit.ipynb)104to learn more about the training process. This model reaches 98.18% accuracy on the105test set of Flowers102.106107The model weights are hosted on Kaggle as a dataset.108To download the weights, follow these steps:1091101. Create an account on Kaggle [here](https://www.kaggle.com).1112. Go to the "Account" tab of your [user profile](https://www.kaggle.com/account).1123. Select "Create API Token". This will trigger the download of `kaggle.json`, a file113containing your API credentials.1144. From that JSON file, copy your Kaggle username and API key.115116Now run the following:117118```python119import os120121os.environ["KAGGLE_USERNAME"] = "" # TODO: enter your Kaggle user name here122os.environ["KAGGLE_KEY"] = "" # TODO: enter your Kaggle key here123```124125Once the environment variables are set, run:126127```shell128$ kaggle datasets download -d spsayakpaul/bitresnet101x3flowers102129$ unzip -qq bitresnet101x3flowers102.zip130```131132This should generate a folder named `T-r101x3-128` which is essentially a teacher133[`SavedModel`](https://www.tensorflow.org/guide/saved_model).134"""135136os.environ["KAGGLE_USERNAME"] = "" # TODO: enter your Kaggle user name here137os.environ["KAGGLE_KEY"] = "" # TODO: enter your Kaggle API key here138139"""shell140!kaggle datasets download -d spsayakpaul/bitresnet101x3flowers102141"""142143"""shell144!unzip -qq bitresnet101x3flowers102.zip145"""146147# Since the teacher model is not going to be trained further we make148# it non-trainable.149teacher_model = keras.layers.TFSMLayer(150"/home/jupyter/keras-io/examples/keras_recipes/T-r101x3-128"151)152teacher_model.trainable = False153154"""155## The "function matching" recipe156157To train a high-quality student model, the authors propose the following changes to the158student training workflow:159160* Use an aggressive variant of MixUp ([Zhang et al.](https://arxiv.org/abs/1710.09412)).161This is done by sampling the `alpha` parameter from a uniform distribution instead of a162beta distribution. MixUp is used here in order to help the student model capture the163function underlying the teacher model. MixUp linearly interpolates between different164samples across the data manifold. So the rationale here is if the student is trained to165fit that it should be able to match the teacher model better. To incorporate more166invariance MixUp is coupled with "Inception-style" cropping167([Szegedy et al.](https://arxiv.org/abs/1409.4842)). This is where the168"function matching" term makes its way in the169[original paper](https://arxiv.org/abs/2106.05237).170* Unlike other works ([Noisy Student Training](https://arxiv.org/abs/1911.04252) for171example), both the teacher and student models receive the same copy of an image, which is172mixed up and randomly cropped. By providing the same inputs to both the models, the173authors make the teacher consistent with the student.174* With MixUp, we are essentially introducing a strong form of regularization when175training the student. As such, it should be trained for a176relatively long period of time (1000 epochs at least). Since the student is trained with177strong regularization, the risk of overfitting due to a longer training178schedule are also mitigated.179180In summary, one needs to be consistent and patient while training the student model.181"""182183"""184## Data input pipeline185"""186187188def mixup(images, labels):189alpha = tf.random.uniform([], 0, 1)190mixedup_images = alpha * images + (1 - alpha) * tf.reverse(images, axis=[0])191# The labels do not matter here since they are NOT used during192# training.193return mixedup_images, labels194195196def preprocess_image(image, label, train=True):197image = tf.cast(image, tf.float32) / 255.0198199if train:200image = tf.image.resize(image, (BIGGER, BIGGER))201image = tf.image.random_crop(image, (RESIZE, RESIZE, 3))202image = tf.image.random_flip_left_right(image)203else:204# Central fraction amount is from here:205# https://git.io/J8Kda.206image = tf.image.central_crop(image, central_fraction=0.875)207image = tf.image.resize(image, (RESIZE, RESIZE))208209return image, label210211212def prepare_dataset(dataset, train=True, batch_size=BATCH_SIZE):213if train:214dataset = dataset.map(preprocess_image, num_parallel_calls=AUTO)215dataset = dataset.shuffle(BATCH_SIZE * 10)216else:217dataset = dataset.map(218lambda x, y: (preprocess_image(x, y, train)), num_parallel_calls=AUTO219)220dataset = dataset.batch(batch_size)221222if train:223dataset = dataset.map(mixup, num_parallel_calls=AUTO)224225dataset = dataset.prefetch(AUTO)226return dataset227228229"""230Note that for brevity, we used mild crops for the training set but in practice231"Inception-style" preprocessing should be applied. You can refer to232[this script](https://github.com/sayakpaul/FunMatch-Distillation/blob/main/crop_resize.py)233for a closer implementation. Also, _**the ground-truth labels are not used for234training the student.**_235"""236237train_ds = prepare_dataset(train_ds, True)238validation_ds = prepare_dataset(validation_ds, False)239test_ds = prepare_dataset(test_ds, False)240241"""242## Visualization243"""244245sample_images, _ = next(iter(train_ds))246plt.figure(figsize=(10, 10))247for n in range(25):248ax = plt.subplot(5, 5, n + 1)249plt.imshow(sample_images[n].numpy())250plt.axis("off")251plt.show()252253"""254## Student model255256For the purpose of this example, we will use the standard ResNet50V2257([He et al.](https://arxiv.org/abs/1603.05027)).258"""259260261def get_resnetv2():262resnet_v2 = keras.applications.ResNet50V2(263weights=None,264input_shape=(RESIZE, RESIZE, 3),265classes=102,266classifier_activation="linear",267)268return resnet_v2269270271get_resnetv2().count_params()272273"""274Compared to the teacher model, this model has 358 Million fewer parameters.275"""276277"""278## Distillation utility279280We will reuse some code from281[this example](https://keras.io/examples/vision/knowledge_distillation/)282on knowledge distillation.283"""284285286class Distiller(tf.keras.Model):287def __init__(self, student, teacher):288super().__init__()289self.student = student290self.teacher = teacher291self.loss_tracker = keras.metrics.Mean(name="distillation_loss")292293@property294def metrics(self):295metrics = super().metrics296metrics.append(self.loss_tracker)297return metrics298299def compile(300self,301optimizer,302metrics,303distillation_loss_fn,304temperature=TEMPERATURE,305):306super().compile(optimizer=optimizer, metrics=metrics)307self.distillation_loss_fn = distillation_loss_fn308self.temperature = temperature309310def train_step(self, data):311# Unpack data312x, _ = data313314# Forward pass of teacher315teacher_predictions = self.teacher(x, training=False)316317with tf.GradientTape() as tape:318# Forward pass of student319student_predictions = self.student(x, training=True)320321# Compute loss322distillation_loss = self.distillation_loss_fn(323tf.nn.softmax(teacher_predictions / self.temperature, axis=1),324tf.nn.softmax(student_predictions / self.temperature, axis=1),325)326327# Compute gradients328trainable_vars = self.student.trainable_variables329gradients = tape.gradient(distillation_loss, trainable_vars)330331# Update weights332self.optimizer.apply_gradients(zip(gradients, trainable_vars))333334# Report progress335self.loss_tracker.update_state(distillation_loss)336return {"distillation_loss": self.loss_tracker.result()}337338def test_step(self, data):339# Unpack data340x, y = data341342# Forward passes343teacher_predictions = self.teacher(x, training=False)344student_predictions = self.student(x, training=False)345346# Calculate the loss347distillation_loss = self.distillation_loss_fn(348tf.nn.softmax(teacher_predictions / self.temperature, axis=1),349tf.nn.softmax(student_predictions / self.temperature, axis=1),350)351352# Report progress353self.loss_tracker.update_state(distillation_loss)354self.compiled_metrics.update_state(y, student_predictions)355results = {m.name: m.result() for m in self.metrics}356return results357358359"""360## Learning rate schedule361362A warmup cosine learning rate schedule is used in the paper. This schedule is also363typical for many pre-training methods especially for computer vision.364"""365366# Some code is taken from:367# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2.368369370class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):371def __init__(372self, learning_rate_base, total_steps, warmup_learning_rate, warmup_steps373):374super().__init__()375376self.learning_rate_base = learning_rate_base377self.total_steps = total_steps378self.warmup_learning_rate = warmup_learning_rate379self.warmup_steps = warmup_steps380self.pi = tf.constant(np.pi)381382def __call__(self, step):383if self.total_steps < self.warmup_steps:384raise ValueError("Total_steps must be larger or equal to warmup_steps.")385386cos_annealed_lr = tf.cos(387self.pi388* (tf.cast(step, tf.float32) - self.warmup_steps)389/ float(self.total_steps - self.warmup_steps)390)391learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr)392393if self.warmup_steps > 0:394if self.learning_rate_base < self.warmup_learning_rate:395raise ValueError(396"Learning_rate_base must be larger or equal to "397"warmup_learning_rate."398)399slope = (400self.learning_rate_base - self.warmup_learning_rate401) / self.warmup_steps402warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate403learning_rate = tf.where(404step < self.warmup_steps, warmup_rate, learning_rate405)406return tf.where(407step > self.total_steps, 0.0, learning_rate, name="learning_rate"408)409410411"""412We can now plot a a graph of learning rates generated using this schedule.413"""414415ARTIFICIAL_EPOCHS = 1000416ARTIFICIAL_BATCH_SIZE = 512417DATASET_NUM_TRAIN_EXAMPLES = 1020418TOTAL_STEPS = int(419DATASET_NUM_TRAIN_EXAMPLES / ARTIFICIAL_BATCH_SIZE * ARTIFICIAL_EPOCHS420)421scheduled_lrs = WarmUpCosine(422learning_rate_base=INIT_LR,423total_steps=TOTAL_STEPS,424warmup_learning_rate=0.0,425warmup_steps=1500,426)427428lrs = [scheduled_lrs(step) for step in range(TOTAL_STEPS)]429plt.plot(lrs)430plt.xlabel("Step", fontsize=14)431plt.ylabel("LR", fontsize=14)432plt.show()433434435"""436The original paper uses at least 1000 epochs and a batch size of 512 to perform437"function matching". The objective of this example is to present a workflow to438implement the recipe and not to demonstrate the results when they are applied at full scale.439However, these recipes will transfer to the original settings from the paper. Please440refer to [this repository](https://github.com/sayakpaul/FunMatch-Distillation) if you are441interested in finding out more.442"""443444"""445## Training446"""447448optimizer = keras.optimizers.AdamW(449weight_decay=WEIGHT_DECAY, learning_rate=scheduled_lrs, clipnorm=CLIP_THRESHOLD450)451452student_model = get_resnetv2()453454distiller = Distiller(student=student_model, teacher=teacher_model)455distiller.compile(456optimizer,457metrics=[keras.metrics.SparseCategoricalAccuracy()],458distillation_loss_fn=keras.losses.KLDivergence(),459temperature=TEMPERATURE,460)461462history = distiller.fit(463train_ds,464steps_per_epoch=int(np.ceil(DATASET_NUM_TRAIN_EXAMPLES / BATCH_SIZE)),465validation_data=validation_ds,466epochs=30, # This should be at least 1000.467)468469student = distiller.student470student_model.compile(metrics=["accuracy"])471_, top1_accuracy = student.evaluate(test_ds)472print(f"Top-1 accuracy on the test set: {round(top1_accuracy * 100, 2)}%")473474"""475## Results476477With just 30 epochs of training, the results are nowhere near expected.478This is where the benefits of patience aka a longer training schedule479will come into play. Let's investigate what the model trained for 1000 epochs can do.480"""481482"""shell483# Download the pre-trained weights.484!wget https://git.io/JBO3Y -O S-r50x1-128-1000.tar.gz485!tar xf S-r50x1-128-1000.tar.gz486"""487488pretrained_student = keras.layers.TFSMLayer("S-r50x1-128-1000")489490"""491This model exactly follows what the authors have used in their student models.492"""493494_, top1_accuracy = pretrained_student.evaluate(test_ds)495print(f"Top-1 accuracy on the test set: {round(top1_accuracy * 100, 2)}%")496497"""498With 100000 epochs of training, this same model leads to a top-1 accuracy of 95.54%.499500There are a number of important ablations studies presented in the paper that show the501effectiveness of these recipes compared to the prior art. So if you are skeptical about502these recipes, definitely consult the paper.503"""504505"""506## Note on training for longer507508With TPU-based hardware infrastructure, we can train the model for 1000 epochs faster.509This does not even require adding a lot of changes to this codebase. You510are encouraged to check511[this repository](https://github.com/sayakpaul/FunMatch-Distillation)512as it presents TPU-compatible training workflows for these recipes and can be run on513[Kaggle Kernel](https://www.kaggle.com/kernels) leveraging their free TPU v3-8 hardware.514"""515516517