Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/mlp_image_classification.py
3507 views
1
"""
2
Title: Image classification with modern MLP models
3
Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
4
Date created: 2021/05/30
5
Last modified: 2023/08/03
6
Description: Implementing the MLP-Mixer, FNet, and gMLP models for CIFAR-100 image classification.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
This example implements three modern attention-free, multi-layer perceptron (MLP) based models for image
14
classification, demonstrated on the CIFAR-100 dataset:
15
16
1. The [MLP-Mixer](https://arxiv.org/abs/2105.01601) model, by Ilya Tolstikhin et al., based on two types of MLPs.
17
3. The [FNet](https://arxiv.org/abs/2105.03824) model, by James Lee-Thorp et al., based on unparameterized
18
Fourier Transform.
19
2. The [gMLP](https://arxiv.org/abs/2105.08050) model, by Hanxiao Liu et al., based on MLP with gating.
20
21
The purpose of the example is not to compare between these models, as they might perform differently on
22
different datasets with well-tuned hyperparameters. Rather, it is to show simple implementations of their
23
main building blocks.
24
"""
25
26
"""
27
## Setup
28
"""
29
30
import numpy as np
31
import keras
32
from keras import layers
33
34
"""
35
## Prepare the data
36
"""
37
38
num_classes = 100
39
input_shape = (32, 32, 3)
40
41
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
42
43
print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
44
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
45
46
"""
47
## Configure the hyperparameters
48
"""
49
50
weight_decay = 0.0001
51
batch_size = 128
52
num_epochs = 1 # Recommended num_epochs = 50
53
dropout_rate = 0.2
54
image_size = 64 # We'll resize input images to this size.
55
patch_size = 8 # Size of the patches to be extracted from the input images.
56
num_patches = (image_size // patch_size) ** 2 # Size of the data array.
57
embedding_dim = 256 # Number of hidden units.
58
num_blocks = 4 # Number of blocks.
59
60
print(f"Image size: {image_size} X {image_size} = {image_size ** 2}")
61
print(f"Patch size: {patch_size} X {patch_size} = {patch_size ** 2} ")
62
print(f"Patches per image: {num_patches}")
63
print(f"Elements per patch (3 channels): {(patch_size ** 2) * 3}")
64
65
"""
66
## Build a classification model
67
68
We implement a method that builds a classifier given the processing blocks.
69
"""
70
71
72
def build_classifier(blocks, positional_encoding=False):
73
inputs = layers.Input(shape=input_shape)
74
# Augment data.
75
augmented = data_augmentation(inputs)
76
# Create patches.
77
patches = Patches(patch_size)(augmented)
78
# Encode patches to generate a [batch_size, num_patches, embedding_dim] tensor.
79
x = layers.Dense(units=embedding_dim)(patches)
80
if positional_encoding:
81
x = x + PositionEmbedding(sequence_length=num_patches)(x)
82
# Process x using the module blocks.
83
x = blocks(x)
84
# Apply global average pooling to generate a [batch_size, embedding_dim] representation tensor.
85
representation = layers.GlobalAveragePooling1D()(x)
86
# Apply dropout.
87
representation = layers.Dropout(rate=dropout_rate)(representation)
88
# Compute logits outputs.
89
logits = layers.Dense(num_classes)(representation)
90
# Create the Keras model.
91
return keras.Model(inputs=inputs, outputs=logits)
92
93
94
"""
95
## Define an experiment
96
97
We implement a utility function to compile, train, and evaluate a given model.
98
"""
99
100
101
def run_experiment(model):
102
# Create Adam optimizer with weight decay.
103
optimizer = keras.optimizers.AdamW(
104
learning_rate=learning_rate,
105
weight_decay=weight_decay,
106
)
107
# Compile the model.
108
model.compile(
109
optimizer=optimizer,
110
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
111
metrics=[
112
keras.metrics.SparseCategoricalAccuracy(name="acc"),
113
keras.metrics.SparseTopKCategoricalAccuracy(5, name="top5-acc"),
114
],
115
)
116
# Create a learning rate scheduler callback.
117
reduce_lr = keras.callbacks.ReduceLROnPlateau(
118
monitor="val_loss", factor=0.5, patience=5
119
)
120
# Create an early stopping callback.
121
early_stopping = keras.callbacks.EarlyStopping(
122
monitor="val_loss", patience=10, restore_best_weights=True
123
)
124
# Fit the model.
125
history = model.fit(
126
x=x_train,
127
y=y_train,
128
batch_size=batch_size,
129
epochs=num_epochs,
130
validation_split=0.1,
131
callbacks=[early_stopping, reduce_lr],
132
)
133
134
_, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
135
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
136
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
137
138
# Return history to plot learning curves.
139
return history
140
141
142
"""
143
## Use data augmentation
144
"""
145
146
data_augmentation = keras.Sequential(
147
[
148
layers.Normalization(),
149
layers.Resizing(image_size, image_size),
150
layers.RandomFlip("horizontal"),
151
layers.RandomZoom(height_factor=0.2, width_factor=0.2),
152
],
153
name="data_augmentation",
154
)
155
# Compute the mean and the variance of the training data for normalization.
156
data_augmentation.layers[0].adapt(x_train)
157
158
159
"""
160
## Implement patch extraction as a layer
161
"""
162
163
164
class Patches(layers.Layer):
165
def __init__(self, patch_size, **kwargs):
166
super().__init__(**kwargs)
167
self.patch_size = patch_size
168
169
def call(self, x):
170
patches = keras.ops.image.extract_patches(x, self.patch_size)
171
batch_size = keras.ops.shape(patches)[0]
172
num_patches = keras.ops.shape(patches)[1] * keras.ops.shape(patches)[2]
173
patch_dim = keras.ops.shape(patches)[3]
174
out = keras.ops.reshape(patches, (batch_size, num_patches, patch_dim))
175
return out
176
177
178
"""
179
## Implement position embedding as a layer
180
"""
181
182
183
class PositionEmbedding(keras.layers.Layer):
184
def __init__(
185
self,
186
sequence_length,
187
initializer="glorot_uniform",
188
**kwargs,
189
):
190
super().__init__(**kwargs)
191
if sequence_length is None:
192
raise ValueError("`sequence_length` must be an Integer, received `None`.")
193
self.sequence_length = int(sequence_length)
194
self.initializer = keras.initializers.get(initializer)
195
196
def get_config(self):
197
config = super().get_config()
198
config.update(
199
{
200
"sequence_length": self.sequence_length,
201
"initializer": keras.initializers.serialize(self.initializer),
202
}
203
)
204
return config
205
206
def build(self, input_shape):
207
feature_size = input_shape[-1]
208
self.position_embeddings = self.add_weight(
209
name="embeddings",
210
shape=[self.sequence_length, feature_size],
211
initializer=self.initializer,
212
trainable=True,
213
)
214
215
super().build(input_shape)
216
217
def call(self, inputs, start_index=0):
218
shape = keras.ops.shape(inputs)
219
feature_length = shape[-1]
220
sequence_length = shape[-2]
221
# trim to match the length of the input sequence, which might be less
222
# than the sequence_length of the layer.
223
position_embeddings = keras.ops.convert_to_tensor(self.position_embeddings)
224
position_embeddings = keras.ops.slice(
225
position_embeddings,
226
(start_index, 0),
227
(sequence_length, feature_length),
228
)
229
return keras.ops.broadcast_to(position_embeddings, shape)
230
231
def compute_output_shape(self, input_shape):
232
return input_shape
233
234
235
"""
236
## The MLP-Mixer model
237
238
The MLP-Mixer is an architecture based exclusively on
239
multi-layer perceptrons (MLPs), that contains two types of MLP layers:
240
241
1. One applied independently to image patches, which mixes the per-location features.
242
2. The other applied across patches (along channels), which mixes spatial information.
243
244
This is similar to a [depthwise separable convolution based model](https://arxiv.org/abs/1610.02357)
245
such as the Xception model, but with two chained dense transforms, no max pooling, and layer normalization
246
instead of batch normalization.
247
"""
248
249
"""
250
### Implement the MLP-Mixer module
251
"""
252
253
254
class MLPMixerLayer(layers.Layer):
255
def __init__(self, num_patches, hidden_units, dropout_rate, *args, **kwargs):
256
super().__init__(*args, **kwargs)
257
258
self.mlp1 = keras.Sequential(
259
[
260
layers.Dense(units=num_patches, activation="gelu"),
261
layers.Dense(units=num_patches),
262
layers.Dropout(rate=dropout_rate),
263
]
264
)
265
self.mlp2 = keras.Sequential(
266
[
267
layers.Dense(units=num_patches, activation="gelu"),
268
layers.Dense(units=hidden_units),
269
layers.Dropout(rate=dropout_rate),
270
]
271
)
272
self.normalize = layers.LayerNormalization(epsilon=1e-6)
273
274
def build(self, input_shape):
275
return super().build(input_shape)
276
277
def call(self, inputs):
278
# Apply layer normalization.
279
x = self.normalize(inputs)
280
# Transpose inputs from [num_batches, num_patches, hidden_units] to [num_batches, hidden_units, num_patches].
281
x_channels = keras.ops.transpose(x, axes=(0, 2, 1))
282
# Apply mlp1 on each channel independently.
283
mlp1_outputs = self.mlp1(x_channels)
284
# Transpose mlp1_outputs from [num_batches, hidden_units, num_patches] to [num_batches, num_patches, hidden_units].
285
mlp1_outputs = keras.ops.transpose(mlp1_outputs, axes=(0, 2, 1))
286
# Add skip connection.
287
x = mlp1_outputs + inputs
288
# Apply layer normalization.
289
x_patches = self.normalize(x)
290
# Apply mlp2 on each patch independtenly.
291
mlp2_outputs = self.mlp2(x_patches)
292
# Add skip connection.
293
x = x + mlp2_outputs
294
return x
295
296
297
"""
298
### Build, train, and evaluate the MLP-Mixer model
299
300
Note that training the model with the current settings on a V100 GPUs
301
takes around 8 seconds per epoch.
302
"""
303
304
mlpmixer_blocks = keras.Sequential(
305
[MLPMixerLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)]
306
)
307
learning_rate = 0.005
308
mlpmixer_classifier = build_classifier(mlpmixer_blocks)
309
history = run_experiment(mlpmixer_classifier)
310
311
"""
312
The MLP-Mixer model tends to have much less number of parameters compared
313
to convolutional and transformer-based models, which leads to less training and
314
serving computational cost.
315
316
As mentioned in the [MLP-Mixer](https://arxiv.org/abs/2105.01601) paper,
317
when pre-trained on large datasets, or with modern regularization schemes,
318
the MLP-Mixer attains competitive scores to state-of-the-art models.
319
You can obtain better results by increasing the embedding dimensions,
320
increasing the number of mixer blocks, and training the model for longer.
321
You may also try to increase the size of the input images and use different patch sizes.
322
"""
323
324
"""
325
## The FNet model
326
327
The FNet uses a similar block to the Transformer block. However, FNet replaces the self-attention layer
328
in the Transformer block with a parameter-free 2D Fourier transformation layer:
329
330
1. One 1D Fourier Transform is applied along the patches.
331
2. One 1D Fourier Transform is applied along the channels.
332
"""
333
334
"""
335
### Implement the FNet module
336
"""
337
338
339
class FNetLayer(layers.Layer):
340
def __init__(self, embedding_dim, dropout_rate, *args, **kwargs):
341
super().__init__(*args, **kwargs)
342
343
self.ffn = keras.Sequential(
344
[
345
layers.Dense(units=embedding_dim, activation="gelu"),
346
layers.Dropout(rate=dropout_rate),
347
layers.Dense(units=embedding_dim),
348
]
349
)
350
351
self.normalize1 = layers.LayerNormalization(epsilon=1e-6)
352
self.normalize2 = layers.LayerNormalization(epsilon=1e-6)
353
354
def call(self, inputs):
355
# Apply fourier transformations.
356
real_part = inputs
357
im_part = keras.ops.zeros_like(inputs)
358
x = keras.ops.fft2((real_part, im_part))[0]
359
# Add skip connection.
360
x = x + inputs
361
# Apply layer normalization.
362
x = self.normalize1(x)
363
# Apply Feedfowrad network.
364
x_ffn = self.ffn(x)
365
# Add skip connection.
366
x = x + x_ffn
367
# Apply layer normalization.
368
return self.normalize2(x)
369
370
371
"""
372
### Build, train, and evaluate the FNet model
373
374
Note that training the model with the current settings on a V100 GPUs
375
takes around 8 seconds per epoch.
376
"""
377
378
fnet_blocks = keras.Sequential(
379
[FNetLayer(embedding_dim, dropout_rate) for _ in range(num_blocks)]
380
)
381
learning_rate = 0.001
382
fnet_classifier = build_classifier(fnet_blocks, positional_encoding=True)
383
history = run_experiment(fnet_classifier)
384
385
"""
386
As shown in the [FNet](https://arxiv.org/abs/2105.03824) paper,
387
better results can be achieved by increasing the embedding dimensions,
388
increasing the number of FNet blocks, and training the model for longer.
389
You may also try to increase the size of the input images and use different patch sizes.
390
The FNet scales very efficiently to long inputs, runs much faster than attention-based
391
Transformer models, and produces competitive accuracy results.
392
"""
393
394
"""
395
## The gMLP model
396
397
The gMLP is a MLP architecture that features a Spatial Gating Unit (SGU).
398
The SGU enables cross-patch interactions across the spatial (channel) dimension, by:
399
400
1. Transforming the input spatially by applying linear projection across patches (along channels).
401
2. Applying element-wise multiplication of the input and its spatial transformation.
402
"""
403
404
"""
405
### Implement the gMLP module
406
"""
407
408
409
class gMLPLayer(layers.Layer):
410
def __init__(self, num_patches, embedding_dim, dropout_rate, *args, **kwargs):
411
super().__init__(*args, **kwargs)
412
413
self.channel_projection1 = keras.Sequential(
414
[
415
layers.Dense(units=embedding_dim * 2, activation="gelu"),
416
layers.Dropout(rate=dropout_rate),
417
]
418
)
419
420
self.channel_projection2 = layers.Dense(units=embedding_dim)
421
422
self.spatial_projection = layers.Dense(
423
units=num_patches, bias_initializer="Ones"
424
)
425
426
self.normalize1 = layers.LayerNormalization(epsilon=1e-6)
427
self.normalize2 = layers.LayerNormalization(epsilon=1e-6)
428
429
def spatial_gating_unit(self, x):
430
# Split x along the channel dimensions.
431
# Tensors u and v will in the shape of [batch_size, num_patchs, embedding_dim].
432
u, v = keras.ops.split(x, indices_or_sections=2, axis=2)
433
# Apply layer normalization.
434
v = self.normalize2(v)
435
# Apply spatial projection.
436
v_channels = keras.ops.transpose(v, axes=(0, 2, 1))
437
v_projected = self.spatial_projection(v_channels)
438
v_projected = keras.ops.transpose(v_projected, axes=(0, 2, 1))
439
# Apply element-wise multiplication.
440
return u * v_projected
441
442
def call(self, inputs):
443
# Apply layer normalization.
444
x = self.normalize1(inputs)
445
# Apply the first channel projection. x_projected shape: [batch_size, num_patches, embedding_dim * 2].
446
x_projected = self.channel_projection1(x)
447
# Apply the spatial gating unit. x_spatial shape: [batch_size, num_patches, embedding_dim].
448
x_spatial = self.spatial_gating_unit(x_projected)
449
# Apply the second channel projection. x_projected shape: [batch_size, num_patches, embedding_dim].
450
x_projected = self.channel_projection2(x_spatial)
451
# Add skip connection.
452
return x + x_projected
453
454
455
"""
456
### Build, train, and evaluate the gMLP model
457
458
Note that training the model with the current settings on a V100 GPUs
459
takes around 9 seconds per epoch.
460
"""
461
462
gmlp_blocks = keras.Sequential(
463
[gMLPLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)]
464
)
465
learning_rate = 0.003
466
gmlp_classifier = build_classifier(gmlp_blocks)
467
history = run_experiment(gmlp_classifier)
468
469
"""
470
As shown in the [gMLP](https://arxiv.org/abs/2105.08050) paper,
471
better results can be achieved by increasing the embedding dimensions,
472
increasing the number of gMLP blocks, and training the model for longer.
473
You may also try to increase the size of the input images and use different patch sizes.
474
Note that, the paper used advanced regularization strategies, such as MixUp and CutMix,
475
as well as AutoAugment.
476
"""
477
478