Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/cct.py
3507 views
1
"""
2
Title: Compact Convolutional Transformers
3
Author: [Sayak Paul](https://twitter.com/RisingSayak)
4
Date created: 2021/06/30
5
Last modified: 2023/08/07
6
Description: Compact Convolutional Transformers for efficient image classification.
7
Accelerator: GPU
8
Converted to Keras 3 by: [Muhammad Anas Raza](https://anasrz.com), [Guillaume Baquiast](https://www.linkedin.com/in/guillaume-baquiast-478965ba/)
9
"""
10
11
"""
12
As discussed in the [Vision Transformers (ViT)](https://arxiv.org/abs/2010.11929) paper,
13
a Transformer-based architecture for vision typically requires a larger dataset than
14
usual, as well as a longer pre-training schedule. [ImageNet-1k](http://imagenet.org/)
15
(which has about a million images) is considered to fall under the medium-sized data regime with
16
respect to ViTs. This is primarily because, unlike CNNs, ViTs (or a typical
17
Transformer-based architecture) do not have well-informed inductive biases (such as
18
convolutions for processing images). This begs the question: can't we combine the
19
benefits of convolution and the benefits of Transformers
20
in a single network architecture? These benefits include parameter-efficiency, and
21
self-attention to process long-range and global dependencies (interactions between
22
different regions in an image).
23
24
In [Escaping the Big Data Paradigm with Compact Transformers](https://arxiv.org/abs/2104.05704),
25
Hassani et al. present an approach for doing exactly this. They proposed the
26
**Compact Convolutional Transformer** (CCT) architecture. In this example, we will work on an
27
implementation of CCT and we will see how well it performs on the CIFAR-10 dataset.
28
29
If you are unfamiliar with the concept of self-attention or Transformers, you can read
30
[this chapter](https://livebook.manning.com/book/deep-learning-with-python-second-edition/chapter-11/r-3/312)
31
from François Chollet's book *Deep Learning with Python*. This example uses
32
code snippets from another example,
33
[Image classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/).
34
"""
35
36
"""
37
## Imports
38
"""
39
40
from keras import layers
41
import keras
42
43
import matplotlib.pyplot as plt
44
import numpy as np
45
46
"""
47
## Hyperparameters and constants
48
"""
49
50
positional_emb = True
51
conv_layers = 2
52
projection_dim = 128
53
54
num_heads = 2
55
transformer_units = [
56
projection_dim,
57
projection_dim,
58
]
59
transformer_layers = 2
60
stochastic_depth_rate = 0.1
61
62
learning_rate = 0.001
63
weight_decay = 0.0001
64
batch_size = 128
65
num_epochs = 30
66
image_size = 32
67
68
"""
69
## Load CIFAR-10 dataset
70
"""
71
72
num_classes = 10
73
input_shape = (32, 32, 3)
74
75
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
76
77
y_train = keras.utils.to_categorical(y_train, num_classes)
78
y_test = keras.utils.to_categorical(y_test, num_classes)
79
80
print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
81
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
82
83
"""
84
## The CCT tokenizer
85
86
The first recipe introduced by the CCT authors is the tokenizer for processing the
87
images. In a standard ViT, images are organized into uniform *non-overlapping* patches.
88
This eliminates the boundary-level information present in between different patches. This
89
is important for a neural network to effectively exploit the locality information. The
90
figure below presents an illustration of how images are organized into patches.
91
92
![](https://i.imgur.com/IkBK9oY.png)
93
94
We already know that convolutions are quite good at exploiting locality information. So,
95
based on this, the authors introduce an all-convolution mini-network to produce image
96
patches.
97
"""
98
99
100
class CCTTokenizer(layers.Layer):
101
def __init__(
102
self,
103
kernel_size=3,
104
stride=1,
105
padding=1,
106
pooling_kernel_size=3,
107
pooling_stride=2,
108
num_conv_layers=conv_layers,
109
num_output_channels=[64, 128],
110
positional_emb=positional_emb,
111
**kwargs,
112
):
113
super().__init__(**kwargs)
114
115
# This is our tokenizer.
116
self.conv_model = keras.Sequential()
117
for i in range(num_conv_layers):
118
self.conv_model.add(
119
layers.Conv2D(
120
num_output_channels[i],
121
kernel_size,
122
stride,
123
padding="valid",
124
use_bias=False,
125
activation="relu",
126
kernel_initializer="he_normal",
127
)
128
)
129
self.conv_model.add(layers.ZeroPadding2D(padding))
130
self.conv_model.add(
131
layers.MaxPooling2D(pooling_kernel_size, pooling_stride, "same")
132
)
133
134
self.positional_emb = positional_emb
135
136
def call(self, images):
137
outputs = self.conv_model(images)
138
# After passing the images through our mini-network the spatial dimensions
139
# are flattened to form sequences.
140
reshaped = keras.ops.reshape(
141
outputs,
142
(
143
-1,
144
keras.ops.shape(outputs)[1] * keras.ops.shape(outputs)[2],
145
keras.ops.shape(outputs)[-1],
146
),
147
)
148
return reshaped
149
150
151
"""
152
Positional embeddings are optional in CCT. If we want to use them, we can use
153
the Layer defined below.
154
"""
155
156
157
class PositionEmbedding(keras.layers.Layer):
158
def __init__(
159
self,
160
sequence_length,
161
initializer="glorot_uniform",
162
**kwargs,
163
):
164
super().__init__(**kwargs)
165
if sequence_length is None:
166
raise ValueError("`sequence_length` must be an Integer, received `None`.")
167
self.sequence_length = int(sequence_length)
168
self.initializer = keras.initializers.get(initializer)
169
170
def get_config(self):
171
config = super().get_config()
172
config.update(
173
{
174
"sequence_length": self.sequence_length,
175
"initializer": keras.initializers.serialize(self.initializer),
176
}
177
)
178
return config
179
180
def build(self, input_shape):
181
feature_size = input_shape[-1]
182
self.position_embeddings = self.add_weight(
183
name="embeddings",
184
shape=[self.sequence_length, feature_size],
185
initializer=self.initializer,
186
trainable=True,
187
)
188
189
super().build(input_shape)
190
191
def call(self, inputs, start_index=0):
192
shape = keras.ops.shape(inputs)
193
feature_length = shape[-1]
194
sequence_length = shape[-2]
195
# trim to match the length of the input sequence, which might be less
196
# than the sequence_length of the layer.
197
position_embeddings = keras.ops.convert_to_tensor(self.position_embeddings)
198
position_embeddings = keras.ops.slice(
199
position_embeddings,
200
(start_index, 0),
201
(sequence_length, feature_length),
202
)
203
return keras.ops.broadcast_to(position_embeddings, shape)
204
205
def compute_output_shape(self, input_shape):
206
return input_shape
207
208
209
"""
210
## Sequence Pooling
211
Another recipe introduced in CCT is attention pooling or sequence pooling. In ViT, only
212
the feature map corresponding to the class token is pooled and is then used for the
213
subsequent classification task (or any other downstream task).
214
"""
215
216
217
class SequencePooling(layers.Layer):
218
def __init__(self):
219
super().__init__()
220
self.attention = layers.Dense(1)
221
222
def call(self, x):
223
attention_weights = keras.ops.softmax(self.attention(x), axis=1)
224
attention_weights = keras.ops.transpose(attention_weights, axes=(0, 2, 1))
225
weighted_representation = keras.ops.matmul(attention_weights, x)
226
return keras.ops.squeeze(weighted_representation, -2)
227
228
229
"""
230
## Stochastic depth for regularization
231
232
[Stochastic depth](https://arxiv.org/abs/1603.09382) is a regularization technique that
233
randomly drops a set of layers. During inference, the layers are kept as they are. It is
234
very much similar to [Dropout](https://jmlr.org/papers/v15/srivastava14a.html) but only
235
that it operates on a block of layers rather than individual nodes present inside a
236
layer. In CCT, stochastic depth is used just before the residual blocks of a Transformers
237
encoder.
238
"""
239
240
241
# Referred from: github.com:rwightman/pytorch-image-models.
242
class StochasticDepth(layers.Layer):
243
def __init__(self, drop_prop, **kwargs):
244
super().__init__(**kwargs)
245
self.drop_prob = drop_prop
246
self.seed_generator = keras.random.SeedGenerator(1337)
247
248
def call(self, x, training=None):
249
if training:
250
keep_prob = 1 - self.drop_prob
251
shape = (keras.ops.shape(x)[0],) + (1,) * (len(x.shape) - 1)
252
random_tensor = keep_prob + keras.random.uniform(
253
shape, 0, 1, seed=self.seed_generator
254
)
255
random_tensor = keras.ops.floor(random_tensor)
256
return (x / keep_prob) * random_tensor
257
return x
258
259
260
"""
261
## MLP for the Transformers encoder
262
"""
263
264
265
def mlp(x, hidden_units, dropout_rate):
266
for units in hidden_units:
267
x = layers.Dense(units, activation=keras.ops.gelu)(x)
268
x = layers.Dropout(dropout_rate)(x)
269
return x
270
271
272
"""
273
## Data augmentation
274
275
In the [original paper](https://arxiv.org/abs/2104.05704), the authors use
276
[AutoAugment](https://arxiv.org/abs/1805.09501) to induce stronger regularization. For
277
this example, we will be using the standard geometric augmentations like random cropping
278
and flipping.
279
"""
280
281
# Note the rescaling layer. These layers have pre-defined inference behavior.
282
data_augmentation = keras.Sequential(
283
[
284
layers.Rescaling(scale=1.0 / 255),
285
layers.RandomCrop(image_size, image_size),
286
layers.RandomFlip("horizontal"),
287
],
288
name="data_augmentation",
289
)
290
291
"""
292
## The final CCT model
293
294
In CCT, outputs from the Transformers encoder are weighted and then passed on to the final task-specific layer (in
295
this example, we do classification).
296
"""
297
298
299
def create_cct_model(
300
image_size=image_size,
301
input_shape=input_shape,
302
num_heads=num_heads,
303
projection_dim=projection_dim,
304
transformer_units=transformer_units,
305
):
306
inputs = layers.Input(input_shape)
307
308
# Augment data.
309
augmented = data_augmentation(inputs)
310
311
# Encode patches.
312
cct_tokenizer = CCTTokenizer()
313
encoded_patches = cct_tokenizer(augmented)
314
315
# Apply positional embedding.
316
if positional_emb:
317
sequence_length = encoded_patches.shape[1]
318
encoded_patches += PositionEmbedding(sequence_length=sequence_length)(
319
encoded_patches
320
)
321
322
# Calculate Stochastic Depth probabilities.
323
dpr = [x for x in np.linspace(0, stochastic_depth_rate, transformer_layers)]
324
325
# Create multiple layers of the Transformer block.
326
for i in range(transformer_layers):
327
# Layer normalization 1.
328
x1 = layers.LayerNormalization(epsilon=1e-5)(encoded_patches)
329
330
# Create a multi-head attention layer.
331
attention_output = layers.MultiHeadAttention(
332
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
333
)(x1, x1)
334
335
# Skip connection 1.
336
attention_output = StochasticDepth(dpr[i])(attention_output)
337
x2 = layers.Add()([attention_output, encoded_patches])
338
339
# Layer normalization 2.
340
x3 = layers.LayerNormalization(epsilon=1e-5)(x2)
341
342
# MLP.
343
x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
344
345
# Skip connection 2.
346
x3 = StochasticDepth(dpr[i])(x3)
347
encoded_patches = layers.Add()([x3, x2])
348
349
# Apply sequence pooling.
350
representation = layers.LayerNormalization(epsilon=1e-5)(encoded_patches)
351
weighted_representation = SequencePooling()(representation)
352
353
# Classify outputs.
354
logits = layers.Dense(num_classes)(weighted_representation)
355
# Create the Keras model.
356
model = keras.Model(inputs=inputs, outputs=logits)
357
return model
358
359
360
"""
361
## Model training and evaluation
362
"""
363
364
365
def run_experiment(model):
366
optimizer = keras.optimizers.AdamW(learning_rate=0.001, weight_decay=0.0001)
367
368
model.compile(
369
optimizer=optimizer,
370
loss=keras.losses.CategoricalCrossentropy(
371
from_logits=True, label_smoothing=0.1
372
),
373
metrics=[
374
keras.metrics.CategoricalAccuracy(name="accuracy"),
375
keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
376
],
377
)
378
379
checkpoint_filepath = "/tmp/checkpoint.weights.h5"
380
checkpoint_callback = keras.callbacks.ModelCheckpoint(
381
checkpoint_filepath,
382
monitor="val_accuracy",
383
save_best_only=True,
384
save_weights_only=True,
385
)
386
387
history = model.fit(
388
x=x_train,
389
y=y_train,
390
batch_size=batch_size,
391
epochs=num_epochs,
392
validation_split=0.1,
393
callbacks=[checkpoint_callback],
394
)
395
396
model.load_weights(checkpoint_filepath)
397
_, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
398
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
399
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
400
401
return history
402
403
404
cct_model = create_cct_model()
405
history = run_experiment(cct_model)
406
407
"""
408
Let's now visualize the training progress of the model.
409
"""
410
411
plt.plot(history.history["loss"], label="train_loss")
412
plt.plot(history.history["val_loss"], label="val_loss")
413
plt.xlabel("Epochs")
414
plt.ylabel("Loss")
415
plt.title("Train and Validation Losses Over Epochs", fontsize=14)
416
plt.legend()
417
plt.grid()
418
plt.show()
419
420
"""
421
The CCT model we just trained has just **0.4 million** parameters, and it gets us to
422
~79% top-1 accuracy within 30 epochs. The plot above shows no signs of overfitting as
423
well. This means we can train this network for longer (perhaps with a bit more
424
regularization) and may obtain even better performance. This performance can further be
425
improved by additional recipes like cosine decay learning rate schedule, other data augmentation
426
techniques like [AutoAugment](https://arxiv.org/abs/1805.09501),
427
[MixUp](https://arxiv.org/abs/1710.09412) or
428
[Cutmix](https://arxiv.org/abs/1905.04899). With these modifications, the authors present
429
95.1% top-1 accuracy on the CIFAR-10 dataset. The authors also present a number of
430
experiments to study how the number of convolution blocks, Transformers layers, etc.
431
affect the final performance of CCTs.
432
433
For a comparison, a ViT model takes about **4.7 million** parameters and **100
434
epochs** of training to reach a top-1 accuracy of 78.22% on the CIFAR-10 dataset. You can
435
refer to
436
[this notebook](https://colab.research.google.com/gist/sayakpaul/1a80d9f582b044354a1a26c5cb3d69e5/image_classification_with_vision_transformer.ipynb)
437
to know about the experimental setup.
438
439
The authors also demonstrate the performance of Compact Convolutional Transformers on
440
NLP tasks and they report competitive results there.
441
"""
442
443