Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/keras_recipes/better_knowledge_distillation.py
3507 views
1
"""
2
Title: Knowledge distillation recipes
3
Author: [Sayak Paul](https://twitter.com/RisingSayak)
4
Date created: 2021/08/01
5
Last modified: 2021/08/01
6
Description: Training better student models via knowledge distillation with function matching.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
Knowledge distillation ([Hinton et al.](https://arxiv.org/abs/1503.02531)) is a technique
14
that enables us to compress larger models into smaller ones. This allows us to reap the
15
benefits of high performing larger models, while reducing storage and memory costs and
16
achieving higher inference speed:
17
18
* Smaller models -> smaller memory footprint
19
* Reduced complexity -> fewer floating-point operations (FLOPs)
20
21
In [Knowledge distillation: A good teacher is patient and consistent](https://arxiv.org/abs/2106.05237),
22
Beyer et al. investigate various existing setups for performing knowledge distillation
23
and show that all of them lead to sub-optimal performance. Due to this,
24
practitioners often settle for other alternatives (quantization, pruning, weight
25
clustering, etc.) when developing production systems that are resource-constrained.
26
27
Beyer et al. investigate how we can improve the student models that come out
28
of the knowledge distillation process and always match the performance of
29
their teacher models. In this example, we will study the recipes introduced by them, using
30
the [Flowers102 dataset](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/). As a
31
reference, with these recipes, the authors were able to produce a ResNet50 model that
32
achieves 82.8% accuracy on the ImageNet-1k dataset.
33
34
In case you need a refresher on knowledge distillation and want to study how it is
35
implemented in Keras, you can refer to
36
[this example](https://keras.io/examples/vision/knowledge_distillation/).
37
You can also follow
38
[this example](https://keras.io/examples/vision/consistency_training/)
39
that shows an extension of knowledge distillation applied to consistency training.
40
"""
41
42
"""
43
## Imports
44
"""
45
46
import os
47
48
os.environ["KERAS_BACKEND"] = "tensorflow"
49
50
import keras
51
import tensorflow as tf
52
53
import matplotlib.pyplot as plt
54
import numpy as np
55
56
import tensorflow_datasets as tfds
57
58
tfds.disable_progress_bar()
59
60
"""
61
## Hyperparameters and constants
62
"""
63
64
AUTO = tf.data.AUTOTUNE # Used to dynamically adjust parallelism.
65
BATCH_SIZE = 64
66
67
# Comes from Table 4 and "Training setup" section.
68
TEMPERATURE = 10 # Used to soften the logits before they go to softmax.
69
INIT_LR = 0.003 # Initial learning rate that will be decayed over the training period.
70
WEIGHT_DECAY = 0.001 # Used for regularization.
71
CLIP_THRESHOLD = 1.0 # Used for clipping the gradients by L2-norm.
72
73
# We will first resize the training images to a bigger size and then we will take
74
# random crops of a lower size.
75
BIGGER = 160
76
RESIZE = 128
77
78
"""
79
## Load the Flowers102 dataset
80
"""
81
82
train_ds, validation_ds, test_ds = tfds.load(
83
"oxford_flowers102", split=["train", "validation", "test"], as_supervised=True
84
)
85
print(f"Number of training examples: {train_ds.cardinality()}.")
86
print(f"Number of validation examples: {validation_ds.cardinality()}.")
87
print(f"Number of test examples: {test_ds.cardinality()}.")
88
89
"""
90
## Teacher model
91
92
As is common with any distillation technique, it's important to first train a
93
well-performing teacher model which is usually larger than the subsequent student model.
94
The authors distill a BiT ResNet152x2 model (teacher) into a BiT ResNet50 model
95
(student).
96
97
BiT stands for Big Transfer and was introduced in
98
[Big Transfer (BiT): General Visual Representation Learning](https://arxiv.org/abs/1912.11370).
99
BiT variants of ResNets use Group Normalization ([Wu et al.](https://arxiv.org/abs/1803.08494))
100
and Weight Standardization ([Qiao et al.](https://arxiv.org/abs/1903.10520v2))
101
in place of Batch Normalization ([Ioffe et al.](https://arxiv.org/abs/1502.03167)).
102
In order to limit the time it takes to run this example, we will be using a BiT
103
ResNet101x3 already trained on the Flowers102 dataset. You can refer to
104
[this notebook](https://github.com/sayakpaul/FunMatch-Distillation/blob/main/train_bit.ipynb)
105
to learn more about the training process. This model reaches 98.18% accuracy on the
106
test set of Flowers102.
107
108
The model weights are hosted on Kaggle as a dataset.
109
To download the weights, follow these steps:
110
111
1. Create an account on Kaggle [here](https://www.kaggle.com).
112
2. Go to the "Account" tab of your [user profile](https://www.kaggle.com/account).
113
3. Select "Create API Token". This will trigger the download of `kaggle.json`, a file
114
containing your API credentials.
115
4. From that JSON file, copy your Kaggle username and API key.
116
117
Now run the following:
118
119
```python
120
import os
121
122
os.environ["KAGGLE_USERNAME"] = "" # TODO: enter your Kaggle user name here
123
os.environ["KAGGLE_KEY"] = "" # TODO: enter your Kaggle key here
124
```
125
126
Once the environment variables are set, run:
127
128
```shell
129
$ kaggle datasets download -d spsayakpaul/bitresnet101x3flowers102
130
$ unzip -qq bitresnet101x3flowers102.zip
131
```
132
133
This should generate a folder named `T-r101x3-128` which is essentially a teacher
134
[`SavedModel`](https://www.tensorflow.org/guide/saved_model).
135
"""
136
137
os.environ["KAGGLE_USERNAME"] = "" # TODO: enter your Kaggle user name here
138
os.environ["KAGGLE_KEY"] = "" # TODO: enter your Kaggle API key here
139
140
"""shell
141
!kaggle datasets download -d spsayakpaul/bitresnet101x3flowers102
142
"""
143
144
"""shell
145
!unzip -qq bitresnet101x3flowers102.zip
146
"""
147
148
# Since the teacher model is not going to be trained further we make
149
# it non-trainable.
150
teacher_model = keras.layers.TFSMLayer(
151
"/home/jupyter/keras-io/examples/keras_recipes/T-r101x3-128"
152
)
153
teacher_model.trainable = False
154
155
"""
156
## The "function matching" recipe
157
158
To train a high-quality student model, the authors propose the following changes to the
159
student training workflow:
160
161
* Use an aggressive variant of MixUp ([Zhang et al.](https://arxiv.org/abs/1710.09412)).
162
This is done by sampling the `alpha` parameter from a uniform distribution instead of a
163
beta distribution. MixUp is used here in order to help the student model capture the
164
function underlying the teacher model. MixUp linearly interpolates between different
165
samples across the data manifold. So the rationale here is if the student is trained to
166
fit that it should be able to match the teacher model better. To incorporate more
167
invariance MixUp is coupled with "Inception-style" cropping
168
([Szegedy et al.](https://arxiv.org/abs/1409.4842)). This is where the
169
"function matching" term makes its way in the
170
[original paper](https://arxiv.org/abs/2106.05237).
171
* Unlike other works ([Noisy Student Training](https://arxiv.org/abs/1911.04252) for
172
example), both the teacher and student models receive the same copy of an image, which is
173
mixed up and randomly cropped. By providing the same inputs to both the models, the
174
authors make the teacher consistent with the student.
175
* With MixUp, we are essentially introducing a strong form of regularization when
176
training the student. As such, it should be trained for a
177
relatively long period of time (1000 epochs at least). Since the student is trained with
178
strong regularization, the risk of overfitting due to a longer training
179
schedule are also mitigated.
180
181
In summary, one needs to be consistent and patient while training the student model.
182
"""
183
184
"""
185
## Data input pipeline
186
"""
187
188
189
def mixup(images, labels):
190
alpha = tf.random.uniform([], 0, 1)
191
mixedup_images = alpha * images + (1 - alpha) * tf.reverse(images, axis=[0])
192
# The labels do not matter here since they are NOT used during
193
# training.
194
return mixedup_images, labels
195
196
197
def preprocess_image(image, label, train=True):
198
image = tf.cast(image, tf.float32) / 255.0
199
200
if train:
201
image = tf.image.resize(image, (BIGGER, BIGGER))
202
image = tf.image.random_crop(image, (RESIZE, RESIZE, 3))
203
image = tf.image.random_flip_left_right(image)
204
else:
205
# Central fraction amount is from here:
206
# https://git.io/J8Kda.
207
image = tf.image.central_crop(image, central_fraction=0.875)
208
image = tf.image.resize(image, (RESIZE, RESIZE))
209
210
return image, label
211
212
213
def prepare_dataset(dataset, train=True, batch_size=BATCH_SIZE):
214
if train:
215
dataset = dataset.map(preprocess_image, num_parallel_calls=AUTO)
216
dataset = dataset.shuffle(BATCH_SIZE * 10)
217
else:
218
dataset = dataset.map(
219
lambda x, y: (preprocess_image(x, y, train)), num_parallel_calls=AUTO
220
)
221
dataset = dataset.batch(batch_size)
222
223
if train:
224
dataset = dataset.map(mixup, num_parallel_calls=AUTO)
225
226
dataset = dataset.prefetch(AUTO)
227
return dataset
228
229
230
"""
231
Note that for brevity, we used mild crops for the training set but in practice
232
"Inception-style" preprocessing should be applied. You can refer to
233
[this script](https://github.com/sayakpaul/FunMatch-Distillation/blob/main/crop_resize.py)
234
for a closer implementation. Also, _**the ground-truth labels are not used for
235
training the student.**_
236
"""
237
238
train_ds = prepare_dataset(train_ds, True)
239
validation_ds = prepare_dataset(validation_ds, False)
240
test_ds = prepare_dataset(test_ds, False)
241
242
"""
243
## Visualization
244
"""
245
246
sample_images, _ = next(iter(train_ds))
247
plt.figure(figsize=(10, 10))
248
for n in range(25):
249
ax = plt.subplot(5, 5, n + 1)
250
plt.imshow(sample_images[n].numpy())
251
plt.axis("off")
252
plt.show()
253
254
"""
255
## Student model
256
257
For the purpose of this example, we will use the standard ResNet50V2
258
([He et al.](https://arxiv.org/abs/1603.05027)).
259
"""
260
261
262
def get_resnetv2():
263
resnet_v2 = keras.applications.ResNet50V2(
264
weights=None,
265
input_shape=(RESIZE, RESIZE, 3),
266
classes=102,
267
classifier_activation="linear",
268
)
269
return resnet_v2
270
271
272
get_resnetv2().count_params()
273
274
"""
275
Compared to the teacher model, this model has 358 Million fewer parameters.
276
"""
277
278
"""
279
## Distillation utility
280
281
We will reuse some code from
282
[this example](https://keras.io/examples/vision/knowledge_distillation/)
283
on knowledge distillation.
284
"""
285
286
287
class Distiller(tf.keras.Model):
288
def __init__(self, student, teacher):
289
super().__init__()
290
self.student = student
291
self.teacher = teacher
292
self.loss_tracker = keras.metrics.Mean(name="distillation_loss")
293
294
@property
295
def metrics(self):
296
metrics = super().metrics
297
metrics.append(self.loss_tracker)
298
return metrics
299
300
def compile(
301
self,
302
optimizer,
303
metrics,
304
distillation_loss_fn,
305
temperature=TEMPERATURE,
306
):
307
super().compile(optimizer=optimizer, metrics=metrics)
308
self.distillation_loss_fn = distillation_loss_fn
309
self.temperature = temperature
310
311
def train_step(self, data):
312
# Unpack data
313
x, _ = data
314
315
# Forward pass of teacher
316
teacher_predictions = self.teacher(x, training=False)
317
318
with tf.GradientTape() as tape:
319
# Forward pass of student
320
student_predictions = self.student(x, training=True)
321
322
# Compute loss
323
distillation_loss = self.distillation_loss_fn(
324
tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
325
tf.nn.softmax(student_predictions / self.temperature, axis=1),
326
)
327
328
# Compute gradients
329
trainable_vars = self.student.trainable_variables
330
gradients = tape.gradient(distillation_loss, trainable_vars)
331
332
# Update weights
333
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
334
335
# Report progress
336
self.loss_tracker.update_state(distillation_loss)
337
return {"distillation_loss": self.loss_tracker.result()}
338
339
def test_step(self, data):
340
# Unpack data
341
x, y = data
342
343
# Forward passes
344
teacher_predictions = self.teacher(x, training=False)
345
student_predictions = self.student(x, training=False)
346
347
# Calculate the loss
348
distillation_loss = self.distillation_loss_fn(
349
tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
350
tf.nn.softmax(student_predictions / self.temperature, axis=1),
351
)
352
353
# Report progress
354
self.loss_tracker.update_state(distillation_loss)
355
self.compiled_metrics.update_state(y, student_predictions)
356
results = {m.name: m.result() for m in self.metrics}
357
return results
358
359
360
"""
361
## Learning rate schedule
362
363
A warmup cosine learning rate schedule is used in the paper. This schedule is also
364
typical for many pre-training methods especially for computer vision.
365
"""
366
367
# Some code is taken from:
368
# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2.
369
370
371
class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
372
def __init__(
373
self, learning_rate_base, total_steps, warmup_learning_rate, warmup_steps
374
):
375
super().__init__()
376
377
self.learning_rate_base = learning_rate_base
378
self.total_steps = total_steps
379
self.warmup_learning_rate = warmup_learning_rate
380
self.warmup_steps = warmup_steps
381
self.pi = tf.constant(np.pi)
382
383
def __call__(self, step):
384
if self.total_steps < self.warmup_steps:
385
raise ValueError("Total_steps must be larger or equal to warmup_steps.")
386
387
cos_annealed_lr = tf.cos(
388
self.pi
389
* (tf.cast(step, tf.float32) - self.warmup_steps)
390
/ float(self.total_steps - self.warmup_steps)
391
)
392
learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr)
393
394
if self.warmup_steps > 0:
395
if self.learning_rate_base < self.warmup_learning_rate:
396
raise ValueError(
397
"Learning_rate_base must be larger or equal to "
398
"warmup_learning_rate."
399
)
400
slope = (
401
self.learning_rate_base - self.warmup_learning_rate
402
) / self.warmup_steps
403
warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate
404
learning_rate = tf.where(
405
step < self.warmup_steps, warmup_rate, learning_rate
406
)
407
return tf.where(
408
step > self.total_steps, 0.0, learning_rate, name="learning_rate"
409
)
410
411
412
"""
413
We can now plot a a graph of learning rates generated using this schedule.
414
"""
415
416
ARTIFICIAL_EPOCHS = 1000
417
ARTIFICIAL_BATCH_SIZE = 512
418
DATASET_NUM_TRAIN_EXAMPLES = 1020
419
TOTAL_STEPS = int(
420
DATASET_NUM_TRAIN_EXAMPLES / ARTIFICIAL_BATCH_SIZE * ARTIFICIAL_EPOCHS
421
)
422
scheduled_lrs = WarmUpCosine(
423
learning_rate_base=INIT_LR,
424
total_steps=TOTAL_STEPS,
425
warmup_learning_rate=0.0,
426
warmup_steps=1500,
427
)
428
429
lrs = [scheduled_lrs(step) for step in range(TOTAL_STEPS)]
430
plt.plot(lrs)
431
plt.xlabel("Step", fontsize=14)
432
plt.ylabel("LR", fontsize=14)
433
plt.show()
434
435
436
"""
437
The original paper uses at least 1000 epochs and a batch size of 512 to perform
438
"function matching". The objective of this example is to present a workflow to
439
implement the recipe and not to demonstrate the results when they are applied at full scale.
440
However, these recipes will transfer to the original settings from the paper. Please
441
refer to [this repository](https://github.com/sayakpaul/FunMatch-Distillation) if you are
442
interested in finding out more.
443
"""
444
445
"""
446
## Training
447
"""
448
449
optimizer = keras.optimizers.AdamW(
450
weight_decay=WEIGHT_DECAY, learning_rate=scheduled_lrs, clipnorm=CLIP_THRESHOLD
451
)
452
453
student_model = get_resnetv2()
454
455
distiller = Distiller(student=student_model, teacher=teacher_model)
456
distiller.compile(
457
optimizer,
458
metrics=[keras.metrics.SparseCategoricalAccuracy()],
459
distillation_loss_fn=keras.losses.KLDivergence(),
460
temperature=TEMPERATURE,
461
)
462
463
history = distiller.fit(
464
train_ds,
465
steps_per_epoch=int(np.ceil(DATASET_NUM_TRAIN_EXAMPLES / BATCH_SIZE)),
466
validation_data=validation_ds,
467
epochs=30, # This should be at least 1000.
468
)
469
470
student = distiller.student
471
student_model.compile(metrics=["accuracy"])
472
_, top1_accuracy = student.evaluate(test_ds)
473
print(f"Top-1 accuracy on the test set: {round(top1_accuracy * 100, 2)}%")
474
475
"""
476
## Results
477
478
With just 30 epochs of training, the results are nowhere near expected.
479
This is where the benefits of patience aka a longer training schedule
480
will come into play. Let's investigate what the model trained for 1000 epochs can do.
481
"""
482
483
"""shell
484
# Download the pre-trained weights.
485
!wget https://git.io/JBO3Y -O S-r50x1-128-1000.tar.gz
486
!tar xf S-r50x1-128-1000.tar.gz
487
"""
488
489
pretrained_student = keras.layers.TFSMLayer("S-r50x1-128-1000")
490
491
"""
492
This model exactly follows what the authors have used in their student models.
493
"""
494
495
_, top1_accuracy = pretrained_student.evaluate(test_ds)
496
print(f"Top-1 accuracy on the test set: {round(top1_accuracy * 100, 2)}%")
497
498
"""
499
With 100000 epochs of training, this same model leads to a top-1 accuracy of 95.54%.
500
501
There are a number of important ablations studies presented in the paper that show the
502
effectiveness of these recipes compared to the prior art. So if you are skeptical about
503
these recipes, definitely consult the paper.
504
"""
505
506
"""
507
## Note on training for longer
508
509
With TPU-based hardware infrastructure, we can train the model for 1000 epochs faster.
510
This does not even require adding a lot of changes to this codebase. You
511
are encouraged to check
512
[this repository](https://github.com/sayakpaul/FunMatch-Distillation)
513
as it presents TPU-compatible training workflows for these recipes and can be run on
514
[Kaggle Kernel](https://www.kaggle.com/kernels) leveraging their free TPU v3-8 hardware.
515
"""
516
517