Path: blob/master/examples/vision/mlp_image_classification.py
3507 views
"""1Title: Image classification with modern MLP models2Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)3Date created: 2021/05/304Last modified: 2023/08/035Description: Implementing the MLP-Mixer, FNet, and gMLP models for CIFAR-100 image classification.6Accelerator: GPU7"""89"""10## Introduction1112This example implements three modern attention-free, multi-layer perceptron (MLP) based models for image13classification, demonstrated on the CIFAR-100 dataset:14151. The [MLP-Mixer](https://arxiv.org/abs/2105.01601) model, by Ilya Tolstikhin et al., based on two types of MLPs.163. The [FNet](https://arxiv.org/abs/2105.03824) model, by James Lee-Thorp et al., based on unparameterized17Fourier Transform.182. The [gMLP](https://arxiv.org/abs/2105.08050) model, by Hanxiao Liu et al., based on MLP with gating.1920The purpose of the example is not to compare between these models, as they might perform differently on21different datasets with well-tuned hyperparameters. Rather, it is to show simple implementations of their22main building blocks.23"""2425"""26## Setup27"""2829import numpy as np30import keras31from keras import layers3233"""34## Prepare the data35"""3637num_classes = 10038input_shape = (32, 32, 3)3940(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()4142print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")43print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")4445"""46## Configure the hyperparameters47"""4849weight_decay = 0.000150batch_size = 12851num_epochs = 1 # Recommended num_epochs = 5052dropout_rate = 0.253image_size = 64 # We'll resize input images to this size.54patch_size = 8 # Size of the patches to be extracted from the input images.55num_patches = (image_size // patch_size) ** 2 # Size of the data array.56embedding_dim = 256 # Number of hidden units.57num_blocks = 4 # Number of blocks.5859print(f"Image size: {image_size} X {image_size} = {image_size ** 2}")60print(f"Patch size: {patch_size} X {patch_size} = {patch_size ** 2} ")61print(f"Patches per image: {num_patches}")62print(f"Elements per patch (3 channels): {(patch_size ** 2) * 3}")6364"""65## Build a classification model6667We implement a method that builds a classifier given the processing blocks.68"""697071def build_classifier(blocks, positional_encoding=False):72inputs = layers.Input(shape=input_shape)73# Augment data.74augmented = data_augmentation(inputs)75# Create patches.76patches = Patches(patch_size)(augmented)77# Encode patches to generate a [batch_size, num_patches, embedding_dim] tensor.78x = layers.Dense(units=embedding_dim)(patches)79if positional_encoding:80x = x + PositionEmbedding(sequence_length=num_patches)(x)81# Process x using the module blocks.82x = blocks(x)83# Apply global average pooling to generate a [batch_size, embedding_dim] representation tensor.84representation = layers.GlobalAveragePooling1D()(x)85# Apply dropout.86representation = layers.Dropout(rate=dropout_rate)(representation)87# Compute logits outputs.88logits = layers.Dense(num_classes)(representation)89# Create the Keras model.90return keras.Model(inputs=inputs, outputs=logits)919293"""94## Define an experiment9596We implement a utility function to compile, train, and evaluate a given model.97"""9899100def run_experiment(model):101# Create Adam optimizer with weight decay.102optimizer = keras.optimizers.AdamW(103learning_rate=learning_rate,104weight_decay=weight_decay,105)106# Compile the model.107model.compile(108optimizer=optimizer,109loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),110metrics=[111keras.metrics.SparseCategoricalAccuracy(name="acc"),112keras.metrics.SparseTopKCategoricalAccuracy(5, name="top5-acc"),113],114)115# Create a learning rate scheduler callback.116reduce_lr = keras.callbacks.ReduceLROnPlateau(117monitor="val_loss", factor=0.5, patience=5118)119# Create an early stopping callback.120early_stopping = keras.callbacks.EarlyStopping(121monitor="val_loss", patience=10, restore_best_weights=True122)123# Fit the model.124history = model.fit(125x=x_train,126y=y_train,127batch_size=batch_size,128epochs=num_epochs,129validation_split=0.1,130callbacks=[early_stopping, reduce_lr],131)132133_, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)134print(f"Test accuracy: {round(accuracy * 100, 2)}%")135print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")136137# Return history to plot learning curves.138return history139140141"""142## Use data augmentation143"""144145data_augmentation = keras.Sequential(146[147layers.Normalization(),148layers.Resizing(image_size, image_size),149layers.RandomFlip("horizontal"),150layers.RandomZoom(height_factor=0.2, width_factor=0.2),151],152name="data_augmentation",153)154# Compute the mean and the variance of the training data for normalization.155data_augmentation.layers[0].adapt(x_train)156157158"""159## Implement patch extraction as a layer160"""161162163class Patches(layers.Layer):164def __init__(self, patch_size, **kwargs):165super().__init__(**kwargs)166self.patch_size = patch_size167168def call(self, x):169patches = keras.ops.image.extract_patches(x, self.patch_size)170batch_size = keras.ops.shape(patches)[0]171num_patches = keras.ops.shape(patches)[1] * keras.ops.shape(patches)[2]172patch_dim = keras.ops.shape(patches)[3]173out = keras.ops.reshape(patches, (batch_size, num_patches, patch_dim))174return out175176177"""178## Implement position embedding as a layer179"""180181182class PositionEmbedding(keras.layers.Layer):183def __init__(184self,185sequence_length,186initializer="glorot_uniform",187**kwargs,188):189super().__init__(**kwargs)190if sequence_length is None:191raise ValueError("`sequence_length` must be an Integer, received `None`.")192self.sequence_length = int(sequence_length)193self.initializer = keras.initializers.get(initializer)194195def get_config(self):196config = super().get_config()197config.update(198{199"sequence_length": self.sequence_length,200"initializer": keras.initializers.serialize(self.initializer),201}202)203return config204205def build(self, input_shape):206feature_size = input_shape[-1]207self.position_embeddings = self.add_weight(208name="embeddings",209shape=[self.sequence_length, feature_size],210initializer=self.initializer,211trainable=True,212)213214super().build(input_shape)215216def call(self, inputs, start_index=0):217shape = keras.ops.shape(inputs)218feature_length = shape[-1]219sequence_length = shape[-2]220# trim to match the length of the input sequence, which might be less221# than the sequence_length of the layer.222position_embeddings = keras.ops.convert_to_tensor(self.position_embeddings)223position_embeddings = keras.ops.slice(224position_embeddings,225(start_index, 0),226(sequence_length, feature_length),227)228return keras.ops.broadcast_to(position_embeddings, shape)229230def compute_output_shape(self, input_shape):231return input_shape232233234"""235## The MLP-Mixer model236237The MLP-Mixer is an architecture based exclusively on238multi-layer perceptrons (MLPs), that contains two types of MLP layers:2392401. One applied independently to image patches, which mixes the per-location features.2412. The other applied across patches (along channels), which mixes spatial information.242243This is similar to a [depthwise separable convolution based model](https://arxiv.org/abs/1610.02357)244such as the Xception model, but with two chained dense transforms, no max pooling, and layer normalization245instead of batch normalization.246"""247248"""249### Implement the MLP-Mixer module250"""251252253class MLPMixerLayer(layers.Layer):254def __init__(self, num_patches, hidden_units, dropout_rate, *args, **kwargs):255super().__init__(*args, **kwargs)256257self.mlp1 = keras.Sequential(258[259layers.Dense(units=num_patches, activation="gelu"),260layers.Dense(units=num_patches),261layers.Dropout(rate=dropout_rate),262]263)264self.mlp2 = keras.Sequential(265[266layers.Dense(units=num_patches, activation="gelu"),267layers.Dense(units=hidden_units),268layers.Dropout(rate=dropout_rate),269]270)271self.normalize = layers.LayerNormalization(epsilon=1e-6)272273def build(self, input_shape):274return super().build(input_shape)275276def call(self, inputs):277# Apply layer normalization.278x = self.normalize(inputs)279# Transpose inputs from [num_batches, num_patches, hidden_units] to [num_batches, hidden_units, num_patches].280x_channels = keras.ops.transpose(x, axes=(0, 2, 1))281# Apply mlp1 on each channel independently.282mlp1_outputs = self.mlp1(x_channels)283# Transpose mlp1_outputs from [num_batches, hidden_units, num_patches] to [num_batches, num_patches, hidden_units].284mlp1_outputs = keras.ops.transpose(mlp1_outputs, axes=(0, 2, 1))285# Add skip connection.286x = mlp1_outputs + inputs287# Apply layer normalization.288x_patches = self.normalize(x)289# Apply mlp2 on each patch independtenly.290mlp2_outputs = self.mlp2(x_patches)291# Add skip connection.292x = x + mlp2_outputs293return x294295296"""297### Build, train, and evaluate the MLP-Mixer model298299Note that training the model with the current settings on a V100 GPUs300takes around 8 seconds per epoch.301"""302303mlpmixer_blocks = keras.Sequential(304[MLPMixerLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)]305)306learning_rate = 0.005307mlpmixer_classifier = build_classifier(mlpmixer_blocks)308history = run_experiment(mlpmixer_classifier)309310"""311The MLP-Mixer model tends to have much less number of parameters compared312to convolutional and transformer-based models, which leads to less training and313serving computational cost.314315As mentioned in the [MLP-Mixer](https://arxiv.org/abs/2105.01601) paper,316when pre-trained on large datasets, or with modern regularization schemes,317the MLP-Mixer attains competitive scores to state-of-the-art models.318You can obtain better results by increasing the embedding dimensions,319increasing the number of mixer blocks, and training the model for longer.320You may also try to increase the size of the input images and use different patch sizes.321"""322323"""324## The FNet model325326The FNet uses a similar block to the Transformer block. However, FNet replaces the self-attention layer327in the Transformer block with a parameter-free 2D Fourier transformation layer:3283291. One 1D Fourier Transform is applied along the patches.3302. One 1D Fourier Transform is applied along the channels.331"""332333"""334### Implement the FNet module335"""336337338class FNetLayer(layers.Layer):339def __init__(self, embedding_dim, dropout_rate, *args, **kwargs):340super().__init__(*args, **kwargs)341342self.ffn = keras.Sequential(343[344layers.Dense(units=embedding_dim, activation="gelu"),345layers.Dropout(rate=dropout_rate),346layers.Dense(units=embedding_dim),347]348)349350self.normalize1 = layers.LayerNormalization(epsilon=1e-6)351self.normalize2 = layers.LayerNormalization(epsilon=1e-6)352353def call(self, inputs):354# Apply fourier transformations.355real_part = inputs356im_part = keras.ops.zeros_like(inputs)357x = keras.ops.fft2((real_part, im_part))[0]358# Add skip connection.359x = x + inputs360# Apply layer normalization.361x = self.normalize1(x)362# Apply Feedfowrad network.363x_ffn = self.ffn(x)364# Add skip connection.365x = x + x_ffn366# Apply layer normalization.367return self.normalize2(x)368369370"""371### Build, train, and evaluate the FNet model372373Note that training the model with the current settings on a V100 GPUs374takes around 8 seconds per epoch.375"""376377fnet_blocks = keras.Sequential(378[FNetLayer(embedding_dim, dropout_rate) for _ in range(num_blocks)]379)380learning_rate = 0.001381fnet_classifier = build_classifier(fnet_blocks, positional_encoding=True)382history = run_experiment(fnet_classifier)383384"""385As shown in the [FNet](https://arxiv.org/abs/2105.03824) paper,386better results can be achieved by increasing the embedding dimensions,387increasing the number of FNet blocks, and training the model for longer.388You may also try to increase the size of the input images and use different patch sizes.389The FNet scales very efficiently to long inputs, runs much faster than attention-based390Transformer models, and produces competitive accuracy results.391"""392393"""394## The gMLP model395396The gMLP is a MLP architecture that features a Spatial Gating Unit (SGU).397The SGU enables cross-patch interactions across the spatial (channel) dimension, by:3983991. Transforming the input spatially by applying linear projection across patches (along channels).4002. Applying element-wise multiplication of the input and its spatial transformation.401"""402403"""404### Implement the gMLP module405"""406407408class gMLPLayer(layers.Layer):409def __init__(self, num_patches, embedding_dim, dropout_rate, *args, **kwargs):410super().__init__(*args, **kwargs)411412self.channel_projection1 = keras.Sequential(413[414layers.Dense(units=embedding_dim * 2, activation="gelu"),415layers.Dropout(rate=dropout_rate),416]417)418419self.channel_projection2 = layers.Dense(units=embedding_dim)420421self.spatial_projection = layers.Dense(422units=num_patches, bias_initializer="Ones"423)424425self.normalize1 = layers.LayerNormalization(epsilon=1e-6)426self.normalize2 = layers.LayerNormalization(epsilon=1e-6)427428def spatial_gating_unit(self, x):429# Split x along the channel dimensions.430# Tensors u and v will in the shape of [batch_size, num_patchs, embedding_dim].431u, v = keras.ops.split(x, indices_or_sections=2, axis=2)432# Apply layer normalization.433v = self.normalize2(v)434# Apply spatial projection.435v_channels = keras.ops.transpose(v, axes=(0, 2, 1))436v_projected = self.spatial_projection(v_channels)437v_projected = keras.ops.transpose(v_projected, axes=(0, 2, 1))438# Apply element-wise multiplication.439return u * v_projected440441def call(self, inputs):442# Apply layer normalization.443x = self.normalize1(inputs)444# Apply the first channel projection. x_projected shape: [batch_size, num_patches, embedding_dim * 2].445x_projected = self.channel_projection1(x)446# Apply the spatial gating unit. x_spatial shape: [batch_size, num_patches, embedding_dim].447x_spatial = self.spatial_gating_unit(x_projected)448# Apply the second channel projection. x_projected shape: [batch_size, num_patches, embedding_dim].449x_projected = self.channel_projection2(x_spatial)450# Add skip connection.451return x + x_projected452453454"""455### Build, train, and evaluate the gMLP model456457Note that training the model with the current settings on a V100 GPUs458takes around 9 seconds per epoch.459"""460461gmlp_blocks = keras.Sequential(462[gMLPLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)]463)464learning_rate = 0.003465gmlp_classifier = build_classifier(gmlp_blocks)466history = run_experiment(gmlp_classifier)467468"""469As shown in the [gMLP](https://arxiv.org/abs/2105.08050) paper,470better results can be achieved by increasing the embedding dimensions,471increasing the number of gMLP blocks, and training the model for longer.472You may also try to increase the size of the input images and use different patch sizes.473Note that, the paper used advanced regularization strategies, such as MixUp and CutMix,474as well as AutoAugment.475"""476477478