Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/mobilevit.py
3507 views
1
"""
2
Title: MobileViT: A mobile-friendly Transformer-based model for image classification
3
Author: [Sayak Paul](https://twitter.com/RisingSayak)
4
Date created: 2021/10/20
5
Last modified: 2024/02/11
6
Description: MobileViT for image classification with combined benefits of convolutions and Transformers.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
In this example, we implement the MobileViT architecture
14
([Mehta et al.](https://arxiv.org/abs/2110.02178)),
15
which combines the benefits of Transformers
16
([Vaswani et al.](https://arxiv.org/abs/1706.03762))
17
and convolutions. With Transformers, we can capture long-range dependencies that result
18
in global representations. With convolutions, we can capture spatial relationships that
19
model locality.
20
21
Besides combining the properties of Transformers and convolutions, the authors introduce
22
MobileViT as a general-purpose mobile-friendly backbone for different image recognition
23
tasks. Their findings suggest that, performance-wise, MobileViT is better than other
24
models with the same or higher complexity ([MobileNetV3](https://arxiv.org/abs/1905.02244),
25
for example), while being efficient on mobile devices.
26
27
Note: This example should be run with Tensorflow 2.13 and higher.
28
"""
29
30
"""
31
## Imports
32
"""
33
34
import os
35
import tensorflow as tf
36
37
os.environ["KERAS_BACKEND"] = "tensorflow"
38
39
import keras
40
from keras import layers
41
from keras import backend
42
43
import tensorflow_datasets as tfds
44
45
tfds.disable_progress_bar()
46
47
"""
48
## Hyperparameters
49
"""
50
51
# Values are from table 4.
52
patch_size = 4 # 2x2, for the Transformer blocks.
53
image_size = 256
54
expansion_factor = 2 # expansion factor for the MobileNetV2 blocks.
55
56
"""
57
## MobileViT utilities
58
59
The MobileViT architecture is comprised of the following blocks:
60
61
* Strided 3x3 convolutions that process the input image.
62
* [MobileNetV2](https://arxiv.org/abs/1801.04381)-style inverted residual blocks for
63
downsampling the resolution of the intermediate feature maps.
64
* MobileViT blocks that combine the benefits of Transformers and convolutions. It is
65
presented in the figure below (taken from the
66
[original paper](https://arxiv.org/abs/2110.02178)):
67
68
69
![](https://i.imgur.com/mANnhI7.png)
70
"""
71
72
73
def conv_block(x, filters=16, kernel_size=3, strides=2):
74
conv_layer = layers.Conv2D(
75
filters,
76
kernel_size,
77
strides=strides,
78
activation=keras.activations.swish,
79
padding="same",
80
)
81
return conv_layer(x)
82
83
84
# Reference: https://github.com/keras-team/keras/blob/e3858739d178fe16a0c77ce7fab88b0be6dbbdc7/keras/applications/imagenet_utils.py#L413C17-L435
85
86
87
def correct_pad(inputs, kernel_size):
88
img_dim = 2 if backend.image_data_format() == "channels_first" else 1
89
input_size = inputs.shape[img_dim : (img_dim + 2)]
90
if isinstance(kernel_size, int):
91
kernel_size = (kernel_size, kernel_size)
92
if input_size[0] is None:
93
adjust = (1, 1)
94
else:
95
adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2)
96
correct = (kernel_size[0] // 2, kernel_size[1] // 2)
97
return (
98
(correct[0] - adjust[0], correct[0]),
99
(correct[1] - adjust[1], correct[1]),
100
)
101
102
103
# Reference: https://git.io/JKgtC
104
105
106
def inverted_residual_block(x, expanded_channels, output_channels, strides=1):
107
m = layers.Conv2D(expanded_channels, 1, padding="same", use_bias=False)(x)
108
m = layers.BatchNormalization()(m)
109
m = keras.activations.swish(m)
110
111
if strides == 2:
112
m = layers.ZeroPadding2D(padding=correct_pad(m, 3))(m)
113
m = layers.DepthwiseConv2D(
114
3, strides=strides, padding="same" if strides == 1 else "valid", use_bias=False
115
)(m)
116
m = layers.BatchNormalization()(m)
117
m = keras.activations.swish(m)
118
119
m = layers.Conv2D(output_channels, 1, padding="same", use_bias=False)(m)
120
m = layers.BatchNormalization()(m)
121
122
if keras.ops.equal(x.shape[-1], output_channels) and strides == 1:
123
return layers.Add()([m, x])
124
return m
125
126
127
# Reference:
128
# https://keras.io/examples/vision/image_classification_with_vision_transformer/
129
130
131
def mlp(x, hidden_units, dropout_rate):
132
for units in hidden_units:
133
x = layers.Dense(units, activation=keras.activations.swish)(x)
134
x = layers.Dropout(dropout_rate)(x)
135
return x
136
137
138
def transformer_block(x, transformer_layers, projection_dim, num_heads=2):
139
for _ in range(transformer_layers):
140
# Layer normalization 1.
141
x1 = layers.LayerNormalization(epsilon=1e-6)(x)
142
# Create a multi-head attention layer.
143
attention_output = layers.MultiHeadAttention(
144
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
145
)(x1, x1)
146
# Skip connection 1.
147
x2 = layers.Add()([attention_output, x])
148
# Layer normalization 2.
149
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
150
# MLP.
151
x3 = mlp(
152
x3,
153
hidden_units=[x.shape[-1] * 2, x.shape[-1]],
154
dropout_rate=0.1,
155
)
156
# Skip connection 2.
157
x = layers.Add()([x3, x2])
158
159
return x
160
161
162
def mobilevit_block(x, num_blocks, projection_dim, strides=1):
163
# Local projection with convolutions.
164
local_features = conv_block(x, filters=projection_dim, strides=strides)
165
local_features = conv_block(
166
local_features, filters=projection_dim, kernel_size=1, strides=strides
167
)
168
169
# Unfold into patches and then pass through Transformers.
170
num_patches = int((local_features.shape[1] * local_features.shape[2]) / patch_size)
171
non_overlapping_patches = layers.Reshape((patch_size, num_patches, projection_dim))(
172
local_features
173
)
174
global_features = transformer_block(
175
non_overlapping_patches, num_blocks, projection_dim
176
)
177
178
# Fold into conv-like feature-maps.
179
folded_feature_map = layers.Reshape((*local_features.shape[1:-1], projection_dim))(
180
global_features
181
)
182
183
# Apply point-wise conv -> concatenate with the input features.
184
folded_feature_map = conv_block(
185
folded_feature_map, filters=x.shape[-1], kernel_size=1, strides=strides
186
)
187
local_global_features = layers.Concatenate(axis=-1)([x, folded_feature_map])
188
189
# Fuse the local and global features using a convoluion layer.
190
local_global_features = conv_block(
191
local_global_features, filters=projection_dim, strides=strides
192
)
193
194
return local_global_features
195
196
197
"""
198
**More on the MobileViT block**:
199
200
* First, the feature representations (A) go through convolution blocks that capture local
201
relationships. The expected shape of a single entry here would be `(h, w, num_channels)`.
202
* Then they get unfolded into another vector with shape `(p, n, num_channels)`,
203
where `p` is the area of a small patch, and `n` is `(h * w) / p`. So, we end up with `n`
204
non-overlapping patches.
205
* This unfolded vector is then passed through a Tranformer block that captures global
206
relationships between the patches.
207
* The output vector (B) is again folded into a vector of shape `(h, w, num_channels)`
208
resembling a feature map coming out of convolutions.
209
210
Vectors A and B are then passed through two more convolutional layers to fuse the local
211
and global representations. Notice how the spatial resolution of the final vector remains
212
unchanged at this point. The authors also present an explanation of how the MobileViT
213
block resembles a convolution block of a CNN. For more details, please refer to the
214
original paper.
215
"""
216
217
"""
218
Next, we combine these blocks together and implement the MobileViT architecture (XXS
219
variant). The following figure (taken from the original paper) presents a schematic
220
representation of the architecture:
221
222
![](https://i.ibb.co/sRbVRBN/image.png)
223
"""
224
225
226
def create_mobilevit(num_classes=5):
227
inputs = keras.Input((image_size, image_size, 3))
228
x = layers.Rescaling(scale=1.0 / 255)(inputs)
229
230
# Initial conv-stem -> MV2 block.
231
x = conv_block(x, filters=16)
232
x = inverted_residual_block(
233
x, expanded_channels=16 * expansion_factor, output_channels=16
234
)
235
236
# Downsampling with MV2 block.
237
x = inverted_residual_block(
238
x, expanded_channels=16 * expansion_factor, output_channels=24, strides=2
239
)
240
x = inverted_residual_block(
241
x, expanded_channels=24 * expansion_factor, output_channels=24
242
)
243
x = inverted_residual_block(
244
x, expanded_channels=24 * expansion_factor, output_channels=24
245
)
246
247
# First MV2 -> MobileViT block.
248
x = inverted_residual_block(
249
x, expanded_channels=24 * expansion_factor, output_channels=48, strides=2
250
)
251
x = mobilevit_block(x, num_blocks=2, projection_dim=64)
252
253
# Second MV2 -> MobileViT block.
254
x = inverted_residual_block(
255
x, expanded_channels=64 * expansion_factor, output_channels=64, strides=2
256
)
257
x = mobilevit_block(x, num_blocks=4, projection_dim=80)
258
259
# Third MV2 -> MobileViT block.
260
x = inverted_residual_block(
261
x, expanded_channels=80 * expansion_factor, output_channels=80, strides=2
262
)
263
x = mobilevit_block(x, num_blocks=3, projection_dim=96)
264
x = conv_block(x, filters=320, kernel_size=1, strides=1)
265
266
# Classification head.
267
x = layers.GlobalAvgPool2D()(x)
268
outputs = layers.Dense(num_classes, activation="softmax")(x)
269
270
return keras.Model(inputs, outputs)
271
272
273
mobilevit_xxs = create_mobilevit()
274
mobilevit_xxs.summary()
275
276
"""
277
## Dataset preparation
278
279
We will be using the
280
[`tf_flowers`](https://www.tensorflow.org/datasets/catalog/tf_flowers)
281
dataset to demonstrate the model. Unlike other Transformer-based architectures,
282
MobileViT uses a simple augmentation pipeline primarily because it has the properties
283
of a CNN.
284
"""
285
286
batch_size = 64
287
auto = tf.data.AUTOTUNE
288
resize_bigger = 280
289
num_classes = 5
290
291
292
def preprocess_dataset(is_training=True):
293
def _pp(image, label):
294
if is_training:
295
# Resize to a bigger spatial resolution and take the random
296
# crops.
297
image = tf.image.resize(image, (resize_bigger, resize_bigger))
298
image = tf.image.random_crop(image, (image_size, image_size, 3))
299
image = tf.image.random_flip_left_right(image)
300
else:
301
image = tf.image.resize(image, (image_size, image_size))
302
label = tf.one_hot(label, depth=num_classes)
303
return image, label
304
305
return _pp
306
307
308
def prepare_dataset(dataset, is_training=True):
309
if is_training:
310
dataset = dataset.shuffle(batch_size * 10)
311
dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=auto)
312
return dataset.batch(batch_size).prefetch(auto)
313
314
315
"""
316
The authors use a multi-scale data sampler to help the model learn representations of
317
varied scales. In this example, we discard this part.
318
"""
319
320
"""
321
## Load and prepare the dataset
322
"""
323
324
train_dataset, val_dataset = tfds.load(
325
"tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True
326
)
327
328
num_train = train_dataset.cardinality()
329
num_val = val_dataset.cardinality()
330
print(f"Number of training examples: {num_train}")
331
print(f"Number of validation examples: {num_val}")
332
333
train_dataset = prepare_dataset(train_dataset, is_training=True)
334
val_dataset = prepare_dataset(val_dataset, is_training=False)
335
336
"""
337
## Train a MobileViT (XXS) model
338
"""
339
340
learning_rate = 0.002
341
label_smoothing_factor = 0.1
342
epochs = 30
343
344
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
345
loss_fn = keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing_factor)
346
347
348
def run_experiment(epochs=epochs):
349
mobilevit_xxs = create_mobilevit(num_classes=num_classes)
350
mobilevit_xxs.compile(optimizer=optimizer, loss=loss_fn, metrics=["accuracy"])
351
352
# When using `save_weights_only=True` in `ModelCheckpoint`, the filepath provided must end in `.weights.h5`
353
checkpoint_filepath = "/tmp/checkpoint.weights.h5"
354
checkpoint_callback = keras.callbacks.ModelCheckpoint(
355
checkpoint_filepath,
356
monitor="val_accuracy",
357
save_best_only=True,
358
save_weights_only=True,
359
)
360
361
mobilevit_xxs.fit(
362
train_dataset,
363
validation_data=val_dataset,
364
epochs=epochs,
365
callbacks=[checkpoint_callback],
366
)
367
mobilevit_xxs.load_weights(checkpoint_filepath)
368
_, accuracy = mobilevit_xxs.evaluate(val_dataset)
369
print(f"Validation accuracy: {round(accuracy * 100, 2)}%")
370
return mobilevit_xxs
371
372
373
mobilevit_xxs = run_experiment()
374
375
"""
376
## Results and TFLite conversion
377
378
With about one million parameters, getting to ~85% top-1 accuracy on 256x256 resolution is
379
a strong result. This MobileViT mobile is fully compatible with TensorFlow Lite (TFLite)
380
and can be converted with the following code:
381
"""
382
383
# Serialize the model as a SavedModel.
384
tf.saved_model.save(mobilevit_xxs, "mobilevit_xxs")
385
386
# Convert to TFLite. This form of quantization is called
387
# post-training dynamic-range quantization in TFLite.
388
converter = tf.lite.TFLiteConverter.from_saved_model("mobilevit_xxs")
389
converter.optimizations = [tf.lite.Optimize.DEFAULT]
390
converter.target_spec.supported_ops = [
391
tf.lite.OpsSet.TFLITE_BUILTINS, # Enable TensorFlow Lite ops.
392
tf.lite.OpsSet.SELECT_TF_OPS, # Enable TensorFlow ops.
393
]
394
tflite_model = converter.convert()
395
open("mobilevit_xxs.tflite", "wb").write(tflite_model)
396
397
"""
398
To learn more about different quantization recipes available in TFLite and running
399
inference with TFLite models, check out
400
[this official resource](https://www.tensorflow.org/lite/performance/post_training_quantization).
401
402
You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/mobile-vit-xxs)
403
and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/Flowers-Classification-MobileViT).
404
"""
405
406