"""
Title: MobileViT: A mobile-friendly Transformer-based model for image classification
Author: [Sayak Paul](https://twitter.com/RisingSayak)
Date created: 2021/10/20
Last modified: 2024/02/11
Description: MobileViT for image classification with combined benefits of convolutions and Transformers.
Accelerator: GPU
"""
"""
## Introduction
In this example, we implement the MobileViT architecture
([Mehta et al.](https://arxiv.org/abs/2110.02178)),
which combines the benefits of Transformers
([Vaswani et al.](https://arxiv.org/abs/1706.03762))
and convolutions. With Transformers, we can capture long-range dependencies that result
in global representations. With convolutions, we can capture spatial relationships that
model locality.
Besides combining the properties of Transformers and convolutions, the authors introduce
MobileViT as a general-purpose mobile-friendly backbone for different image recognition
tasks. Their findings suggest that, performance-wise, MobileViT is better than other
models with the same or higher complexity ([MobileNetV3](https://arxiv.org/abs/1905.02244),
for example), while being efficient on mobile devices.
Note: This example should be run with Tensorflow 2.13 and higher.
"""
"""
## Imports
"""
import os
import tensorflow as tf
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
from keras import layers
from keras import backend
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
"""
## Hyperparameters
"""
patch_size = 4
image_size = 256
expansion_factor = 2
"""
## MobileViT utilities
The MobileViT architecture is comprised of the following blocks:
* Strided 3x3 convolutions that process the input image.
* [MobileNetV2](https://arxiv.org/abs/1801.04381)-style inverted residual blocks for
downsampling the resolution of the intermediate feature maps.
* MobileViT blocks that combine the benefits of Transformers and convolutions. It is
presented in the figure below (taken from the
[original paper](https://arxiv.org/abs/2110.02178)):

"""
def conv_block(x, filters=16, kernel_size=3, strides=2):
conv_layer = layers.Conv2D(
filters,
kernel_size,
strides=strides,
activation=keras.activations.swish,
padding="same",
)
return conv_layer(x)
def correct_pad(inputs, kernel_size):
img_dim = 2 if backend.image_data_format() == "channels_first" else 1
input_size = inputs.shape[img_dim : (img_dim + 2)]
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if input_size[0] is None:
adjust = (1, 1)
else:
adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2)
correct = (kernel_size[0] // 2, kernel_size[1] // 2)
return (
(correct[0] - adjust[0], correct[0]),
(correct[1] - adjust[1], correct[1]),
)
def inverted_residual_block(x, expanded_channels, output_channels, strides=1):
m = layers.Conv2D(expanded_channels, 1, padding="same", use_bias=False)(x)
m = layers.BatchNormalization()(m)
m = keras.activations.swish(m)
if strides == 2:
m = layers.ZeroPadding2D(padding=correct_pad(m, 3))(m)
m = layers.DepthwiseConv2D(
3, strides=strides, padding="same" if strides == 1 else "valid", use_bias=False
)(m)
m = layers.BatchNormalization()(m)
m = keras.activations.swish(m)
m = layers.Conv2D(output_channels, 1, padding="same", use_bias=False)(m)
m = layers.BatchNormalization()(m)
if keras.ops.equal(x.shape[-1], output_channels) and strides == 1:
return layers.Add()([m, x])
return m
def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:
x = layers.Dense(units, activation=keras.activations.swish)(x)
x = layers.Dropout(dropout_rate)(x)
return x
def transformer_block(x, transformer_layers, projection_dim, num_heads=2):
for _ in range(transformer_layers):
x1 = layers.LayerNormalization(epsilon=1e-6)(x)
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
)(x1, x1)
x2 = layers.Add()([attention_output, x])
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
x3 = mlp(
x3,
hidden_units=[x.shape[-1] * 2, x.shape[-1]],
dropout_rate=0.1,
)
x = layers.Add()([x3, x2])
return x
def mobilevit_block(x, num_blocks, projection_dim, strides=1):
local_features = conv_block(x, filters=projection_dim, strides=strides)
local_features = conv_block(
local_features, filters=projection_dim, kernel_size=1, strides=strides
)
num_patches = int((local_features.shape[1] * local_features.shape[2]) / patch_size)
non_overlapping_patches = layers.Reshape((patch_size, num_patches, projection_dim))(
local_features
)
global_features = transformer_block(
non_overlapping_patches, num_blocks, projection_dim
)
folded_feature_map = layers.Reshape((*local_features.shape[1:-1], projection_dim))(
global_features
)
folded_feature_map = conv_block(
folded_feature_map, filters=x.shape[-1], kernel_size=1, strides=strides
)
local_global_features = layers.Concatenate(axis=-1)([x, folded_feature_map])
local_global_features = conv_block(
local_global_features, filters=projection_dim, strides=strides
)
return local_global_features
"""
**More on the MobileViT block**:
* First, the feature representations (A) go through convolution blocks that capture local
relationships. The expected shape of a single entry here would be `(h, w, num_channels)`.
* Then they get unfolded into another vector with shape `(p, n, num_channels)`,
where `p` is the area of a small patch, and `n` is `(h * w) / p`. So, we end up with `n`
non-overlapping patches.
* This unfolded vector is then passed through a Tranformer block that captures global
relationships between the patches.
* The output vector (B) is again folded into a vector of shape `(h, w, num_channels)`
resembling a feature map coming out of convolutions.
Vectors A and B are then passed through two more convolutional layers to fuse the local
and global representations. Notice how the spatial resolution of the final vector remains
unchanged at this point. The authors also present an explanation of how the MobileViT
block resembles a convolution block of a CNN. For more details, please refer to the
original paper.
"""
"""
Next, we combine these blocks together and implement the MobileViT architecture (XXS
variant). The following figure (taken from the original paper) presents a schematic
representation of the architecture:

"""
def create_mobilevit(num_classes=5):
inputs = keras.Input((image_size, image_size, 3))
x = layers.Rescaling(scale=1.0 / 255)(inputs)
x = conv_block(x, filters=16)
x = inverted_residual_block(
x, expanded_channels=16 * expansion_factor, output_channels=16
)
x = inverted_residual_block(
x, expanded_channels=16 * expansion_factor, output_channels=24, strides=2
)
x = inverted_residual_block(
x, expanded_channels=24 * expansion_factor, output_channels=24
)
x = inverted_residual_block(
x, expanded_channels=24 * expansion_factor, output_channels=24
)
x = inverted_residual_block(
x, expanded_channels=24 * expansion_factor, output_channels=48, strides=2
)
x = mobilevit_block(x, num_blocks=2, projection_dim=64)
x = inverted_residual_block(
x, expanded_channels=64 * expansion_factor, output_channels=64, strides=2
)
x = mobilevit_block(x, num_blocks=4, projection_dim=80)
x = inverted_residual_block(
x, expanded_channels=80 * expansion_factor, output_channels=80, strides=2
)
x = mobilevit_block(x, num_blocks=3, projection_dim=96)
x = conv_block(x, filters=320, kernel_size=1, strides=1)
x = layers.GlobalAvgPool2D()(x)
outputs = layers.Dense(num_classes, activation="softmax")(x)
return keras.Model(inputs, outputs)
mobilevit_xxs = create_mobilevit()
mobilevit_xxs.summary()
"""
## Dataset preparation
We will be using the
[`tf_flowers`](https://www.tensorflow.org/datasets/catalog/tf_flowers)
dataset to demonstrate the model. Unlike other Transformer-based architectures,
MobileViT uses a simple augmentation pipeline primarily because it has the properties
of a CNN.
"""
batch_size = 64
auto = tf.data.AUTOTUNE
resize_bigger = 280
num_classes = 5
def preprocess_dataset(is_training=True):
def _pp(image, label):
if is_training:
image = tf.image.resize(image, (resize_bigger, resize_bigger))
image = tf.image.random_crop(image, (image_size, image_size, 3))
image = tf.image.random_flip_left_right(image)
else:
image = tf.image.resize(image, (image_size, image_size))
label = tf.one_hot(label, depth=num_classes)
return image, label
return _pp
def prepare_dataset(dataset, is_training=True):
if is_training:
dataset = dataset.shuffle(batch_size * 10)
dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=auto)
return dataset.batch(batch_size).prefetch(auto)
"""
The authors use a multi-scale data sampler to help the model learn representations of
varied scales. In this example, we discard this part.
"""
"""
## Load and prepare the dataset
"""
train_dataset, val_dataset = tfds.load(
"tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True
)
num_train = train_dataset.cardinality()
num_val = val_dataset.cardinality()
print(f"Number of training examples: {num_train}")
print(f"Number of validation examples: {num_val}")
train_dataset = prepare_dataset(train_dataset, is_training=True)
val_dataset = prepare_dataset(val_dataset, is_training=False)
"""
## Train a MobileViT (XXS) model
"""
learning_rate = 0.002
label_smoothing_factor = 0.1
epochs = 30
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
loss_fn = keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing_factor)
def run_experiment(epochs=epochs):
mobilevit_xxs = create_mobilevit(num_classes=num_classes)
mobilevit_xxs.compile(optimizer=optimizer, loss=loss_fn, metrics=["accuracy"])
checkpoint_filepath = "/tmp/checkpoint.weights.h5"
checkpoint_callback = keras.callbacks.ModelCheckpoint(
checkpoint_filepath,
monitor="val_accuracy",
save_best_only=True,
save_weights_only=True,
)
mobilevit_xxs.fit(
train_dataset,
validation_data=val_dataset,
epochs=epochs,
callbacks=[checkpoint_callback],
)
mobilevit_xxs.load_weights(checkpoint_filepath)
_, accuracy = mobilevit_xxs.evaluate(val_dataset)
print(f"Validation accuracy: {round(accuracy * 100, 2)}%")
return mobilevit_xxs
mobilevit_xxs = run_experiment()
"""
## Results and TFLite conversion
With about one million parameters, getting to ~85% top-1 accuracy on 256x256 resolution is
a strong result. This MobileViT mobile is fully compatible with TensorFlow Lite (TFLite)
and can be converted with the following code:
"""
tf.saved_model.save(mobilevit_xxs, "mobilevit_xxs")
converter = tf.lite.TFLiteConverter.from_saved_model("mobilevit_xxs")
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS,
]
tflite_model = converter.convert()
open("mobilevit_xxs.tflite", "wb").write(tflite_model)
"""
To learn more about different quantization recipes available in TFLite and running
inference with TFLite models, check out
[this official resource](https://www.tensorflow.org/lite/performance/post_training_quantization).
You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/mobile-vit-xxs)
and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/Flowers-Classification-MobileViT).
"""