Path: blob/master/examples/vision/knowledge_distillation.py
3507 views
"""1Title: Knowledge Distillation2Author: [Kenneth Borup](https://twitter.com/Kennethborup)3Date created: 2020/09/014Last modified: 2020/09/015Description: Implementation of classical Knowledge Distillation.6Accelerator: GPU7Converted to Keras 3 by: [Md Awsafur Rahman](https://awsaf49.github.io)8"""910"""11## Introduction to Knowledge Distillation1213Knowledge Distillation is a procedure for model14compression, in which a small (student) model is trained to match a large pre-trained15(teacher) model. Knowledge is transferred from the teacher model to the student16by minimizing a loss function, aimed at matching softened teacher logits as well as17ground-truth labels.1819The logits are softened by applying a "temperature" scaling function in the softmax,20effectively smoothing out the probability distribution and revealing21inter-class relationships learned by the teacher.2223**Reference:**2425- [Hinton et al. (2015)](https://arxiv.org/abs/1503.02531)26"""2728"""29## Setup30"""3132import os3334import keras35from keras import layers36from keras import ops37import numpy as np3839"""40## Construct `Distiller()` class4142The custom `Distiller()` class, overrides the `Model` methods `compile`, `compute_loss`,43and `call`. In order to use the distiller, we need:4445- A trained teacher model46- A student model to train47- A student loss function on the difference between student predictions and ground-truth48- A distillation loss function, along with a `temperature`, on the difference between the49soft student predictions and the soft teacher labels50- An `alpha` factor to weight the student and distillation loss51- An optimizer for the student and (optional) metrics to evaluate performance5253In the `compute_loss` method, we perform a forward pass of both the teacher and student,54calculate the loss with weighting of the `student_loss` and `distillation_loss` by `alpha`55and `1 - alpha`, respectively. Note: only the student weights are updated.56"""575859class Distiller(keras.Model):60def __init__(self, student, teacher):61super().__init__()62self.teacher = teacher63self.student = student6465def compile(66self,67optimizer,68metrics,69student_loss_fn,70distillation_loss_fn,71alpha=0.1,72temperature=3,73):74"""Configure the distiller.7576Args:77optimizer: Keras optimizer for the student weights78metrics: Keras metrics for evaluation79student_loss_fn: Loss function of difference between student80predictions and ground-truth81distillation_loss_fn: Loss function of difference between soft82student predictions and soft teacher predictions83alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn84temperature: Temperature for softening probability distributions.85Larger temperature gives softer distributions.86"""87super().compile(optimizer=optimizer, metrics=metrics)88self.student_loss_fn = student_loss_fn89self.distillation_loss_fn = distillation_loss_fn90self.alpha = alpha91self.temperature = temperature9293def compute_loss(94self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False95):96teacher_pred = self.teacher(x, training=False)97student_loss = self.student_loss_fn(y, y_pred)9899distillation_loss = self.distillation_loss_fn(100ops.softmax(teacher_pred / self.temperature, axis=1),101ops.softmax(y_pred / self.temperature, axis=1),102) * (self.temperature**2)103104loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss105return loss106107def call(self, x):108return self.student(x)109110111"""112## Create student and teacher models113114Initialy, we create a teacher model and a smaller student model. Both models are115convolutional neural networks and created using `Sequential()`,116but could be any Keras model.117"""118119# Create the teacher120teacher = keras.Sequential(121[122keras.Input(shape=(28, 28, 1)),123layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),124layers.LeakyReLU(negative_slope=0.2),125layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),126layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),127layers.Flatten(),128layers.Dense(10),129],130name="teacher",131)132133# Create the student134student = keras.Sequential(135[136keras.Input(shape=(28, 28, 1)),137layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),138layers.LeakyReLU(negative_slope=0.2),139layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),140layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),141layers.Flatten(),142layers.Dense(10),143],144name="student",145)146147# Clone student for later comparison148student_scratch = keras.models.clone_model(student)149150"""151## Prepare the dataset152153The dataset used for training the teacher and distilling the teacher is154[MNIST](https://keras.io/api/datasets/mnist/), and the procedure would be equivalent for155any other156dataset, e.g. [CIFAR-10](https://keras.io/api/datasets/cifar10/), with a suitable choice157of models. Both the student and teacher are trained on the training set and evaluated on158the test set.159"""160161# Prepare the train and test dataset.162batch_size = 64163(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()164165# Normalize data166x_train = x_train.astype("float32") / 255.0167x_train = np.reshape(x_train, (-1, 28, 28, 1))168169x_test = x_test.astype("float32") / 255.0170x_test = np.reshape(x_test, (-1, 28, 28, 1))171172173"""174## Train the teacher175176In knowledge distillation we assume that the teacher is trained and fixed. Thus, we start177by training the teacher model on the training set in the usual way.178"""179180# Train teacher as usual181teacher.compile(182optimizer=keras.optimizers.Adam(),183loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),184metrics=[keras.metrics.SparseCategoricalAccuracy()],185)186187# Train and evaluate teacher on data.188teacher.fit(x_train, y_train, epochs=5)189teacher.evaluate(x_test, y_test)190191"""192## Distill teacher to student193194We have already trained the teacher model, and we only need to initialize a195`Distiller(student, teacher)` instance, `compile()` it with the desired losses,196hyperparameters and optimizer, and distill the teacher to the student.197"""198199# Initialize and compile distiller200distiller = Distiller(student=student, teacher=teacher)201distiller.compile(202optimizer=keras.optimizers.Adam(),203metrics=[keras.metrics.SparseCategoricalAccuracy()],204student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),205distillation_loss_fn=keras.losses.KLDivergence(),206alpha=0.1,207temperature=10,208)209210# Distill teacher to student211distiller.fit(x_train, y_train, epochs=3)212213# Evaluate student on test dataset214distiller.evaluate(x_test, y_test)215216"""217## Train student from scratch for comparison218219We can also train an equivalent student model from scratch without the teacher, in order220to evaluate the performance gain obtained by knowledge distillation.221"""222223# Train student as doen usually224student_scratch.compile(225optimizer=keras.optimizers.Adam(),226loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),227metrics=[keras.metrics.SparseCategoricalAccuracy()],228)229230# Train and evaluate student trained from scratch.231student_scratch.fit(x_train, y_train, epochs=3)232student_scratch.evaluate(x_test, y_test)233234"""235If the teacher is trained for 5 full epochs and the student is distilled on this teacher236for 3 full epochs, you should in this example experience a performance boost compared to237training the same student model from scratch, and even compared to the teacher itself.238You should expect the teacher to have accuracy around 97.6%, the student trained from239scratch should be around 97.6%, and the distilled student should be around 98.1%. Remove240or try out different seeds to use different weight initializations.241"""242243244