Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/cait.py
3507 views
1
"""
2
Title: Class Attention Image Transformers with LayerScale
3
Author: [Sayak Paul](https://twitter.com/RisingSayak)
4
Date created: 2022/09/19
5
Last modified: 2022/11/21
6
Description: Implementing an image transformer equipped with Class Attention and LayerScale.
7
Accelerator: None
8
"""
9
10
"""
11
12
## Introduction
13
14
In this tutorial, we implement the CaiT (Class-Attention in Image Transformers)
15
proposed in [Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239) by
16
Touvron et al. Depth scaling, i.e. increasing the model depth for obtaining better
17
performance and generalization has been quite successful for convolutional neural
18
networks ([Tan et al.](https://arxiv.org/abs/1905.11946),
19
[Dollár et al.](https://arxiv.org/abs/2103.06877), for example). But applying
20
the same model scaling principles to
21
Vision Transformers ([Dosovitskiy et al.](https://arxiv.org/abs/2010.11929)) doesn't
22
translate equally well -- their performance gets saturated quickly with depth scaling.
23
Note that one assumption here is that the underlying pre-training dataset is
24
always kept fixed when performing model scaling.
25
26
In the CaiT paper, the authors investigate this phenomenon and propose modifications to
27
the vanilla ViT (Vision Transformers) architecture to mitigate this problem.
28
29
The tutorial is structured like so:
30
31
* Implementation of the individual blocks of CaiT
32
* Collating all the blocks to create the CaiT model
33
* Loading a pre-trained CaiT model
34
* Obtaining prediction results
35
* Visualization of the different attention layers of CaiT
36
37
The readers are assumed to be familiar with Vision Transformers already. Here is
38
an implementation of Vision Transformers in Keras:
39
[Image classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/).
40
"""
41
42
"""
43
## Imports
44
"""
45
46
import os
47
48
os.environ["KERAS_BACKEND"] = "tensorflow"
49
50
import io
51
import typing
52
from urllib.request import urlopen
53
54
import matplotlib.pyplot as plt
55
import numpy as np
56
import PIL
57
import keras
58
from keras import layers
59
from keras import ops
60
61
"""
62
## The LayerScale layer
63
64
We begin by implementing a **LayerScale** layer which is one of the two modifications
65
proposed in the CaiT paper.
66
67
When increasing the depth of the ViT models, they meet with optimization instability and
68
eventually don't converge. The residual connections within each Transformer block
69
introduce information bottleneck. When there is an increased amount of depth, this
70
bottleneck can quickly explode and deviate the optimization pathway for the underlying
71
model.
72
73
The following equations denote where residual connections are added within a Transformer
74
block:
75
76
<div align="center">
77
<img src="https://i.ibb.co/jWV5bFb/image.png"/>
78
</div>
79
80
where, **SA** stands for self-attention, **FFN** stands for feed-forward network, and
81
**eta** denotes the LayerNorm operator ([Ba et al.](https://arxiv.org/abs/1607.06450)).
82
83
LayerScale is formally implemented like so:
84
85
<div align="center">
86
<img src="https://i.ibb.co/VYDWNn9/image.png"/>
87
</div>
88
89
where, the lambdas are learnable parameters and are initialized with a very small value
90
({0.1, 1e-5, 1e-6}). **diag** represents a diagonal matrix.
91
92
Intuitively, LayerScale helps control the contribution of the residual branches. The
93
learnable parameters of LayerScale are initialized to a small value to let the branches
94
act like identity functions and then let them figure out the degrees of interactions
95
during the training. The diagonal matrix additionally helps control the contributions
96
of the individual dimensions of the residual inputs as it is applied on a per-channel
97
basis.
98
99
The practical implementation of LayerScale is simpler than it might sound.
100
"""
101
102
103
class LayerScale(layers.Layer):
104
"""LayerScale as introduced in CaiT: https://arxiv.org/abs/2103.17239.
105
106
Args:
107
init_values (float): value to initialize the diagonal matrix of LayerScale.
108
projection_dim (int): projection dimension used in LayerScale.
109
"""
110
111
def __init__(self, init_values: float, projection_dim: int, **kwargs):
112
super().__init__(**kwargs)
113
self.gamma = self.add_weight(
114
shape=(projection_dim,),
115
initializer=keras.initializers.Constant(init_values),
116
)
117
118
def call(self, x, training=False):
119
return x * self.gamma
120
121
122
"""
123
## Stochastic depth layer
124
125
Since its introduction ([Huang et al.](https://arxiv.org/abs/1603.09382)), Stochastic
126
Depth has become a favorite component in almost all modern neural network architectures.
127
CaiT is no exception. Discussing Stochastic Depth is out of scope for this notebook. You
128
can refer to [this resource](https://paperswithcode.com/method/stochastic-depth) in case
129
you need a refresher.
130
"""
131
132
133
class StochasticDepth(layers.Layer):
134
"""Stochastic Depth layer (https://arxiv.org/abs/1603.09382).
135
136
Reference:
137
https://github.com/rwightman/pytorch-image-models
138
"""
139
140
def __init__(self, drop_prob: float, **kwargs):
141
super().__init__(**kwargs)
142
self.drop_prob = drop_prob
143
self.seed_generator = keras.random.SeedGenerator(1337)
144
145
def call(self, x, training=False):
146
if training:
147
keep_prob = 1 - self.drop_prob
148
shape = (ops.shape(x)[0],) + (1,) * (len(x.shape) - 1)
149
random_tensor = keep_prob + ops.random.uniform(
150
shape, minval=0, maxval=1, seed=self.seed_generator
151
)
152
random_tensor = ops.floor(random_tensor)
153
return (x / keep_prob) * random_tensor
154
return x
155
156
157
"""
158
## Class attention
159
160
The vanilla ViT uses self-attention (SA) layers for modelling how the image patches and
161
the _learnable_ CLS token interact with each other. The CaiT authors propose to decouple
162
the attention layers responsible for attending to the image patches and the CLS tokens.
163
164
When using ViTs for any discriminative tasks (classification, for example), we usually
165
take the representations belonging to the CLS token and then pass them to the
166
task-specific heads. This is as opposed to using something like global average pooling as
167
is typically done in convolutional neural networks.
168
169
The interactions between the CLS token and other image patches are processed uniformly
170
through self-attention layers. As the CaiT authors point out, this setup has got an
171
entangled effect. On one hand, the self-attention layers are responsible for modelling
172
the image patches. On the other hand, they're also responsible for summarizing the
173
modelled information via the CLS token so that it's useful for the learning objective.
174
175
To help disentangle these two things, the authors propose to:
176
177
* Introduce the CLS token at a later stage in the network.
178
* Model the interaction between the CLS token and the representations related to the
179
image patches through a separate set of attention layers. The authors call this **Class
180
Attention** (CA).
181
182
The figure below (taken from the original paper) depicts this idea:
183
184
<div align="center">
185
<img src="https://i.imgur.com/cxeooHr.png"/ width=350>
186
</div>
187
188
This is achieved by treating the CLS token embeddings as the queries in the CA layers.
189
CLS token embeddings and the image patch embeddings are fed as keys as well values.
190
191
**Note** that "embeddings" and "representations" have been used interchangeably here.
192
"""
193
194
195
class ClassAttention(layers.Layer):
196
"""Class attention as proposed in CaiT: https://arxiv.org/abs/2103.17239.
197
198
Args:
199
projection_dim (int): projection dimension for the query, key, and value
200
of attention.
201
num_heads (int): number of attention heads.
202
dropout_rate (float): dropout rate to be used for dropout in the attention
203
scores as well as the final projected outputs.
204
"""
205
206
def __init__(
207
self, projection_dim: int, num_heads: int, dropout_rate: float, **kwargs
208
):
209
super().__init__(**kwargs)
210
self.num_heads = num_heads
211
212
head_dim = projection_dim // num_heads
213
self.scale = head_dim**-0.5
214
215
self.q = layers.Dense(projection_dim)
216
self.k = layers.Dense(projection_dim)
217
self.v = layers.Dense(projection_dim)
218
self.attn_drop = layers.Dropout(dropout_rate)
219
self.proj = layers.Dense(projection_dim)
220
self.proj_drop = layers.Dropout(dropout_rate)
221
222
def call(self, x, training=False):
223
batch_size, num_patches, num_channels = (
224
ops.shape(x)[0],
225
ops.shape(x)[1],
226
ops.shape(x)[2],
227
)
228
229
# Query projection. `cls_token` embeddings are queries.
230
q = ops.expand_dims(self.q(x[:, 0]), axis=1)
231
q = ops.reshape(
232
q, (batch_size, 1, self.num_heads, num_channels // self.num_heads)
233
) # Shape: (batch_size, 1, num_heads, dimension_per_head)
234
q = ops.transpose(q, axes=[0, 2, 1, 3])
235
scale = ops.cast(self.scale, dtype=q.dtype)
236
q = q * scale
237
238
# Key projection. Patch embeddings as well the cls embedding are used as keys.
239
k = self.k(x)
240
k = ops.reshape(
241
k, (batch_size, num_patches, self.num_heads, num_channels // self.num_heads)
242
) # Shape: (batch_size, num_tokens, num_heads, dimension_per_head)
243
k = ops.transpose(k, axes=[0, 2, 3, 1])
244
245
# Value projection. Patch embeddings as well the cls embedding are used as values.
246
v = self.v(x)
247
v = ops.reshape(
248
v, (batch_size, num_patches, self.num_heads, num_channels // self.num_heads)
249
)
250
v = ops.transpose(v, axes=[0, 2, 1, 3])
251
252
# Calculate attention scores between cls_token embedding and patch embeddings.
253
attn = ops.matmul(q, k)
254
attn = ops.nn.softmax(attn, axis=-1)
255
attn = self.attn_drop(attn, training=training)
256
257
x_cls = ops.matmul(attn, v)
258
x_cls = ops.transpose(x_cls, axes=[0, 2, 1, 3])
259
x_cls = ops.reshape(x_cls, (batch_size, 1, num_channels))
260
x_cls = self.proj(x_cls)
261
x_cls = self.proj_drop(x_cls, training=training)
262
263
return x_cls, attn
264
265
266
"""
267
## Talking Head Attention
268
269
The CaiT authors use the Talking Head attention
270
([Shazeer et al.](https://arxiv.org/abs/2003.02436))
271
instead of the vanilla scaled dot-product multi-head attention used in
272
the original Transformer paper
273
([Vaswani et al.](https://papers.nips.cc/paper/7181-attention-is-all-you-need)).
274
They introduce two linear projections before and after the softmax
275
operations for obtaining better results.
276
277
For a more rigorous treatment of the Talking Head attention and the vanilla attention
278
mechanisms, please refer to their respective papers (linked above).
279
"""
280
281
282
class TalkingHeadAttention(layers.Layer):
283
"""Talking-head attention as proposed in CaiT: https://arxiv.org/abs/2003.02436.
284
285
Args:
286
projection_dim (int): projection dimension for the query, key, and value
287
of attention.
288
num_heads (int): number of attention heads.
289
dropout_rate (float): dropout rate to be used for dropout in the attention
290
scores as well as the final projected outputs.
291
"""
292
293
def __init__(
294
self, projection_dim: int, num_heads: int, dropout_rate: float, **kwargs
295
):
296
super().__init__(**kwargs)
297
298
self.num_heads = num_heads
299
300
head_dim = projection_dim // self.num_heads
301
302
self.scale = head_dim**-0.5
303
304
self.qkv = layers.Dense(projection_dim * 3)
305
self.attn_drop = layers.Dropout(dropout_rate)
306
307
self.proj = layers.Dense(projection_dim)
308
309
self.proj_l = layers.Dense(self.num_heads)
310
self.proj_w = layers.Dense(self.num_heads)
311
312
self.proj_drop = layers.Dropout(dropout_rate)
313
314
def call(self, x, training=False):
315
B, N, C = ops.shape(x)[0], ops.shape(x)[1], ops.shape(x)[2]
316
317
# Project the inputs all at once.
318
qkv = self.qkv(x)
319
320
# Reshape the projected output so that they're segregated in terms of
321
# query, key, and value projections.
322
qkv = ops.reshape(qkv, (B, N, 3, self.num_heads, C // self.num_heads))
323
324
# Transpose so that the `num_heads` becomes the leading dimensions.
325
# Helps to better segregate the representation sub-spaces.
326
qkv = ops.transpose(qkv, axes=[2, 0, 3, 1, 4])
327
scale = ops.cast(self.scale, dtype=qkv.dtype)
328
q, k, v = qkv[0] * scale, qkv[1], qkv[2]
329
330
# Obtain the raw attention scores.
331
attn = ops.matmul(q, ops.transpose(k, axes=[0, 1, 3, 2]))
332
333
# Linear projection of the similarities between the query and key projections.
334
attn = self.proj_l(ops.transpose(attn, axes=[0, 2, 3, 1]))
335
336
# Normalize the attention scores.
337
attn = ops.transpose(attn, axes=[0, 3, 1, 2])
338
attn = ops.nn.softmax(attn, axis=-1)
339
340
# Linear projection on the softmaxed scores.
341
attn = self.proj_w(ops.transpose(attn, axes=[0, 2, 3, 1]))
342
attn = ops.transpose(attn, axes=[0, 3, 1, 2])
343
attn = self.attn_drop(attn, training=training)
344
345
# Final set of projections as done in the vanilla attention mechanism.
346
x = ops.matmul(attn, v)
347
x = ops.transpose(x, axes=[0, 2, 1, 3])
348
x = ops.reshape(x, (B, N, C))
349
350
x = self.proj(x)
351
x = self.proj_drop(x, training=training)
352
353
return x, attn
354
355
356
"""
357
## Feed-forward Network
358
359
Next, we implement the feed-forward network which is one of the components within a
360
Transformer block.
361
"""
362
363
364
def mlp(x, dropout_rate: float, hidden_units: typing.List[int]):
365
"""FFN for a Transformer block."""
366
for idx, units in enumerate(hidden_units):
367
x = layers.Dense(
368
units,
369
activation=ops.nn.gelu if idx == 0 else None,
370
bias_initializer=keras.initializers.RandomNormal(stddev=1e-6),
371
)(x)
372
x = layers.Dropout(dropout_rate)(x)
373
return x
374
375
376
"""
377
## Other blocks
378
379
In the next two cells, we implement the remaining blocks as standalone functions:
380
381
* `LayerScaleBlockClassAttention()` which returns a `keras.Model`. It is a Transformer block
382
equipped with Class Attention, LayerScale, and Stochastic Depth. It operates on the CLS
383
embeddings and the image patch embeddings.
384
* `LayerScaleBlock()` which returns a `keras.model`. It is also a Transformer block that
385
operates only on the embeddings of the image patches. It is equipped with LayerScale and
386
Stochastic Depth.
387
"""
388
389
390
def LayerScaleBlockClassAttention(
391
projection_dim: int,
392
num_heads: int,
393
layer_norm_eps: float,
394
init_values: float,
395
mlp_units: typing.List[int],
396
dropout_rate: float,
397
sd_prob: float,
398
name: str,
399
):
400
"""Pre-norm transformer block meant to be applied to the embeddings of the
401
cls token and the embeddings of image patches.
402
403
Includes LayerScale and Stochastic Depth.
404
405
Args:
406
projection_dim (int): projection dimension to be used in the
407
Transformer blocks and patch projection layer.
408
num_heads (int): number of attention heads.
409
layer_norm_eps (float): epsilon to be used for Layer Normalization.
410
init_values (float): initial value for the diagonal matrix used in LayerScale.
411
mlp_units (List[int]): dimensions of the feed-forward network used in
412
the Transformer blocks.
413
dropout_rate (float): dropout rate to be used for dropout in the attention
414
scores as well as the final projected outputs.
415
sd_prob (float): stochastic depth rate.
416
name (str): a name identifier for the block.
417
418
Returns:
419
A keras.Model instance.
420
"""
421
x = keras.Input((None, projection_dim))
422
x_cls = keras.Input((None, projection_dim))
423
inputs = keras.layers.Concatenate(axis=1)([x_cls, x])
424
425
# Class attention (CA).
426
x1 = layers.LayerNormalization(epsilon=layer_norm_eps)(inputs)
427
attn_output, attn_scores = ClassAttention(projection_dim, num_heads, dropout_rate)(
428
x1
429
)
430
attn_output = (
431
LayerScale(init_values, projection_dim)(attn_output)
432
if init_values
433
else attn_output
434
)
435
attn_output = StochasticDepth(sd_prob)(attn_output) if sd_prob else attn_output
436
x2 = keras.layers.Add()([x_cls, attn_output])
437
438
# FFN.
439
x3 = layers.LayerNormalization(epsilon=layer_norm_eps)(x2)
440
x4 = mlp(x3, hidden_units=mlp_units, dropout_rate=dropout_rate)
441
x4 = LayerScale(init_values, projection_dim)(x4) if init_values else x4
442
x4 = StochasticDepth(sd_prob)(x4) if sd_prob else x4
443
outputs = keras.layers.Add()([x2, x4])
444
445
return keras.Model([x, x_cls], [outputs, attn_scores], name=name)
446
447
448
def LayerScaleBlock(
449
projection_dim: int,
450
num_heads: int,
451
layer_norm_eps: float,
452
init_values: float,
453
mlp_units: typing.List[int],
454
dropout_rate: float,
455
sd_prob: float,
456
name: str,
457
):
458
"""Pre-norm transformer block meant to be applied to the embeddings of the
459
image patches.
460
461
Includes LayerScale and Stochastic Depth.
462
463
Args:
464
projection_dim (int): projection dimension to be used in the
465
Transformer blocks and patch projection layer.
466
num_heads (int): number of attention heads.
467
layer_norm_eps (float): epsilon to be used for Layer Normalization.
468
init_values (float): initial value for the diagonal matrix used in LayerScale.
469
mlp_units (List[int]): dimensions of the feed-forward network used in
470
the Transformer blocks.
471
dropout_rate (float): dropout rate to be used for dropout in the attention
472
scores as well as the final projected outputs.
473
sd_prob (float): stochastic depth rate.
474
name (str): a name identifier for the block.
475
476
Returns:
477
A keras.Model instance.
478
"""
479
encoded_patches = keras.Input((None, projection_dim))
480
481
# Self-attention.
482
x1 = layers.LayerNormalization(epsilon=layer_norm_eps)(encoded_patches)
483
attn_output, attn_scores = TalkingHeadAttention(
484
projection_dim, num_heads, dropout_rate
485
)(x1)
486
attn_output = (
487
LayerScale(init_values, projection_dim)(attn_output)
488
if init_values
489
else attn_output
490
)
491
attn_output = StochasticDepth(sd_prob)(attn_output) if sd_prob else attn_output
492
x2 = layers.Add()([encoded_patches, attn_output])
493
494
# FFN.
495
x3 = layers.LayerNormalization(epsilon=layer_norm_eps)(x2)
496
x4 = mlp(x3, hidden_units=mlp_units, dropout_rate=dropout_rate)
497
x4 = LayerScale(init_values, projection_dim)(x4) if init_values else x4
498
x4 = StochasticDepth(sd_prob)(x4) if sd_prob else x4
499
outputs = layers.Add()([x2, x4])
500
501
return keras.Model(encoded_patches, [outputs, attn_scores], name=name)
502
503
504
"""
505
Given all these blocks, we are now ready to collate them into the final CaiT model.
506
"""
507
508
"""
509
## Putting the pieces together: The CaiT model
510
"""
511
512
513
class CaiT(keras.Model):
514
"""CaiT model.
515
516
Args:
517
projection_dim (int): projection dimension to be used in the
518
Transformer blocks and patch projection layer.
519
patch_size (int): patch size of the input images.
520
num_patches (int): number of patches after extracting the image patches.
521
init_values (float): initial value for the diagonal matrix used in LayerScale.
522
mlp_units: (List[int]): dimensions of the feed-forward network used in
523
the Transformer blocks.
524
sa_ffn_layers (int): number of self-attention Transformer blocks.
525
ca_ffn_layers (int): number of class-attention Transformer blocks.
526
num_heads (int): number of attention heads.
527
layer_norm_eps (float): epsilon to be used for Layer Normalization.
528
dropout_rate (float): dropout rate to be used for dropout in the attention
529
scores as well as the final projected outputs.
530
sd_prob (float): stochastic depth rate.
531
global_pool (str): denotes how to pool the representations coming out of
532
the final Transformer block.
533
pre_logits (bool): if set to True then don't add a classification head.
534
num_classes (int): number of classes to construct the final classification
535
layer with.
536
"""
537
538
def __init__(
539
self,
540
projection_dim: int,
541
patch_size: int,
542
num_patches: int,
543
init_values: float,
544
mlp_units: typing.List[int],
545
sa_ffn_layers: int,
546
ca_ffn_layers: int,
547
num_heads: int,
548
layer_norm_eps: float,
549
dropout_rate: float,
550
sd_prob: float,
551
global_pool: str,
552
pre_logits: bool,
553
num_classes: int,
554
**kwargs,
555
):
556
if global_pool not in ["token", "avg"]:
557
raise ValueError(
558
'Invalid value received for `global_pool`, should be either `"token"` or `"avg"`.'
559
)
560
561
super().__init__(**kwargs)
562
563
# Responsible for patchifying the input images and the linearly projecting them.
564
self.projection = keras.Sequential(
565
[
566
layers.Conv2D(
567
filters=projection_dim,
568
kernel_size=(patch_size, patch_size),
569
strides=(patch_size, patch_size),
570
padding="VALID",
571
name="conv_projection",
572
kernel_initializer="lecun_normal",
573
),
574
layers.Reshape(
575
target_shape=(-1, projection_dim),
576
name="flatten_projection",
577
),
578
],
579
name="projection",
580
)
581
582
# CLS token and the positional embeddings.
583
self.cls_token = self.add_weight(
584
shape=(1, 1, projection_dim), initializer="zeros"
585
)
586
self.pos_embed = self.add_weight(
587
shape=(1, num_patches, projection_dim), initializer="zeros"
588
)
589
590
# Projection dropout.
591
self.pos_drop = layers.Dropout(dropout_rate, name="projection_dropout")
592
593
# Stochastic depth schedule.
594
dpr = [sd_prob for _ in range(sa_ffn_layers)]
595
596
# Self-attention (SA) Transformer blocks operating only on the image patch
597
# embeddings.
598
self.blocks = [
599
LayerScaleBlock(
600
projection_dim=projection_dim,
601
num_heads=num_heads,
602
layer_norm_eps=layer_norm_eps,
603
init_values=init_values,
604
mlp_units=mlp_units,
605
dropout_rate=dropout_rate,
606
sd_prob=dpr[i],
607
name=f"sa_ffn_block_{i}",
608
)
609
for i in range(sa_ffn_layers)
610
]
611
612
# Class Attention (CA) Transformer blocks operating on the CLS token and image patch
613
# embeddings.
614
self.blocks_token_only = [
615
LayerScaleBlockClassAttention(
616
projection_dim=projection_dim,
617
num_heads=num_heads,
618
layer_norm_eps=layer_norm_eps,
619
init_values=init_values,
620
mlp_units=mlp_units,
621
dropout_rate=dropout_rate,
622
name=f"ca_ffn_block_{i}",
623
sd_prob=0.0, # No Stochastic Depth in the class attention layers.
624
)
625
for i in range(ca_ffn_layers)
626
]
627
628
# Pre-classification layer normalization.
629
self.norm = layers.LayerNormalization(epsilon=layer_norm_eps, name="head_norm")
630
631
# Representation pooling for classification head.
632
self.global_pool = global_pool
633
634
# Classification head.
635
self.pre_logits = pre_logits
636
self.num_classes = num_classes
637
if not pre_logits:
638
self.head = layers.Dense(num_classes, name="classification_head")
639
640
def call(self, x, training=False):
641
# Notice how CLS token is not added here.
642
x = self.projection(x)
643
x = x + self.pos_embed
644
x = self.pos_drop(x)
645
646
# SA+FFN layers.
647
sa_ffn_attn = {}
648
for blk in self.blocks:
649
x, attn_scores = blk(x)
650
sa_ffn_attn[f"{blk.name}_att"] = attn_scores
651
652
# CA+FFN layers.
653
ca_ffn_attn = {}
654
cls_tokens = ops.tile(self.cls_token, (ops.shape(x)[0], 1, 1))
655
for blk in self.blocks_token_only:
656
cls_tokens, attn_scores = blk([x, cls_tokens])
657
ca_ffn_attn[f"{blk.name}_att"] = attn_scores
658
659
x = ops.concatenate([cls_tokens, x], axis=1)
660
x = self.norm(x)
661
662
# Always return the attention scores from the SA+FFN and CA+FFN layers
663
# for convenience.
664
if self.global_pool:
665
x = (
666
ops.reduce_mean(x[:, 1:], axis=1)
667
if self.global_pool == "avg"
668
else x[:, 0]
669
)
670
return (
671
(x, sa_ffn_attn, ca_ffn_attn)
672
if self.pre_logits
673
else (self.head(x), sa_ffn_attn, ca_ffn_attn)
674
)
675
676
677
"""
678
Having the SA and CA layers segregated this way helps the model to focus on underlying
679
objectives more concretely:
680
681
* model dependencies in between the image patches
682
* summarize the information from the image patches in a CLS token that can be used for
683
the task at hand
684
685
Now that we have defined the CaiT model, it's time to test it. We will start by defining
686
a model configuration that will be passed to our `CaiT` class for initialization.
687
"""
688
689
"""
690
## Defining Model Configuration
691
"""
692
693
694
def get_config(
695
image_size: int = 224,
696
patch_size: int = 16,
697
projection_dim: int = 192,
698
sa_ffn_layers: int = 24,
699
ca_ffn_layers: int = 2,
700
num_heads: int = 4,
701
mlp_ratio: int = 4,
702
layer_norm_eps=1e-6,
703
init_values: float = 1e-5,
704
dropout_rate: float = 0.0,
705
sd_prob: float = 0.0,
706
global_pool: str = "token",
707
pre_logits: bool = False,
708
num_classes: int = 1000,
709
) -> typing.Dict:
710
"""Default configuration for CaiT models (cait_xxs24_224).
711
712
Reference:
713
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/cait.py
714
"""
715
config = {}
716
717
# Patchification and projection.
718
config["patch_size"] = patch_size
719
config["num_patches"] = (image_size // patch_size) ** 2
720
721
# LayerScale.
722
config["init_values"] = init_values
723
724
# Dropout and Stochastic Depth.
725
config["dropout_rate"] = dropout_rate
726
config["sd_prob"] = sd_prob
727
728
# Shared across different blocks and layers.
729
config["layer_norm_eps"] = layer_norm_eps
730
config["projection_dim"] = projection_dim
731
config["mlp_units"] = [
732
projection_dim * mlp_ratio,
733
projection_dim,
734
]
735
736
# Attention layers.
737
config["num_heads"] = num_heads
738
config["sa_ffn_layers"] = sa_ffn_layers
739
config["ca_ffn_layers"] = ca_ffn_layers
740
741
# Representation pooling and task specific parameters.
742
config["global_pool"] = global_pool
743
config["pre_logits"] = pre_logits
744
config["num_classes"] = num_classes
745
746
return config
747
748
749
"""
750
Most of the configuration variables should sound familiar to you if you already know the
751
ViT architecture. Point of focus is given to `sa_ffn_layers` and `ca_ffn_layers` that
752
control the number of SA-Transformer blocks and CA-Transformer blocks. You can easily
753
amend this `get_config()` method to instantiate a CaiT model for your own dataset.
754
"""
755
756
"""
757
## Model Instantiation
758
"""
759
760
image_size = 224
761
num_channels = 3
762
batch_size = 2
763
764
config = get_config()
765
cait_xxs24_224 = CaiT(**config)
766
767
dummy_inputs = ops.ones((batch_size, image_size, image_size, num_channels))
768
_ = cait_xxs24_224(dummy_inputs)
769
770
"""
771
We can successfully perform inference with the model. But what about implementation
772
correctness? There are many ways to verify it:
773
774
* Obtain the performance of the model (given it's been populated with the pre-trained
775
parameters) on the ImageNet-1k validation set (as the pretraining dataset was
776
ImageNet-1k).
777
* Fine-tune the model on a different dataset.
778
779
In order to verify that, we will load another instance of the same model that has been
780
already populated with the pre-trained parameters. Please refer to
781
[this repository](https://github.com/sayakpaul/cait-tf)
782
(developed by the author of this notebook) for more details.
783
Additionally, the repository provides code to verify model performance on the
784
[ImageNet-1k validation set](https://github.com/sayakpaul/cait-tf/tree/main/i1k_eval)
785
as well as
786
[fine-tuning](https://github.com/sayakpaul/cait-tf/blob/main/notebooks/finetune.ipynb).
787
"""
788
789
"""
790
## Load a pretrained model
791
"""
792
793
model_gcs_path = "gs://kaggle-tfhub-models-uncompressed/tfhub-modules/sayakpaul/cait_xxs24_224/1/uncompressed"
794
pretrained_model = keras.Sequential(
795
[keras.layers.TFSMLayer(model_gcs_path, call_endpoint="serving_default")]
796
)
797
798
"""
799
## Inference utilities
800
801
In the next couple of cells, we develop preprocessing utilities needed to run inference
802
with the pretrained model.
803
"""
804
# The preprocessing transformations include center cropping, and normalizing
805
# the pixel values with the ImageNet-1k training stats (mean and standard deviation).
806
crop_layer = keras.layers.CenterCrop(image_size, image_size)
807
norm_layer = keras.layers.Normalization(
808
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
809
variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
810
)
811
812
813
def preprocess_image(image, size=image_size):
814
image = np.array(image)
815
image_resized = ops.expand_dims(image, 0)
816
resize_size = int((256 / image_size) * size)
817
image_resized = ops.image.resize(
818
image_resized, (resize_size, resize_size), interpolation="bicubic"
819
)
820
image_resized = crop_layer(image_resized)
821
return norm_layer(image_resized).numpy()
822
823
824
def load_image_from_url(url):
825
image_bytes = io.BytesIO(urlopen(url).read())
826
image = PIL.Image.open(image_bytes)
827
preprocessed_image = preprocess_image(image)
828
return image, preprocessed_image
829
830
831
"""
832
Now, we retrieve the ImageNet-1k labels and load them as the model we're
833
loading was pretrained on the ImageNet-1k dataset.
834
"""
835
836
# ImageNet-1k class labels.
837
imagenet_labels = (
838
"https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt"
839
)
840
label_path = keras.utils.get_file(origin=imagenet_labels)
841
842
with open(label_path, "r") as f:
843
lines = f.readlines()
844
imagenet_labels = [line.rstrip() for line in lines]
845
846
"""
847
## Load an Image
848
"""
849
850
img_url = "https://i.imgur.com/ErgfLTn.jpg"
851
image, preprocessed_image = load_image_from_url(img_url)
852
853
# https://unsplash.com/photos/Ho93gVTRWW8
854
plt.imshow(image)
855
plt.axis("off")
856
plt.show()
857
858
"""
859
## Obtain Predictions
860
"""
861
862
outputs = pretrained_model.predict(preprocessed_image)
863
logits = outputs["output_1"]
864
ca_ffn_block_0_att = outputs["output_3_ca_ffn_block_0_att"]
865
ca_ffn_block_1_att = outputs["output_3_ca_ffn_block_1_att"]
866
867
predicted_label = imagenet_labels[int(np.argmax(logits))]
868
print(predicted_label)
869
870
"""
871
Now that we have obtained the predictions (which appear to be as expected), we can
872
further extend our investigation. Following the CaiT authors, we can investigate the
873
attention scores from the attention layers. This helps us to get deeper insights into the
874
modifications introduced in the CaiT paper.
875
"""
876
877
"""
878
## Visualizing the Attention Layers
879
880
We start by inspecting the shape of the attention weights returned by a Class Attention
881
layer.
882
"""
883
884
# (batch_size, nb_attention_heads, num_cls_token, seq_length)
885
print("Shape of the attention scores from a class attention block:")
886
print(ca_ffn_block_0_att.shape)
887
888
"""
889
The shape denotes we have got attention weights for each of the individual attention
890
heads. They quantify the information about how the CLS token is related to itself and the
891
rest of the image patches.
892
893
Next, we write a utility to:
894
895
* Visualize what the individual attention heads in the Class Attention layers are
896
focusing on. This helps us to get an idea of how the _spatial-class relationship_ is
897
induced in the CaiT model.
898
* Obtain a saliency map from the first Class Attention layer that helps to understand how
899
CA layer aggregates information from the region(s) of interest in the images.
900
901
This utility is referred from Figures 6 and 7 of the original
902
[CaiT paper](https://arxiv.org/abs/2103.17239). This is also a part of
903
[this notebook](https://github.com/sayakpaul/cait-tf/blob/main/notebooks/classification.ipynb)
904
(developed by the author of this tutorial).
905
"""
906
907
# Reference:
908
# https://github.com/facebookresearch/dino/blob/main/visualize_attention.py
909
910
patch_size = 16
911
912
913
def get_cls_attention_map(
914
attention_scores,
915
return_saliency=False,
916
) -> np.ndarray:
917
"""
918
Returns attention scores from a particular attention block.
919
920
Args:
921
attention_scores: the attention scores from the attention block to
922
visualize.
923
return_saliency: a boolean flag if set to True also returns the salient
924
representations of the attention block.
925
"""
926
w_featmap = preprocessed_image.shape[2] // patch_size
927
h_featmap = preprocessed_image.shape[1] // patch_size
928
929
nh = attention_scores.shape[1] # Number of attention heads.
930
931
# Taking the representations from CLS token.
932
attentions = attention_scores[0, :, 0, 1:].reshape(nh, -1)
933
934
# Reshape the attention scores to resemble mini patches.
935
attentions = attentions.reshape(nh, w_featmap, h_featmap)
936
937
if not return_saliency:
938
attentions = attentions.transpose((1, 2, 0))
939
940
else:
941
attentions = np.mean(attentions, axis=0)
942
attentions = (attentions - attentions.min()) / (
943
attentions.max() - attentions.min()
944
)
945
attentions = np.expand_dims(attentions, -1)
946
947
# Resize the attention patches to 224x224 (224: 14x16)
948
attentions = ops.image.resize(
949
attentions,
950
size=(h_featmap * patch_size, w_featmap * patch_size),
951
interpolation="bicubic",
952
)
953
954
return attentions
955
956
957
"""
958
In the first CA layer, we notice that the model is focusing solely on the region of
959
interest.
960
"""
961
962
attentions_ca_block_0 = get_cls_attention_map(ca_ffn_block_0_att)
963
964
965
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 13))
966
img_count = 0
967
968
for i in range(attentions_ca_block_0.shape[-1]):
969
if img_count < attentions_ca_block_0.shape[-1]:
970
axes[i].imshow(attentions_ca_block_0[:, :, img_count])
971
axes[i].title.set_text(f"Attention head: {img_count}")
972
axes[i].axis("off")
973
img_count += 1
974
975
fig.tight_layout()
976
plt.show()
977
978
"""
979
Whereas in the second CA layer, the model is trying to focus more on the context that
980
contains discriminative signals.
981
"""
982
983
attentions_ca_block_1 = get_cls_attention_map(ca_ffn_block_1_att)
984
985
986
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 13))
987
img_count = 0
988
989
for i in range(attentions_ca_block_1.shape[-1]):
990
if img_count < attentions_ca_block_1.shape[-1]:
991
axes[i].imshow(attentions_ca_block_1[:, :, img_count])
992
axes[i].title.set_text(f"Attention head: {img_count}")
993
axes[i].axis("off")
994
img_count += 1
995
996
fig.tight_layout()
997
plt.show()
998
999
"""
1000
Finally, we obtain the saliency map for the given image.
1001
"""
1002
1003
saliency_attention = get_cls_attention_map(ca_ffn_block_0_att, return_saliency=True)
1004
1005
image = np.array(image)
1006
image_resized = ops.expand_dims(image, 0)
1007
resize_size = int((256 / 224) * image_size)
1008
image_resized = ops.image.resize(
1009
image_resized, (resize_size, resize_size), interpolation="bicubic"
1010
)
1011
image_resized = crop_layer(image_resized)
1012
1013
plt.imshow(image_resized.numpy().squeeze().astype("int32"))
1014
plt.imshow(saliency_attention.numpy().squeeze(), cmap="cividis", alpha=0.9)
1015
plt.axis("off")
1016
1017
plt.show()
1018
1019
"""
1020
## Conclusion
1021
1022
In this notebook, we implemented the CaiT model. It shows how to mitigate the issues in
1023
ViTs when trying scale their depth while keeping the pretraining dataset fixed. I hope
1024
the additional visualizations provided in the notebook spark excitement in the community
1025
and people develop interesting methods to probe what models like ViT learn.
1026
1027
## Acknowledgement
1028
1029
Thanks to the ML Developer Programs team at Google providing Google Cloud Platform
1030
support.
1031
"""
1032
1033