Path: blob/master/examples/keras_recipes/ipynb/better_knowledge_distillation.ipynb
3508 views
Knowledge distillation recipes
Author: Sayak Paul
Date created: 2021/08/01
Last modified: 2021/08/01
Description: Training better student models via knowledge distillation with function matching.
Introduction
Knowledge distillation (Hinton et al.) is a technique that enables us to compress larger models into smaller ones. This allows us to reap the benefits of high performing larger models, while reducing storage and memory costs and achieving higher inference speed:
Smaller models -> smaller memory footprint
Reduced complexity -> fewer floating-point operations (FLOPs)
In Knowledge distillation: A good teacher is patient and consistent, Beyer et al. investigate various existing setups for performing knowledge distillation and show that all of them lead to sub-optimal performance. Due to this, practitioners often settle for other alternatives (quantization, pruning, weight clustering, etc.) when developing production systems that are resource-constrained.
Beyer et al. investigate how we can improve the student models that come out of the knowledge distillation process and always match the performance of their teacher models. In this example, we will study the recipes introduced by them, using the Flowers102 dataset. As a reference, with these recipes, the authors were able to produce a ResNet50 model that achieves 82.8% accuracy on the ImageNet-1k dataset.
In case you need a refresher on knowledge distillation and want to study how it is implemented in Keras, you can refer to this example. You can also follow this example that shows an extension of knowledge distillation applied to consistency training.
To follow this example, you will need TensorFlow 2.5 or higher as well as TensorFlow Addons, which can be installed using the command below:
Imports
Hyperparameters and constants
Load the Flowers102 dataset
Teacher model
As is common with any distillation technique, it's important to first train a well-performing teacher model which is usually larger than the subsequent student model. The authors distill a BiT ResNet152x2 model (teacher) into a BiT ResNet50 model (student).
BiT stands for Big Transfer and was introduced in Big Transfer (BiT): General Visual Representation Learning. BiT variants of ResNets use Group Normalization (Wu et al.) and Weight Standardization (Qiao et al.) in place of Batch Normalization (Ioffe et al.). In order to limit the time it takes to run this example, we will be using a BiT ResNet101x3 already trained on the Flowers102 dataset. You can refer to this notebook to learn more about the training process. This model reaches 98.18% accuracy on the test set of Flowers102.
The model weights are hosted on Kaggle as a dataset. To download the weights, follow these steps:
Create an account on Kaggle here.
Go to the "Account" tab of your user profile.
Select "Create API Token". This will trigger the download of
kaggle.json
, a file containing your API credentials.From that JSON file, copy your Kaggle username and API key.
Now run the following:
Once the environment variables are set, run:
This should generate a folder named T-r101x3-128
which is essentially a teacher SavedModel
.
The "function matching" recipe
To train a high-quality student model, the authors propose the following changes to the student training workflow:
Use an aggressive variant of MixUp (Zhang et al.). This is done by sampling the
alpha
parameter from a uniform distribution instead of a beta distribution. MixUp is used here in order to help the student model capture the function underlying the teacher model. MixUp linearly interpolates between different samples across the data manifold. So the rationale here is if the student is trained to fit that it should be able to match the teacher model better. To incorporate more invariance MixUp is coupled with "Inception-style" cropping (Szegedy et al.). This is where the "function matching" term makes its way in the original paper.Unlike other works (Noisy Student Training for example), both the teacher and student models receive the same copy of an image, which is mixed up and randomly cropped. By providing the same inputs to both the models, the authors make the teacher consistent with the student.
With MixUp, we are essentially introducing a strong form of regularization when training the student. As such, it should be trained for a relatively long period of time (1000 epochs at least). Since the student is trained with strong regularization, the risk of overfitting due to a longer training schedule are also mitigated.
In summary, one needs to be consistent and patient while training the student model.
Data input pipeline
Note that for brevity, we used mild crops for the training set but in practice "Inception-style" preprocessing should be applied. You can refer to this script for a closer implementation. Also, the ground-truth labels are not used for training the student.
Visualization
Student model
For the purpose of this example, we will use the standard ResNet50V2 (He et al.).
Compared to the teacher model, this model has 358 Million fewer parameters.
Distillation utility
We will reuse some code from this example on knowledge distillation.
Learning rate schedule
A warmup cosine learning rate schedule is used in the paper. This schedule is also typical for many pre-training methods especially for computer vision.
We can now plot a a graph of learning rates generated using this schedule.
The original paper uses at least 1000 epochs and a batch size of 512 to perform "function matching". The objective of this example is to present a workflow to implement the recipe and not to demonstrate the results when they are applied at full scale. However, these recipes will transfer to the original settings from the paper. Please refer to this repository if you are interested in finding out more.
Training
Results
With just 30 epochs of training, the results are nowhere near expected. This is where the benefits of patience aka a longer training schedule will come into play. Let's investigate what the model trained for 1000 epochs can do.
This model exactly follows what the authors have used in their student models. This is why the model summary is a bit different.
With 100000 epochs of training, this same model leads to a top-1 accuracy of 95.54%.
There are a number of important ablations studies presented in the paper that show the effectiveness of these recipes compared to the prior art. So if you are skeptical about these recipes, definitely consult the paper.
Note on training for longer
With TPU-based hardware infrastructure, we can train the model for 1000 epochs faster. This does not even require adding a lot of changes to this codebase. You are encouraged to check this repository as it presents TPU-compatible training workflows for these recipes and can be run on Kaggle Kernel leveraging their free TPU v3-8 hardware.