Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/deit.py
3507 views
1
"""
2
Title: Distilling Vision Transformers
3
Author: [Sayak Paul](https://twitter.com/RisingSayak)
4
Date created: 2022/04/05
5
Last modified: 2022/04/08
6
Description: Distillation of Vision Transformers through attention.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
In the original *Vision Transformers* (ViT) paper
14
([Dosovitskiy et al.](https://arxiv.org/abs/2010.11929)),
15
the authors concluded that to perform on par with Convolutional Neural Networks (CNNs),
16
ViTs need to be pre-trained on larger datasets. The larger the better. This is mainly
17
due to the lack of inductive biases in the ViT architecture -- unlike CNNs,
18
they don't have layers that exploit locality. In a follow-up paper
19
([Steiner et al.](https://arxiv.org/abs/2106.10270)),
20
the authors show that it is possible to substantially improve the performance of ViTs
21
with stronger regularization and longer training.
22
23
Many groups have proposed different ways to deal with the problem
24
of data-intensiveness of ViT training.
25
One such way was shown in the *Data-efficient image Transformers*,
26
(DeiT) paper ([Touvron et al.](https://arxiv.org/abs/2012.12877)). The
27
authors introduced a distillation technique that is specific to transformer-based vision
28
models. DeiT is among the first works to show that it's possible to train ViTs well
29
without using larger datasets.
30
31
In this example, we implement the distillation recipe proposed in DeiT. This
32
requires us to slightly tweak the original ViT architecture and write a custom training
33
loop to implement the distillation recipe.
34
35
To run the example, you'll need TensorFlow Addons, which you can install with the
36
following command:
37
38
```
39
pip install tensorflow-addons
40
```
41
42
To comfortably navigate through this example, you'll be expected to know how a ViT and
43
knowledge distillation work. The following are good resources in case you needed a
44
refresher:
45
46
* [ViT on keras.io](https://keras.io/examples/vision/image_classification_with_vision_transformer)
47
* [Knowledge distillation on keras.io](https://keras.io/examples/vision/knowledge_distillation/)
48
"""
49
50
"""
51
## Imports
52
"""
53
54
from typing import List
55
56
import tensorflow as tf
57
import tensorflow_addons as tfa
58
import tensorflow_datasets as tfds
59
import tensorflow_hub as hub
60
from tensorflow import keras
61
from tensorflow.keras import layers
62
63
tfds.disable_progress_bar()
64
tf.keras.utils.set_random_seed(42)
65
66
"""
67
## Constants
68
"""
69
70
# Model
71
MODEL_TYPE = "deit_distilled_tiny_patch16_224"
72
RESOLUTION = 224
73
PATCH_SIZE = 16
74
NUM_PATCHES = (RESOLUTION // PATCH_SIZE) ** 2
75
LAYER_NORM_EPS = 1e-6
76
PROJECTION_DIM = 192
77
NUM_HEADS = 3
78
NUM_LAYERS = 12
79
MLP_UNITS = [
80
PROJECTION_DIM * 4,
81
PROJECTION_DIM,
82
]
83
DROPOUT_RATE = 0.0
84
DROP_PATH_RATE = 0.1
85
86
# Training
87
NUM_EPOCHS = 20
88
BASE_LR = 0.0005
89
WEIGHT_DECAY = 0.0001
90
91
# Data
92
BATCH_SIZE = 256
93
AUTO = tf.data.AUTOTUNE
94
NUM_CLASSES = 5
95
96
"""
97
You probably noticed that `DROPOUT_RATE` has been set 0.0. Dropout has been used
98
in the implementation to keep it complete. For smaller models (like the one used in
99
this example), you don't need it, but for bigger models, using dropout helps.
100
"""
101
102
"""
103
## Load the `tf_flowers` dataset and prepare preprocessing utilities
104
105
The authors use an array of different augmentation techniques, including MixUp
106
([Zhang et al.](https://arxiv.org/abs/1710.09412)),
107
RandAugment ([Cubuk et al.](https://arxiv.org/abs/1909.13719)),
108
and so on. However, to keep the example simple to work through, we'll discard them.
109
"""
110
111
112
def preprocess_dataset(is_training=True):
113
def fn(image, label):
114
if is_training:
115
# Resize to a bigger spatial resolution and take the random
116
# crops.
117
image = tf.image.resize(image, (RESOLUTION + 20, RESOLUTION + 20))
118
image = tf.image.random_crop(image, (RESOLUTION, RESOLUTION, 3))
119
image = tf.image.random_flip_left_right(image)
120
else:
121
image = tf.image.resize(image, (RESOLUTION, RESOLUTION))
122
label = tf.one_hot(label, depth=NUM_CLASSES)
123
return image, label
124
125
return fn
126
127
128
def prepare_dataset(dataset, is_training=True):
129
if is_training:
130
dataset = dataset.shuffle(BATCH_SIZE * 10)
131
dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=AUTO)
132
return dataset.batch(BATCH_SIZE).prefetch(AUTO)
133
134
135
train_dataset, val_dataset = tfds.load(
136
"tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True
137
)
138
num_train = train_dataset.cardinality()
139
num_val = val_dataset.cardinality()
140
print(f"Number of training examples: {num_train}")
141
print(f"Number of validation examples: {num_val}")
142
143
train_dataset = prepare_dataset(train_dataset, is_training=True)
144
val_dataset = prepare_dataset(val_dataset, is_training=False)
145
146
"""
147
## Implementing the DeiT variants of ViT
148
149
Since DeiT is an extension of ViT it'd make sense to first implement ViT and then extend
150
it to support DeiT's components.
151
152
First, we'll implement a layer for Stochastic Depth
153
([Huang et al.](https://arxiv.org/abs/1603.09382))
154
which is used in DeiT for regularization.
155
"""
156
157
158
# Referred from: github.com:rwightman/pytorch-image-models.
159
class StochasticDepth(layers.Layer):
160
def __init__(self, drop_prop, **kwargs):
161
super().__init__(**kwargs)
162
self.drop_prob = drop_prop
163
164
def call(self, x, training=True):
165
if training:
166
keep_prob = 1 - self.drop_prob
167
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
168
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
169
random_tensor = tf.floor(random_tensor)
170
return (x / keep_prob) * random_tensor
171
return x
172
173
174
"""
175
Now, we'll implement the MLP and Transformer blocks.
176
"""
177
178
179
def mlp(x, dropout_rate: float, hidden_units: List):
180
"""FFN for a Transformer block."""
181
# Iterate over the hidden units and
182
# add Dense => Dropout.
183
for idx, units in enumerate(hidden_units):
184
x = layers.Dense(
185
units,
186
activation=tf.nn.gelu if idx == 0 else None,
187
)(x)
188
x = layers.Dropout(dropout_rate)(x)
189
return x
190
191
192
def transformer(drop_prob: float, name: str) -> keras.Model:
193
"""Transformer block with pre-norm."""
194
num_patches = NUM_PATCHES + 2 if "distilled" in MODEL_TYPE else NUM_PATCHES + 1
195
encoded_patches = layers.Input((num_patches, PROJECTION_DIM))
196
197
# Layer normalization 1.
198
x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)
199
200
# Multi Head Self Attention layer 1.
201
attention_output = layers.MultiHeadAttention(
202
num_heads=NUM_HEADS,
203
key_dim=PROJECTION_DIM,
204
dropout=DROPOUT_RATE,
205
)(x1, x1)
206
attention_output = (
207
StochasticDepth(drop_prob)(attention_output) if drop_prob else attention_output
208
)
209
210
# Skip connection 1.
211
x2 = layers.Add()([attention_output, encoded_patches])
212
213
# Layer normalization 2.
214
x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)
215
216
# MLP layer 1.
217
x4 = mlp(x3, hidden_units=MLP_UNITS, dropout_rate=DROPOUT_RATE)
218
x4 = StochasticDepth(drop_prob)(x4) if drop_prob else x4
219
220
# Skip connection 2.
221
outputs = layers.Add()([x2, x4])
222
223
return keras.Model(encoded_patches, outputs, name=name)
224
225
226
"""
227
We'll now implement a `ViTClassifier` class building on top of the components we just
228
developed. Here we'll be following the original pooling strategy used in the ViT paper --
229
use a class token and use the feature representations corresponding to it for
230
classification.
231
"""
232
233
234
class ViTClassifier(keras.Model):
235
"""Vision Transformer base class."""
236
237
def __init__(self, **kwargs):
238
super().__init__(**kwargs)
239
240
# Patchify + linear projection + reshaping.
241
self.projection = keras.Sequential(
242
[
243
layers.Conv2D(
244
filters=PROJECTION_DIM,
245
kernel_size=(PATCH_SIZE, PATCH_SIZE),
246
strides=(PATCH_SIZE, PATCH_SIZE),
247
padding="VALID",
248
name="conv_projection",
249
),
250
layers.Reshape(
251
target_shape=(NUM_PATCHES, PROJECTION_DIM),
252
name="flatten_projection",
253
),
254
],
255
name="projection",
256
)
257
258
# Positional embedding.
259
init_shape = (
260
1,
261
NUM_PATCHES + 1,
262
PROJECTION_DIM,
263
)
264
self.positional_embedding = tf.Variable(
265
tf.zeros(init_shape), name="pos_embedding"
266
)
267
268
# Transformer blocks.
269
dpr = [x for x in tf.linspace(0.0, DROP_PATH_RATE, NUM_LAYERS)]
270
self.transformer_blocks = [
271
transformer(drop_prob=dpr[i], name=f"transformer_block_{i}")
272
for i in range(NUM_LAYERS)
273
]
274
275
# CLS token.
276
initial_value = tf.zeros((1, 1, PROJECTION_DIM))
277
self.cls_token = tf.Variable(
278
initial_value=initial_value, trainable=True, name="cls"
279
)
280
281
# Other layers.
282
self.dropout = layers.Dropout(DROPOUT_RATE)
283
self.layer_norm = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)
284
self.head = layers.Dense(
285
NUM_CLASSES,
286
name="classification_head",
287
)
288
289
def call(self, inputs, training=True):
290
n = tf.shape(inputs)[0]
291
292
# Create patches and project the patches.
293
projected_patches = self.projection(inputs)
294
295
# Append class token if needed.
296
cls_token = tf.tile(self.cls_token, (n, 1, 1))
297
cls_token = tf.cast(cls_token, projected_patches.dtype)
298
projected_patches = tf.concat([cls_token, projected_patches], axis=1)
299
300
# Add positional embeddings to the projected patches.
301
encoded_patches = (
302
self.positional_embedding + projected_patches
303
) # (B, number_patches, projection_dim)
304
encoded_patches = self.dropout(encoded_patches)
305
306
# Iterate over the number of layers and stack up blocks of
307
# Transformer.
308
for transformer_module in self.transformer_blocks:
309
# Add a Transformer block.
310
encoded_patches = transformer_module(encoded_patches)
311
312
# Final layer normalization.
313
representation = self.layer_norm(encoded_patches)
314
315
# Pool representation.
316
encoded_patches = representation[:, 0]
317
318
# Classification head.
319
output = self.head(encoded_patches)
320
return output
321
322
323
"""
324
This class can be used standalone as ViT and is end-to-end trainable. Just remove the
325
`distilled` phrase in `MODEL_TYPE` and it should work with `vit_tiny = ViTClassifier()`.
326
Let's now extend it to DeiT. The following figure presents the schematic of DeiT (taken
327
from the DeiT paper):
328
329
![](https://i.imgur.com/5lmg2Xs.png)
330
331
Apart from the class token, DeiT has another token for distillation. During distillation,
332
the logits corresponding to the class token are compared to the true labels, and the
333
logits corresponding to the distillation token are compared to the teacher's predictions.
334
"""
335
336
337
class ViTDistilled(ViTClassifier):
338
def __init__(self, regular_training=False, **kwargs):
339
super().__init__(**kwargs)
340
self.num_tokens = 2
341
self.regular_training = regular_training
342
343
# CLS and distillation tokens, positional embedding.
344
init_value = tf.zeros((1, 1, PROJECTION_DIM))
345
self.dist_token = tf.Variable(init_value, name="dist_token")
346
self.positional_embedding = tf.Variable(
347
tf.zeros(
348
(
349
1,
350
NUM_PATCHES + self.num_tokens,
351
PROJECTION_DIM,
352
)
353
),
354
name="pos_embedding",
355
)
356
357
# Head layers.
358
self.head = layers.Dense(
359
NUM_CLASSES,
360
name="classification_head",
361
)
362
self.head_dist = layers.Dense(
363
NUM_CLASSES,
364
name="distillation_head",
365
)
366
367
def call(self, inputs, training=True):
368
n = tf.shape(inputs)[0]
369
370
# Create patches and project the patches.
371
projected_patches = self.projection(inputs)
372
373
# Append the tokens.
374
cls_token = tf.tile(self.cls_token, (n, 1, 1))
375
dist_token = tf.tile(self.dist_token, (n, 1, 1))
376
cls_token = tf.cast(cls_token, projected_patches.dtype)
377
dist_token = tf.cast(dist_token, projected_patches.dtype)
378
projected_patches = tf.concat(
379
[cls_token, dist_token, projected_patches], axis=1
380
)
381
382
# Add positional embeddings to the projected patches.
383
encoded_patches = (
384
self.positional_embedding + projected_patches
385
) # (B, number_patches, projection_dim)
386
encoded_patches = self.dropout(encoded_patches)
387
388
# Iterate over the number of layers and stack up blocks of
389
# Transformer.
390
for transformer_module in self.transformer_blocks:
391
# Add a Transformer block.
392
encoded_patches = transformer_module(encoded_patches)
393
394
# Final layer normalization.
395
representation = self.layer_norm(encoded_patches)
396
397
# Classification heads.
398
x, x_dist = (
399
self.head(representation[:, 0]),
400
self.head_dist(representation[:, 1]),
401
)
402
403
if not training or self.regular_training:
404
# During standard train / finetune, inference average the classifier
405
# predictions.
406
return (x + x_dist) / 2
407
408
elif training:
409
# Only return separate classification predictions when training in distilled
410
# mode.
411
return x, x_dist
412
413
414
"""
415
Let's verify if the `ViTDistilled` class can be initialized and called as expected.
416
"""
417
418
deit_tiny_distilled = ViTDistilled()
419
420
dummy_inputs = tf.ones((2, 224, 224, 3))
421
outputs = deit_tiny_distilled(dummy_inputs, training=False)
422
print(outputs.shape)
423
424
"""
425
## Implementing the trainer
426
427
Unlike what happens in standard knowledge distillation
428
([Hinton et al.](https://arxiv.org/abs/1503.02531)),
429
where a temperature-scaled softmax is used as well as KL divergence,
430
DeiT authors use the following loss function:
431
432
![](https://i.imgur.com/bXdxsBq.png)
433
434
435
Here,
436
437
* CE is cross-entropy
438
* `psi` is the softmax function
439
* Z_s denotes student predictions
440
* y denotes true labels
441
* y_t denotes teacher predictions
442
"""
443
444
445
class DeiT(keras.Model):
446
# Reference:
447
# https://keras.io/examples/vision/knowledge_distillation/
448
def __init__(self, student, teacher, **kwargs):
449
super().__init__(**kwargs)
450
self.student = student
451
self.teacher = teacher
452
453
self.student_loss_tracker = keras.metrics.Mean(name="student_loss")
454
self.dist_loss_tracker = keras.metrics.Mean(name="distillation_loss")
455
456
@property
457
def metrics(self):
458
metrics = super().metrics
459
metrics.append(self.student_loss_tracker)
460
metrics.append(self.dist_loss_tracker)
461
return metrics
462
463
def compile(
464
self,
465
optimizer,
466
metrics,
467
student_loss_fn,
468
distillation_loss_fn,
469
):
470
super().compile(optimizer=optimizer, metrics=metrics)
471
self.student_loss_fn = student_loss_fn
472
self.distillation_loss_fn = distillation_loss_fn
473
474
def train_step(self, data):
475
# Unpack data.
476
x, y = data
477
478
# Forward pass of teacher
479
teacher_predictions = tf.nn.softmax(self.teacher(x, training=False), -1)
480
teacher_predictions = tf.argmax(teacher_predictions, -1)
481
482
with tf.GradientTape() as tape:
483
# Forward pass of student.
484
cls_predictions, dist_predictions = self.student(x / 255.0, training=True)
485
486
# Compute losses.
487
student_loss = self.student_loss_fn(y, cls_predictions)
488
distillation_loss = self.distillation_loss_fn(
489
teacher_predictions, dist_predictions
490
)
491
loss = (student_loss + distillation_loss) / 2
492
493
# Compute gradients.
494
trainable_vars = self.student.trainable_variables
495
gradients = tape.gradient(loss, trainable_vars)
496
497
# Update weights.
498
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
499
500
# Update the metrics configured in `compile()`.
501
student_predictions = (cls_predictions + dist_predictions) / 2
502
self.compiled_metrics.update_state(y, student_predictions)
503
self.dist_loss_tracker.update_state(distillation_loss)
504
self.student_loss_tracker.update_state(student_loss)
505
506
# Return a dict of performance.
507
results = {m.name: m.result() for m in self.metrics}
508
return results
509
510
def test_step(self, data):
511
# Unpack the data.
512
x, y = data
513
514
# Compute predictions.
515
y_prediction = self.student(x / 255.0, training=False)
516
517
# Calculate the loss.
518
student_loss = self.student_loss_fn(y, y_prediction)
519
520
# Update the metrics.
521
self.compiled_metrics.update_state(y, y_prediction)
522
self.student_loss_tracker.update_state(student_loss)
523
524
# Return a dict of performance.
525
results = {m.name: m.result() for m in self.metrics}
526
return results
527
528
def call(self, inputs):
529
return self.student(inputs / 255.0, training=False)
530
531
532
"""
533
## Load the teacher model
534
535
This model is based on the BiT family of ResNets
536
([Kolesnikov et al.](https://arxiv.org/abs/1912.11370))
537
fine-tuned on the `tf_flowers` dataset. You can refer to
538
[this notebook](https://github.com/sayakpaul/deit-tf/blob/main/notebooks/bit-teacher.ipynb)
539
to know how the training was performed. The teacher model has about 212 Million parameters
540
which is about **40x more** than the student.
541
"""
542
543
"""shell
544
wget -q https://github.com/sayakpaul/deit-tf/releases/download/v0.1.0/bit_teacher_flowers.zip
545
unzip -q bit_teacher_flowers.zip
546
"""
547
548
bit_teacher_flowers = keras.models.load_model("bit_teacher_flowers")
549
550
"""
551
## Training through distillation
552
"""
553
554
deit_tiny = ViTDistilled()
555
deit_distiller = DeiT(student=deit_tiny, teacher=bit_teacher_flowers)
556
557
lr_scaled = (BASE_LR / 512) * BATCH_SIZE
558
deit_distiller.compile(
559
optimizer=tfa.optimizers.AdamW(weight_decay=WEIGHT_DECAY, learning_rate=lr_scaled),
560
metrics=["accuracy"],
561
student_loss_fn=keras.losses.CategoricalCrossentropy(
562
from_logits=True, label_smoothing=0.1
563
),
564
distillation_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
565
)
566
_ = deit_distiller.fit(train_dataset, validation_data=val_dataset, epochs=NUM_EPOCHS)
567
568
"""
569
If we had trained the same model (the `ViTClassifier`) from scratch with the exact same
570
hyperparameters, the model would have scored about 59% accuracy. You can adapt the following code
571
to reproduce this result:
572
573
```
574
vit_tiny = ViTClassifier()
575
576
inputs = keras.Input((RESOLUTION, RESOLUTION, 3))
577
x = keras.layers.Rescaling(scale=1./255)(inputs)
578
outputs = deit_tiny(x)
579
model = keras.Model(inputs, outputs)
580
581
model.compile(...)
582
model.fit(...)
583
```
584
"""
585
586
"""
587
## Notes
588
589
* Through the use of distillation, we're effectively transferring the inductive biases of
590
a CNN-based teacher model.
591
* Interestingly enough, this distillation strategy works better with a CNN as the teacher
592
model rather than a Transformer as shown in the paper.
593
* The use of regularization to train DeiT models is very important.
594
* ViT models are initialized with a combination of different initializers including
595
truncated normal, random normal, Glorot uniform, etc. If you're looking for
596
end-to-end reproduction of the original results, don't forget to initialize the ViTs well.
597
* If you want to explore the pre-trained DeiT models in TensorFlow and Keras with code
598
for fine-tuning, [check out these models on TF-Hub](https://tfhub.dev/sayakpaul/collections/deit/1).
599
600
## Acknowledgements
601
602
* Ross Wightman for keeping
603
[`timm`](https://github.com/rwightman/pytorch-image-models)
604
updated with readable implementations. I referred to the implementations of ViT and DeiT
605
a lot during implementing them in TensorFlow.
606
* [Aritra Roy Gosthipaty](https://github.com/ariG23498)
607
who implemented some portions of the `ViTClassifier` in another project.
608
* [Google Developers Experts](https://developers.google.com/programs/experts/)
609
program for supporting me with GCP credits which were used to run experiments for this
610
example.
611
612
Example available on HuggingFace:
613
614
| Trained Model | Demo |
615
| :--: | :--: |
616
| [![Generic badge](https://img.shields.io/badge/🤗%20Model-DEIT-black.svg)](https://huggingface.co/keras-io/deit) | [![Generic badge](https://img.shields.io/badge/🤗%20Spaces-DEIT-black.svg)](https://huggingface.co/spaces/keras-io/deit/) |
617
618
"""
619
620