Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/deit.py
8412 views
1
"""
2
Title: Distilling Vision Transformers
3
Author: [Sayak Paul](https://twitter.com/RisingSayak)
4
Date created: 2022/04/05
5
Last modified: 2026/02/10
6
Description: Distillation of Vision Transformers through attention.
7
Accelerator: GPU
8
Converted to Keras 3 by: [Harshith K](https://github.com/kharshith-k/)
9
"""
10
11
"""
12
## Introduction
13
14
In the original *Vision Transformers* (ViT) paper
15
([Dosovitskiy et al.](https://arxiv.org/abs/2010.11929)),
16
the authors concluded that to perform on par with Convolutional Neural Networks (CNNs),
17
ViTs need to be pre-trained on larger datasets. The larger the better. This is mainly
18
due to the lack of inductive biases in the ViT architecture -- unlike CNNs,
19
they don't have layers that exploit locality. In a follow-up paper
20
([Steiner et al.](https://arxiv.org/abs/2106.10270)),
21
the authors show that it is possible to substantially improve the performance of ViTs
22
with stronger regularization and longer training.
23
24
Many groups have proposed different ways to deal with the problem
25
of data-intensiveness of ViT training.
26
One such way was shown in the *Data-efficient image Transformers*,
27
(DeiT) paper ([Touvron et al.](https://arxiv.org/abs/2012.12877)). The
28
authors introduced a distillation technique that is specific to transformer-based vision
29
models. DeiT is among the first works to show that it's possible to train ViTs well
30
without using larger datasets.
31
32
In this example, we implement the distillation recipe proposed in DeiT. This
33
requires us to slightly tweak the original ViT architecture and write a custom training
34
loop to implement the distillation recipe.
35
36
To comfortably navigate through this example, you'll be expected to know how a ViT and
37
knowledge distillation work. The following are good resources in case you needed a
38
refresher:
39
40
* [ViT on keras.io](https://keras.io/examples/vision/image_classification_with_vision_transformer)
41
* [Knowledge distillation on keras.io](https://keras.io/examples/vision/knowledge_distillation/)
42
"""
43
44
"""
45
## Imports
46
"""
47
48
from typing import List
49
50
import tensorflow as tf
51
import tensorflow_datasets as tfds
52
import keras
53
from keras import layers
54
55
tfds.disable_progress_bar()
56
keras.utils.set_random_seed(42)
57
58
"""
59
## Constants
60
"""
61
62
# Model
63
MODEL_TYPE = "deit_distilled_tiny_patch16_224"
64
RESOLUTION = 224
65
PATCH_SIZE = 16
66
NUM_PATCHES = (RESOLUTION // PATCH_SIZE) ** 2
67
LAYER_NORM_EPS = 1e-6
68
PROJECTION_DIM = 192
69
NUM_HEADS = 3
70
NUM_LAYERS = 12
71
MLP_UNITS = [
72
PROJECTION_DIM * 4,
73
PROJECTION_DIM,
74
]
75
DROPOUT_RATE = 0.0
76
DROP_PATH_RATE = 0.1
77
78
# Training
79
NUM_EPOCHS = 20
80
BASE_LR = 0.0005
81
WEIGHT_DECAY = 0.0001
82
83
# Data
84
BATCH_SIZE = 256
85
AUTO = tf.data.AUTOTUNE
86
NUM_CLASSES = 5
87
88
"""
89
You probably noticed that `DROPOUT_RATE` has been set 0.0. Dropout has been used
90
in the implementation to keep it complete. For smaller models (like the one used in
91
this example), you don't need it, but for bigger models, using dropout helps.
92
"""
93
94
"""
95
## Load the `tf_flowers` dataset and prepare preprocessing utilities
96
97
The authors use an array of different augmentation techniques, including MixUp
98
([Zhang et al.](https://arxiv.org/abs/1710.09412)),
99
RandAugment ([Cubuk et al.](https://arxiv.org/abs/1909.13719)),
100
and so on. However, to keep the example simple to work through, we'll discard them.
101
"""
102
103
104
def preprocess_dataset(is_training=True):
105
def fn(image, label):
106
if is_training:
107
# Resize to a bigger spatial resolution and take the random
108
# crops.
109
image = keras.ops.image.resize(image, (RESOLUTION + 20, RESOLUTION + 20))
110
# Perform random crop using TensorFlow ops for graph compatibility
111
# Get random crop coordinates (0 to 20 pixels offset)
112
crop_top = tf.random.uniform((), 0, 21, dtype=tf.int32)
113
crop_left = tf.random.uniform((), 0, 21, dtype=tf.int32)
114
image = tf.image.crop_to_bounding_box(
115
image,
116
offset_height=crop_top,
117
offset_width=crop_left,
118
target_height=RESOLUTION,
119
target_width=RESOLUTION,
120
)
121
# Random horizontal flip
122
if tf.random.uniform(()) > 0.5:
123
image = tf.image.flip_left_right(image)
124
else:
125
image = keras.ops.image.resize(image, (RESOLUTION, RESOLUTION))
126
label = keras.ops.one_hot(label, num_classes=NUM_CLASSES)
127
return image, label
128
129
return fn
130
131
132
def prepare_dataset(dataset, is_training=True):
133
if is_training:
134
dataset = dataset.shuffle(BATCH_SIZE * 10)
135
dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=AUTO)
136
return dataset.batch(BATCH_SIZE).prefetch(AUTO)
137
138
139
train_dataset, val_dataset = tfds.load(
140
"tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True
141
)
142
num_train = train_dataset.cardinality()
143
num_val = val_dataset.cardinality()
144
print(f"Number of training examples: {num_train}")
145
print(f"Number of validation examples: {num_val}")
146
147
train_dataset = prepare_dataset(train_dataset, is_training=True)
148
val_dataset = prepare_dataset(val_dataset, is_training=False)
149
150
"""
151
## Implementing the DeiT variants of ViT
152
153
Since DeiT is an extension of ViT it'd make sense to first implement ViT and then extend
154
it to support DeiT's components.
155
156
First, we'll implement a layer for Stochastic Depth
157
([Huang et al.](https://arxiv.org/abs/1603.09382))
158
which is used in DeiT for regularization.
159
"""
160
161
162
# Referred from: github.com:rwightman/pytorch-image-models.
163
class StochasticDepth(layers.Layer):
164
def __init__(self, drop_prop, **kwargs):
165
super().__init__(**kwargs)
166
self.drop_prob = drop_prop
167
self.seed_generator = keras.random.SeedGenerator(1337)
168
169
def call(self, x, training=True):
170
if training:
171
keep_prob = 1 - self.drop_prob
172
shape = (keras.ops.shape(x)[0],) + (1,) * (len(keras.ops.shape(x)) - 1)
173
random_tensor = keep_prob + keras.random.uniform(
174
shape, 0, 1, seed=self.seed_generator
175
)
176
random_tensor = keras.ops.floor(random_tensor)
177
return (x / keep_prob) * random_tensor
178
return x
179
180
181
"""
182
Now, we'll implement the MLP and Transformer blocks.
183
"""
184
185
186
def mlp(x, dropout_rate: float, hidden_units: List):
187
"""FFN for a Transformer block."""
188
# Iterate over the hidden units and
189
# add Dense => Dropout.
190
for idx, units in enumerate(hidden_units):
191
x = layers.Dense(
192
units,
193
activation="gelu" if idx == 0 else None,
194
)(x)
195
x = layers.Dropout(dropout_rate)(x)
196
return x
197
198
199
def transformer(drop_prob: float, name: str) -> keras.Model:
200
"""Transformer block with pre-norm."""
201
num_patches = NUM_PATCHES + 2 if "distilled" in MODEL_TYPE else NUM_PATCHES + 1
202
encoded_patches = layers.Input((num_patches, PROJECTION_DIM))
203
204
# Layer normalization 1.
205
x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)
206
207
# Multi Head Self Attention layer 1.
208
attention_output = layers.MultiHeadAttention(
209
num_heads=NUM_HEADS,
210
key_dim=PROJECTION_DIM,
211
dropout=DROPOUT_RATE,
212
)(x1, x1)
213
attention_output = (
214
StochasticDepth(drop_prob)(attention_output) if drop_prob else attention_output
215
)
216
217
# Skip connection 1.
218
x2 = layers.Add()([attention_output, encoded_patches])
219
220
# Layer normalization 2.
221
x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)
222
223
# MLP layer 1.
224
x4 = mlp(x3, hidden_units=MLP_UNITS, dropout_rate=DROPOUT_RATE)
225
x4 = StochasticDepth(drop_prob)(x4) if drop_prob else x4
226
227
# Skip connection 2.
228
outputs = layers.Add()([x2, x4])
229
230
return keras.Model(encoded_patches, outputs, name=name)
231
232
233
"""
234
We'll now implement a `ViTClassifier` class building on top of the components we just
235
developed. Here we'll be following the original pooling strategy used in the ViT paper --
236
use a class token and use the feature representations corresponding to it for
237
classification.
238
"""
239
240
241
class ViTClassifier(keras.Model):
242
"""Vision Transformer base class."""
243
244
def __init__(self, **kwargs):
245
super().__init__(**kwargs)
246
247
# Patchify + linear projection + reshaping.
248
self.projection = keras.Sequential(
249
[
250
layers.Conv2D(
251
filters=PROJECTION_DIM,
252
kernel_size=(PATCH_SIZE, PATCH_SIZE),
253
strides=(PATCH_SIZE, PATCH_SIZE),
254
padding="VALID",
255
name="conv_projection",
256
),
257
layers.Reshape(
258
target_shape=(NUM_PATCHES, PROJECTION_DIM),
259
name="flatten_projection",
260
),
261
],
262
name="projection",
263
)
264
265
# Transformer blocks.
266
dpr = [x for x in keras.ops.linspace(0.0, DROP_PATH_RATE, NUM_LAYERS)]
267
self.transformer_blocks = [
268
transformer(drop_prob=dpr[i], name=f"transformer_block_{i}")
269
for i in range(NUM_LAYERS)
270
]
271
272
# Other layers.
273
self.dropout = layers.Dropout(DROPOUT_RATE)
274
self.layer_norm = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)
275
self.head = layers.Dense(
276
NUM_CLASSES,
277
name="classification_head",
278
)
279
280
def build(self, input_shape):
281
# Positional embedding.
282
self.positional_embedding = self.add_weight(
283
shape=(1, NUM_PATCHES + 1, PROJECTION_DIM),
284
initializer=keras.initializers.Zeros(),
285
trainable=True,
286
name="pos_embedding",
287
)
288
289
# CLS token.
290
self.cls_token = self.add_weight(
291
shape=(1, 1, PROJECTION_DIM),
292
initializer=keras.initializers.Zeros(),
293
trainable=True,
294
name="cls",
295
)
296
super().build(input_shape)
297
298
def call(self, inputs, training=True):
299
n = keras.ops.shape(inputs)[0]
300
301
# Create patches and project the patches.
302
projected_patches = self.projection(inputs)
303
cls_token = keras.ops.tile(self.cls_token, (n, 1, 1))
304
cls_token = keras.ops.cast(cls_token, projected_patches.dtype)
305
projected_patches = keras.ops.concatenate(
306
[cls_token, projected_patches], axis=1
307
)
308
309
# Add positional embeddings to the projected patches.
310
encoded_patches = (
311
self.positional_embedding + projected_patches
312
) # (B, number_patches, projection_dim)
313
encoded_patches = self.dropout(encoded_patches)
314
315
# Iterate over the number of layers and stack up blocks of
316
# Transformer.
317
for transformer_module in self.transformer_blocks:
318
# Add a Transformer block.
319
encoded_patches = transformer_module(encoded_patches)
320
321
# Final layer normalization.
322
representation = self.layer_norm(encoded_patches)
323
324
# Pool representation.
325
encoded_patches = representation[:, 0]
326
327
# Classification head.
328
329
output = self.head(encoded_patches)
330
331
return output
332
333
334
"""
335
This class can be used standalone as ViT and is end-to-end trainable. Just remove the
336
`distilled` phrase in `MODEL_TYPE` and it should work with `vit_tiny = ViTClassifier()`.
337
Let's now extend it to DeiT. The following figure presents the schematic of DeiT (taken
338
from the DeiT paper):
339
340
![](https://i.imgur.com/5lmg2Xs.png)
341
342
Apart from the class token, DeiT has another token for distillation. During distillation,
343
the logits corresponding to the class token are compared to the true labels, and the
344
logits corresponding to the distillation token are compared to the teacher's predictions.
345
"""
346
347
348
class ViTDistilled(ViTClassifier):
349
def __init__(self, regular_training=False, **kwargs):
350
super().__init__(**kwargs)
351
self.num_tokens = 2
352
self.regular_training = regular_training
353
354
# Head layers.
355
self.head = layers.Dense(
356
NUM_CLASSES,
357
name="classification_head",
358
)
359
self.head_dist = layers.Dense(
360
NUM_CLASSES,
361
name="distillation_head",
362
)
363
364
def build(self, input_shape):
365
# CLS token.
366
self.cls_token = self.add_weight(
367
shape=(1, 1, PROJECTION_DIM),
368
initializer=keras.initializers.Zeros(),
369
trainable=True,
370
name="cls",
371
)
372
373
# Distillation token.
374
self.dist_token = self.add_weight(
375
shape=(1, 1, PROJECTION_DIM),
376
initializer=keras.initializers.Zeros(),
377
trainable=True,
378
name="dist_token",
379
)
380
381
# Positional embedding (for NUM_PATCHES + 2 tokens: cls + dist).
382
self.positional_embedding = self.add_weight(
383
shape=(1, NUM_PATCHES + self.num_tokens, PROJECTION_DIM),
384
initializer=keras.initializers.Zeros(),
385
trainable=True,
386
name="pos_embedding",
387
)
388
389
def call(self, inputs, training=True):
390
n = keras.ops.shape(inputs)[0]
391
392
# Create patches and project the patches.
393
projected_patches = self.projection(inputs)
394
395
# Append the tokens.
396
cls_token = keras.ops.tile(self.cls_token, (n, 1, 1))
397
dist_token = keras.ops.tile(self.dist_token, (n, 1, 1))
398
cls_token = keras.ops.cast(cls_token, projected_patches.dtype)
399
dist_token = keras.ops.cast(dist_token, projected_patches.dtype)
400
projected_patches = keras.ops.concatenate(
401
[cls_token, dist_token, projected_patches], axis=1
402
)
403
404
# Add positional embeddings to the projected patches.
405
encoded_patches = (
406
self.positional_embedding + projected_patches
407
) # (B, number_patches, projection_dim)
408
encoded_patches = self.dropout(encoded_patches)
409
410
# Iterate over the number of layers and stack up blocks of
411
# Transformer.
412
for transformer_module in self.transformer_blocks:
413
# Add a Transformer block.
414
encoded_patches = transformer_module(encoded_patches)
415
416
# Final layer normalization.
417
representation = self.layer_norm(encoded_patches)
418
419
# Classification heads.
420
x, x_dist = (
421
self.head(representation[:, 0]),
422
self.head_dist(representation[:, 1]),
423
)
424
425
# Only return separate classification predictions when training in distilled
426
# mode.
427
if training and not self.regular_training:
428
return x, x_dist
429
# During standard train / finetune, inference average the classifier
430
# predictions.
431
return (x + x_dist) / 2
432
433
434
"""
435
Let's verify if the `ViTDistilled` class can be initialized and called as expected.
436
"""
437
438
deit_tiny_distilled = ViTDistilled()
439
440
dummy_inputs = tf.ones((2, 224, 224, 3))
441
outputs = deit_tiny_distilled(dummy_inputs, training=False)
442
print(outputs.shape)
443
444
"""
445
## Implementing the trainer
446
447
Unlike what happens in standard knowledge distillation
448
([Hinton et al.](https://arxiv.org/abs/1503.02531)),
449
where a temperature-scaled softmax is used as well as KL divergence,
450
DeiT authors use the following loss function:
451
452
![](https://i.imgur.com/bXdxsBq.png)
453
454
455
Here,
456
457
* CE is cross-entropy
458
* `psi` is the softmax function
459
* Z_s denotes student predictions
460
* y denotes true labels
461
* y_t denotes teacher predictions
462
"""
463
464
465
class DeiT(keras.Model):
466
# Reference:
467
# https://keras.io/examples/vision/knowledge_distillation/
468
def __init__(self, student, teacher, **kwargs):
469
super().__init__(**kwargs)
470
self.student = student
471
self.teacher = teacher
472
473
self.student_loss_tracker = keras.metrics.Mean(name="student_loss")
474
self.dist_loss_tracker = keras.metrics.Mean(name="distillation_loss")
475
self.accuracy_metric = keras.metrics.CategoricalAccuracy(name="accuracy")
476
477
@property
478
def metrics(self):
479
metrics = super().metrics
480
metrics.append(self.student_loss_tracker)
481
metrics.append(self.dist_loss_tracker)
482
metrics.append(self.accuracy_metric)
483
return metrics
484
485
def compile(
486
self,
487
optimizer,
488
student_loss_fn,
489
distillation_loss_fn,
490
):
491
super().compile(optimizer=optimizer)
492
self.student_loss_fn = student_loss_fn
493
self.distillation_loss_fn = distillation_loss_fn
494
495
def train_step(self, data):
496
# Unpack data.
497
x, y = data
498
499
# Normalize for student (ViT expects [0, 1])
500
x_student = keras.ops.cast(x, "float32") / 255.0
501
502
# Teacher expects raw [0, 255] float32 (no normalization)
503
x_teacher = keras.ops.cast(x, "float32")
504
505
# Forward pass of teacher
506
# TFSMLayer returns a dictionary, extract the output
507
teacher_output = self.teacher(x_teacher, training=False)
508
if isinstance(teacher_output, dict):
509
# Get the first (and likely only) output from the dictionary
510
teacher_output = list(teacher_output.values())[0]
511
# Use soft targets (probabilities) for distillation
512
teacher_predictions = keras.ops.nn.softmax(teacher_output, -1)
513
514
with tf.GradientTape() as tape:
515
# Forward pass of student.
516
cls_predictions, dist_predictions = self.student(x_student, training=True)
517
518
# Compute losses.
519
student_loss = self.student_loss_fn(y, cls_predictions)
520
distillation_loss = self.distillation_loss_fn(
521
teacher_predictions, dist_predictions
522
)
523
loss = (student_loss + distillation_loss) / 2
524
525
# Compute gradients.
526
trainable_vars = self.student.trainable_variables
527
gradients = tape.gradient(loss, trainable_vars)
528
529
# Update weights.
530
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
531
532
# Update the metrics.
533
student_predictions = (cls_predictions + dist_predictions) / 2
534
self.accuracy_metric.update_state(y, student_predictions)
535
self.dist_loss_tracker.update_state(distillation_loss)
536
self.student_loss_tracker.update_state(student_loss)
537
538
# Return a dict of performance - include loss
539
return {
540
"loss": loss,
541
"student_loss": self.student_loss_tracker.result(),
542
"distillation_loss": self.dist_loss_tracker.result(),
543
"accuracy": self.accuracy_metric.result(),
544
}
545
546
def test_step(self, data):
547
# Unpack the data.
548
x, y = data
549
550
# Convert to float32 and normalize for student
551
x_normalized = keras.ops.cast(x, "float32") / 255.0
552
553
# Compute predictions.
554
y_prediction = self.student(x_normalized, training=False)
555
556
# Calculate the loss.
557
student_loss = self.student_loss_fn(y, y_prediction)
558
559
# Update the metrics.
560
self.accuracy_metric.update_state(y, y_prediction)
561
self.student_loss_tracker.update_state(student_loss)
562
563
# Return a dict of performance
564
return {
565
"loss": student_loss,
566
"student_loss": self.student_loss_tracker.result(),
567
"accuracy": self.accuracy_metric.result(),
568
}
569
570
def call(self, inputs):
571
# Convert to float32 and normalize for student
572
inputs_normalized = keras.ops.cast(inputs, "float32") / 255.0
573
return self.student(inputs_normalized, training=False)
574
575
576
"""
577
## Load the teacher model
578
579
This model is based on the BiT family of ResNets
580
([Kolesnikov et al.](https://arxiv.org/abs/1912.11370))
581
fine-tuned on the `tf_flowers` dataset. You can refer to
582
[this notebook](https://github.com/sayakpaul/deit-tf/blob/main/notebooks/bit-teacher.ipynb)
583
to know how the training was performed. The teacher model has about 212 Million parameters
584
which is about **40x more** than the student.
585
"""
586
587
"""shell
588
wget -q https://github.com/sayakpaul/deit-tf/releases/download/v0.1.0/bit_teacher_flowers.zip
589
unzip -q bit_teacher_flowers.zip
590
"""
591
592
bit_teacher_flowers = keras.layers.TFSMLayer(
593
"bit_teacher_flowers", call_endpoint="serving_default"
594
)
595
596
"""
597
## Training through distillation
598
"""
599
600
deit_tiny = ViTDistilled()
601
deit_distiller = DeiT(student=deit_tiny, teacher=bit_teacher_flowers)
602
603
lr_scaled = (BASE_LR / 512) * BATCH_SIZE
604
deit_distiller.compile(
605
optimizer=keras.optimizers.AdamW(
606
weight_decay=WEIGHT_DECAY, learning_rate=lr_scaled
607
),
608
student_loss_fn=keras.losses.CategoricalCrossentropy(
609
from_logits=True, label_smoothing=0.1
610
),
611
distillation_loss_fn=keras.losses.CategoricalCrossentropy(from_logits=True),
612
)
613
_ = deit_distiller.fit(train_dataset, validation_data=val_dataset, epochs=NUM_EPOCHS)
614
615
"""
616
If we had trained the same model (the `ViTClassifier`) from scratch with the exact same
617
hyperparameters, the model would have scored about 59% accuracy. You can adapt the following code
618
to reproduce this result:
619
620
```
621
vit_tiny = ViTClassifier()
622
623
inputs = keras.Input((RESOLUTION, RESOLUTION, 3))
624
x = keras.layers.Rescaling(scale=1./255)(inputs)
625
outputs = deit_tiny(x)
626
model = keras.Model(inputs, outputs)
627
628
model.compile(...)
629
model.fit(...)
630
```
631
"""
632
633
"""
634
## Notes
635
636
* Through the use of distillation, we're effectively transferring the inductive biases of
637
a CNN-based teacher model.
638
* Interestingly enough, this distillation strategy works better with a CNN as the teacher
639
model rather than a Transformer as shown in the paper.
640
* The use of regularization to train DeiT models is very important.
641
* ViT models are initialized with a combination of different initializers including
642
truncated normal, random normal, Glorot uniform, etc. If you're looking for
643
end-to-end reproduction of the original results, don't forget to initialize the ViTs well.
644
* If you want to explore the pre-trained DeiT models in Keras with code
645
for fine-tuning, [check out these models on TF-Hub](https://tfhub.dev/sayakpaul/collections/deit/1).
646
647
## Acknowledgements
648
649
* Ross Wightman for keeping
650
[`timm`](https://github.com/rwightman/pytorch-image-models)
651
updated with readable implementations. I referred to the implementations of ViT and DeiT
652
a lot during implementing them in Keras.
653
* [Aritra Roy Gosthipaty](https://github.com/ariG23498)
654
who implemented some portions of the `ViTClassifier` in another project.
655
* [Google Developers Experts](https://developers.google.com/programs/experts/)
656
program for supporting me with GCP credits which were used to run experiments for this
657
example.
658
659
Example available on HuggingFace:
660
661
| Trained Model | Demo |
662
| :--: | :--: |
663
| [![Generic badge](https://img.shields.io/badge/🤗%20Model-DEIT-black.svg)](https://huggingface.co/keras-io/deit) | [![Generic badge](https://img.shields.io/badge/🤗%20Spaces-DEIT-black.svg)](https://huggingface.co/spaces/keras-io/deit/) |
664
665
"""
666
667
"""
668
## Relevant Chapters from Deep Learning with Python
669
- [Chapter 8: Image classification](https://deeplearningwithpython.io/chapters/chapter08_image-classification)
670
- [Chapter 15: Language models and the Transformer](https://deeplearningwithpython.io/chapters/chapter15_language-models-and-the-transformer)
671
"""
672
673