Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/knowledge_distillation.py
3507 views
1
"""
2
Title: Knowledge Distillation
3
Author: [Kenneth Borup](https://twitter.com/Kennethborup)
4
Date created: 2020/09/01
5
Last modified: 2020/09/01
6
Description: Implementation of classical Knowledge Distillation.
7
Accelerator: GPU
8
Converted to Keras 3 by: [Md Awsafur Rahman](https://awsaf49.github.io)
9
"""
10
11
"""
12
## Introduction to Knowledge Distillation
13
14
Knowledge Distillation is a procedure for model
15
compression, in which a small (student) model is trained to match a large pre-trained
16
(teacher) model. Knowledge is transferred from the teacher model to the student
17
by minimizing a loss function, aimed at matching softened teacher logits as well as
18
ground-truth labels.
19
20
The logits are softened by applying a "temperature" scaling function in the softmax,
21
effectively smoothing out the probability distribution and revealing
22
inter-class relationships learned by the teacher.
23
24
**Reference:**
25
26
- [Hinton et al. (2015)](https://arxiv.org/abs/1503.02531)
27
"""
28
29
"""
30
## Setup
31
"""
32
33
import os
34
35
import keras
36
from keras import layers
37
from keras import ops
38
import numpy as np
39
40
"""
41
## Construct `Distiller()` class
42
43
The custom `Distiller()` class, overrides the `Model` methods `compile`, `compute_loss`,
44
and `call`. In order to use the distiller, we need:
45
46
- A trained teacher model
47
- A student model to train
48
- A student loss function on the difference between student predictions and ground-truth
49
- A distillation loss function, along with a `temperature`, on the difference between the
50
soft student predictions and the soft teacher labels
51
- An `alpha` factor to weight the student and distillation loss
52
- An optimizer for the student and (optional) metrics to evaluate performance
53
54
In the `compute_loss` method, we perform a forward pass of both the teacher and student,
55
calculate the loss with weighting of the `student_loss` and `distillation_loss` by `alpha`
56
and `1 - alpha`, respectively. Note: only the student weights are updated.
57
"""
58
59
60
class Distiller(keras.Model):
61
def __init__(self, student, teacher):
62
super().__init__()
63
self.teacher = teacher
64
self.student = student
65
66
def compile(
67
self,
68
optimizer,
69
metrics,
70
student_loss_fn,
71
distillation_loss_fn,
72
alpha=0.1,
73
temperature=3,
74
):
75
"""Configure the distiller.
76
77
Args:
78
optimizer: Keras optimizer for the student weights
79
metrics: Keras metrics for evaluation
80
student_loss_fn: Loss function of difference between student
81
predictions and ground-truth
82
distillation_loss_fn: Loss function of difference between soft
83
student predictions and soft teacher predictions
84
alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
85
temperature: Temperature for softening probability distributions.
86
Larger temperature gives softer distributions.
87
"""
88
super().compile(optimizer=optimizer, metrics=metrics)
89
self.student_loss_fn = student_loss_fn
90
self.distillation_loss_fn = distillation_loss_fn
91
self.alpha = alpha
92
self.temperature = temperature
93
94
def compute_loss(
95
self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False
96
):
97
teacher_pred = self.teacher(x, training=False)
98
student_loss = self.student_loss_fn(y, y_pred)
99
100
distillation_loss = self.distillation_loss_fn(
101
ops.softmax(teacher_pred / self.temperature, axis=1),
102
ops.softmax(y_pred / self.temperature, axis=1),
103
) * (self.temperature**2)
104
105
loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
106
return loss
107
108
def call(self, x):
109
return self.student(x)
110
111
112
"""
113
## Create student and teacher models
114
115
Initialy, we create a teacher model and a smaller student model. Both models are
116
convolutional neural networks and created using `Sequential()`,
117
but could be any Keras model.
118
"""
119
120
# Create the teacher
121
teacher = keras.Sequential(
122
[
123
keras.Input(shape=(28, 28, 1)),
124
layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),
125
layers.LeakyReLU(negative_slope=0.2),
126
layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
127
layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),
128
layers.Flatten(),
129
layers.Dense(10),
130
],
131
name="teacher",
132
)
133
134
# Create the student
135
student = keras.Sequential(
136
[
137
keras.Input(shape=(28, 28, 1)),
138
layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
139
layers.LeakyReLU(negative_slope=0.2),
140
layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
141
layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
142
layers.Flatten(),
143
layers.Dense(10),
144
],
145
name="student",
146
)
147
148
# Clone student for later comparison
149
student_scratch = keras.models.clone_model(student)
150
151
"""
152
## Prepare the dataset
153
154
The dataset used for training the teacher and distilling the teacher is
155
[MNIST](https://keras.io/api/datasets/mnist/), and the procedure would be equivalent for
156
any other
157
dataset, e.g. [CIFAR-10](https://keras.io/api/datasets/cifar10/), with a suitable choice
158
of models. Both the student and teacher are trained on the training set and evaluated on
159
the test set.
160
"""
161
162
# Prepare the train and test dataset.
163
batch_size = 64
164
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
165
166
# Normalize data
167
x_train = x_train.astype("float32") / 255.0
168
x_train = np.reshape(x_train, (-1, 28, 28, 1))
169
170
x_test = x_test.astype("float32") / 255.0
171
x_test = np.reshape(x_test, (-1, 28, 28, 1))
172
173
174
"""
175
## Train the teacher
176
177
In knowledge distillation we assume that the teacher is trained and fixed. Thus, we start
178
by training the teacher model on the training set in the usual way.
179
"""
180
181
# Train teacher as usual
182
teacher.compile(
183
optimizer=keras.optimizers.Adam(),
184
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
185
metrics=[keras.metrics.SparseCategoricalAccuracy()],
186
)
187
188
# Train and evaluate teacher on data.
189
teacher.fit(x_train, y_train, epochs=5)
190
teacher.evaluate(x_test, y_test)
191
192
"""
193
## Distill teacher to student
194
195
We have already trained the teacher model, and we only need to initialize a
196
`Distiller(student, teacher)` instance, `compile()` it with the desired losses,
197
hyperparameters and optimizer, and distill the teacher to the student.
198
"""
199
200
# Initialize and compile distiller
201
distiller = Distiller(student=student, teacher=teacher)
202
distiller.compile(
203
optimizer=keras.optimizers.Adam(),
204
metrics=[keras.metrics.SparseCategoricalAccuracy()],
205
student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
206
distillation_loss_fn=keras.losses.KLDivergence(),
207
alpha=0.1,
208
temperature=10,
209
)
210
211
# Distill teacher to student
212
distiller.fit(x_train, y_train, epochs=3)
213
214
# Evaluate student on test dataset
215
distiller.evaluate(x_test, y_test)
216
217
"""
218
## Train student from scratch for comparison
219
220
We can also train an equivalent student model from scratch without the teacher, in order
221
to evaluate the performance gain obtained by knowledge distillation.
222
"""
223
224
# Train student as doen usually
225
student_scratch.compile(
226
optimizer=keras.optimizers.Adam(),
227
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
228
metrics=[keras.metrics.SparseCategoricalAccuracy()],
229
)
230
231
# Train and evaluate student trained from scratch.
232
student_scratch.fit(x_train, y_train, epochs=3)
233
student_scratch.evaluate(x_test, y_test)
234
235
"""
236
If the teacher is trained for 5 full epochs and the student is distilled on this teacher
237
for 3 full epochs, you should in this example experience a performance boost compared to
238
training the same student model from scratch, and even compared to the teacher itself.
239
You should expect the teacher to have accuracy around 97.6%, the student trained from
240
scratch should be around 97.6%, and the distilled student should be around 98.1%. Remove
241
or try out different seeds to use different weight initializations.
242
"""
243
244