Path: blob/master/examples/vision/md/knowledge_distillation.md
3508 views
Knowledge Distillation
Author: Kenneth Borup
Date created: 2020/09/01
Last modified: 2020/09/01
Description: Implementation of classical Knowledge Distillation.
Introduction to Knowledge Distillation
Knowledge Distillation is a procedure for model compression, in which a small (student) model is trained to match a large pre-trained (teacher) model. Knowledge is transferred from the teacher model to the student by minimizing a loss function, aimed at matching softened teacher logits as well as ground-truth labels.
The logits are softened by applying a "temperature" scaling function in the softmax, effectively smoothing out the probability distribution and revealing inter-class relationships learned by the teacher.
Reference:
Setup
Construct Distiller()
class
The custom Distiller()
class, overrides the Model
methods compile
, compute_loss
, and call
. In order to use the distiller, we need:
A trained teacher model
A student model to train
A student loss function on the difference between student predictions and ground-truth
A distillation loss function, along with a
temperature
, on the difference between the soft student predictions and the soft teacher labelsAn
alpha
factor to weight the student and distillation lossAn optimizer for the student and (optional) metrics to evaluate performance
In the compute_loss
method, we perform a forward pass of both the teacher and student, calculate the loss with weighting of the student_loss
and distillation_loss
by alpha
and 1 - alpha
, respectively. Note: only the student weights are updated.
Create student and teacher models
Initialy, we create a teacher model and a smaller student model. Both models are convolutional neural networks and created using Sequential()
, but could be any Keras model.
Prepare the dataset
The dataset used for training the teacher and distilling the teacher is MNIST, and the procedure would be equivalent for any other dataset, e.g. CIFAR-10, with a suitable choice of models. Both the student and teacher are trained on the training set and evaluated on the test set.
Train the teacher
In knowledge distillation we assume that the teacher is trained and fixed. Thus, we start by training the teacher model on the training set in the usual way.
[0.09044107794761658, 0.978100061416626]
[0.017046602442860603, 0.969200074672699]
[0.0629437193274498, 0.9778000712394714]