Path: blob/master/examples/vision/md/mlp_image_classification.md
3508 views
Image classification with modern MLP models
Author: Khalid Salama
Date created: 2021/05/30
Last modified: 2023/08/03
Description: Implementing the MLP-Mixer, FNet, and gMLP models for CIFAR-100 image classification.
Introduction
This example implements three modern attention-free, multi-layer perceptron (MLP) based models for image classification, demonstrated on the CIFAR-100 dataset:
The MLP-Mixer model, by Ilya Tolstikhin et al., based on two types of MLPs.
The FNet model, by James Lee-Thorp et al., based on unparameterized Fourier Transform.
The gMLP model, by Hanxiao Liu et al., based on MLP with gating.
The purpose of the example is not to compare between these models, as they might perform differently on different datasets with well-tuned hyperparameters. Rather, it is to show simple implementations of their main building blocks.
Setup
Prepare the data
Define an experiment
We implement a utility function to compile, train, and evaluate a given model.
Use data augmentation
Implement patch extraction as a layer
Implement position embedding as a layer
The MLP-Mixer model
The MLP-Mixer is an architecture based exclusively on multi-layer perceptrons (MLPs), that contains two types of MLP layers:
One applied independently to image patches, which mixes the per-location features.
The other applied across patches (along channels), which mixes spatial information.
This is similar to a depthwise separable convolution based model such as the Xception model, but with two chained dense transforms, no max pooling, and layer normalization instead of batch normalization.
Implement the MLP-Mixer module
Build, train, and evaluate the MLP-Mixer model
Note that training the model with the current settings on a V100 GPUs takes around 8 seconds per epoch.
Build, train, and evaluate the FNet model
Note that training the model with the current settings on a V100 GPUs takes around 8 seconds per epoch.
Build, train, and evaluate the gMLP model
Note that training the model with the current settings on a V100 GPUs takes around 9 seconds per epoch.