Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/convmixer.py
3507 views
1
"""
2
Title: Image classification with ConvMixer
3
Author: [Sayak Paul](https://twitter.com/RisingSayak)
4
Date created: 2021/10/12
5
Last modified: 2021/10/12
6
Description: An all-convolutional network applied to patches of images.
7
Accelerator: GPU
8
Converted to Keras 3 by: [Md Awsafur Rahman](https://awsaf49.github.io)
9
"""
10
11
"""
12
## Introduction
13
14
Vision Transformers (ViT; [Dosovitskiy et al.](https://arxiv.org/abs/1612.00593)) extract
15
small patches from the input images, linearly project them, and then apply the
16
Transformer ([Vaswani et al.](https://arxiv.org/abs/1706.03762)) blocks. The application
17
of ViTs to image recognition tasks is quickly becoming a promising area of research,
18
because ViTs eliminate the need to have strong inductive biases (such as convolutions) for
19
modeling locality. This presents them as a general computation primititive capable of
20
learning just from the training data with as minimal inductive priors as possible. ViTs
21
yield great downstream performance when trained with proper regularization, data
22
augmentation, and relatively large datasets.
23
24
In the [Patches Are All You Need](https://openreview.net/pdf?id=TVHS5Y4dNvM) paper (note:
25
at
26
the time of writing, it is a submission to the ICLR 2022 conference), the authors extend
27
the idea of using patches to train an all-convolutional network and demonstrate
28
competitive results. Their architecture namely **ConvMixer** uses recipes from the recent
29
isotrophic architectures like ViT, MLP-Mixer
30
([Tolstikhin et al.](https://arxiv.org/abs/2105.01601)), such as using the same
31
depth and resolution across different layers in the network, residual connections,
32
and so on.
33
34
In this example, we will implement the ConvMixer model and demonstrate its performance on
35
the CIFAR-10 dataset.
36
"""
37
38
"""
39
## Imports
40
"""
41
42
import keras
43
from keras import layers
44
45
import matplotlib.pyplot as plt
46
import tensorflow as tf
47
import numpy as np
48
49
"""
50
## Hyperparameters
51
52
To keep run time short, we will train the model for only 10 epochs. To focus on
53
the core ideas of ConvMixer, we will not use other training-specific elements like
54
RandAugment ([Cubuk et al.](https://arxiv.org/abs/1909.13719)). If you are interested in
55
learning more about those details, please refer to the
56
[original paper](https://openreview.net/pdf?id=TVHS5Y4dNvM).
57
"""
58
59
learning_rate = 0.001
60
weight_decay = 0.0001
61
batch_size = 128
62
num_epochs = 10
63
64
"""
65
## Load the CIFAR-10 dataset
66
"""
67
68
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
69
val_split = 0.1
70
71
val_indices = int(len(x_train) * val_split)
72
new_x_train, new_y_train = x_train[val_indices:], y_train[val_indices:]
73
x_val, y_val = x_train[:val_indices], y_train[:val_indices]
74
75
print(f"Training data samples: {len(new_x_train)}")
76
print(f"Validation data samples: {len(x_val)}")
77
print(f"Test data samples: {len(x_test)}")
78
79
"""
80
## Prepare `tf.data.Dataset` objects
81
82
Our data augmentation pipeline is different from what the authors used for the CIFAR-10
83
dataset, which is fine for the purpose of the example.
84
Note that, it's ok to use **TF APIs for data I/O and preprocessing** with other backends
85
(jax, torch) as it is feature-complete framework when it comes to data preprocessing.
86
"""
87
88
image_size = 32
89
auto = tf.data.AUTOTUNE
90
91
augmentation_layers = [
92
keras.layers.RandomCrop(image_size, image_size),
93
keras.layers.RandomFlip("horizontal"),
94
]
95
96
97
def augment_images(images):
98
for layer in augmentation_layers:
99
images = layer(images, training=True)
100
return images
101
102
103
def make_datasets(images, labels, is_train=False):
104
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
105
if is_train:
106
dataset = dataset.shuffle(batch_size * 10)
107
dataset = dataset.batch(batch_size)
108
if is_train:
109
dataset = dataset.map(
110
lambda x, y: (augment_images(x), y), num_parallel_calls=auto
111
)
112
return dataset.prefetch(auto)
113
114
115
train_dataset = make_datasets(new_x_train, new_y_train, is_train=True)
116
val_dataset = make_datasets(x_val, y_val)
117
test_dataset = make_datasets(x_test, y_test)
118
119
"""
120
## ConvMixer utilities
121
122
The following figure (taken from the original paper) depicts the ConvMixer model:
123
124
![](https://i.imgur.com/yF8actg.png)
125
126
ConvMixer is very similar to the MLP-Mixer, model with the following key
127
differences:
128
129
* Instead of using fully-connected layers, it uses standard convolution layers.
130
* Instead of LayerNorm (which is typical for ViTs and MLP-Mixers), it uses BatchNorm.
131
132
Two types of convolution layers are used in ConvMixer. **(1)**: Depthwise convolutions,
133
for mixing spatial locations of the images, **(2)**: Pointwise convolutions (which follow
134
the depthwise convolutions), for mixing channel-wise information across the patches.
135
Another keypoint is the use of *larger kernel sizes* to allow a larger receptive field.
136
"""
137
138
139
def activation_block(x):
140
x = layers.Activation("gelu")(x)
141
return layers.BatchNormalization()(x)
142
143
144
def conv_stem(x, filters: int, patch_size: int):
145
x = layers.Conv2D(filters, kernel_size=patch_size, strides=patch_size)(x)
146
return activation_block(x)
147
148
149
def conv_mixer_block(x, filters: int, kernel_size: int):
150
# Depthwise convolution.
151
x0 = x
152
x = layers.DepthwiseConv2D(kernel_size=kernel_size, padding="same")(x)
153
x = layers.Add()([activation_block(x), x0]) # Residual.
154
155
# Pointwise convolution.
156
x = layers.Conv2D(filters, kernel_size=1)(x)
157
x = activation_block(x)
158
159
return x
160
161
162
def get_conv_mixer_256_8(
163
image_size=32, filters=256, depth=8, kernel_size=5, patch_size=2, num_classes=10
164
):
165
"""ConvMixer-256/8: https://openreview.net/pdf?id=TVHS5Y4dNvM.
166
The hyperparameter values are taken from the paper.
167
"""
168
inputs = keras.Input((image_size, image_size, 3))
169
x = layers.Rescaling(scale=1.0 / 255)(inputs)
170
171
# Extract patch embeddings.
172
x = conv_stem(x, filters, patch_size)
173
174
# ConvMixer blocks.
175
for _ in range(depth):
176
x = conv_mixer_block(x, filters, kernel_size)
177
178
# Classification block.
179
x = layers.GlobalAvgPool2D()(x)
180
outputs = layers.Dense(num_classes, activation="softmax")(x)
181
182
return keras.Model(inputs, outputs)
183
184
185
"""
186
The model used in this experiment is termed as **ConvMixer-256/8** where 256 denotes the
187
number of channels and 8 denotes the depth. The resulting model only has 0.8 million
188
parameters.
189
"""
190
191
"""
192
## Model training and evaluation utility
193
"""
194
195
# Code reference:
196
# https://keras.io/examples/vision/image_classification_with_vision_transformer/.
197
198
199
def run_experiment(model):
200
optimizer = keras.optimizers.AdamW(
201
learning_rate=learning_rate, weight_decay=weight_decay
202
)
203
204
model.compile(
205
optimizer=optimizer,
206
loss="sparse_categorical_crossentropy",
207
metrics=["accuracy"],
208
)
209
210
checkpoint_filepath = "/tmp/checkpoint.keras"
211
checkpoint_callback = keras.callbacks.ModelCheckpoint(
212
checkpoint_filepath,
213
monitor="val_accuracy",
214
save_best_only=True,
215
save_weights_only=False,
216
)
217
218
history = model.fit(
219
train_dataset,
220
validation_data=val_dataset,
221
epochs=num_epochs,
222
callbacks=[checkpoint_callback],
223
)
224
225
model.load_weights(checkpoint_filepath)
226
_, accuracy = model.evaluate(test_dataset)
227
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
228
229
return history, model
230
231
232
"""
233
## Train and evaluate model
234
"""
235
236
conv_mixer_model = get_conv_mixer_256_8()
237
history, conv_mixer_model = run_experiment(conv_mixer_model)
238
239
"""
240
The gap in training and validation performance can be mitigated by using additional
241
regularization techniques. Nevertheless, being able to get to ~83% accuracy within 10
242
epochs with 0.8 million parameters is a strong result.
243
"""
244
245
"""
246
## Visualizing the internals of ConvMixer
247
248
We can visualize the patch embeddings and the learned convolution filters. Recall
249
that each patch embedding and intermediate feature map have the same number of channels
250
(256 in this case). This will make our visualization utility easier to implement.
251
"""
252
253
# Code reference: https://bit.ly/3awIRbP.
254
255
256
def visualization_plot(weights, idx=1):
257
# First, apply min-max normalization to the
258
# given weights to avoid isotrophic scaling.
259
p_min, p_max = weights.min(), weights.max()
260
weights = (weights - p_min) / (p_max - p_min)
261
262
# Visualize all the filters.
263
num_filters = 256
264
plt.figure(figsize=(8, 8))
265
266
for i in range(num_filters):
267
current_weight = weights[:, :, :, i]
268
if current_weight.shape[-1] == 1:
269
current_weight = current_weight.squeeze()
270
ax = plt.subplot(16, 16, idx)
271
ax.set_xticks([])
272
ax.set_yticks([])
273
plt.imshow(current_weight)
274
idx += 1
275
276
277
# We first visualize the learned patch embeddings.
278
patch_embeddings = conv_mixer_model.layers[2].get_weights()[0]
279
visualization_plot(patch_embeddings)
280
281
"""
282
Even though we did not train the network to convergence, we can notice that different
283
patches show different patterns. Some share similarity with others while some are very
284
different. These visualizations are more salient with larger image sizes.
285
286
Similarly, we can visualize the raw convolution kernels. This can help us understand
287
the patterns to which a given kernel is receptive.
288
"""
289
290
# First, print the indices of the convolution layers that are not
291
# pointwise convolutions.
292
for i, layer in enumerate(conv_mixer_model.layers):
293
if isinstance(layer, layers.DepthwiseConv2D):
294
if layer.get_config()["kernel_size"] == (5, 5):
295
print(i, layer)
296
297
idx = 26 # Taking a kernel from the middle of the network.
298
299
kernel = conv_mixer_model.layers[idx].get_weights()[0]
300
kernel = np.expand_dims(kernel.squeeze(), axis=2)
301
visualization_plot(kernel)
302
303
"""
304
We see that different filters in the kernel have different locality spans, and this
305
pattern
306
is likely to evolve with more training.
307
"""
308
309
"""
310
## Final notes
311
312
There's been a recent trend on fusing convolutions with other data-agnostic operations
313
like self-attention. Following works are along this line of research:
314
315
* ConViT ([d'Ascoli et al.](https://arxiv.org/abs/2103.10697))
316
* CCT ([Hassani et al.](https://arxiv.org/abs/2104.05704))
317
* CoAtNet ([Dai et al.](https://arxiv.org/abs/2106.04803))
318
"""
319
320