Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/image_classification_with_vision_transformer.py
3507 views
1
"""
2
Title: Image classification with Vision Transformer
3
Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
4
Date created: 2021/01/18
5
Last modified: 2021/01/18
6
Description: Implementing the Vision Transformer (ViT) model for image classification.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
This example implements the [Vision Transformer (ViT)](https://arxiv.org/abs/2010.11929)
14
model by Alexey Dosovitskiy et al. for image classification,
15
and demonstrates it on the CIFAR-100 dataset.
16
The ViT model applies the Transformer architecture with self-attention to sequences of
17
image patches, without using convolution layers.
18
19
"""
20
21
"""
22
## Setup
23
"""
24
25
import os
26
27
os.environ["KERAS_BACKEND"] = "jax" # @param ["tensorflow", "jax", "torch"]
28
29
import keras
30
from keras import layers
31
from keras import ops
32
33
import numpy as np
34
import matplotlib.pyplot as plt
35
36
"""
37
## Prepare the data
38
"""
39
40
num_classes = 100
41
input_shape = (32, 32, 3)
42
43
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
44
45
print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
46
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
47
48
49
"""
50
## Configure the hyperparameters
51
"""
52
53
learning_rate = 0.001
54
weight_decay = 0.0001
55
batch_size = 256
56
num_epochs = 10 # For real training, use num_epochs=100. 10 is a test value
57
image_size = 72 # We'll resize input images to this size
58
patch_size = 6 # Size of the patches to be extract from the input images
59
num_patches = (image_size // patch_size) ** 2
60
projection_dim = 64
61
num_heads = 4
62
transformer_units = [
63
projection_dim * 2,
64
projection_dim,
65
] # Size of the transformer layers
66
transformer_layers = 8
67
mlp_head_units = [
68
2048,
69
1024,
70
] # Size of the dense layers of the final classifier
71
72
73
"""
74
## Use data augmentation
75
"""
76
77
data_augmentation = keras.Sequential(
78
[
79
layers.Normalization(),
80
layers.Resizing(image_size, image_size),
81
layers.RandomFlip("horizontal"),
82
layers.RandomRotation(factor=0.02),
83
layers.RandomZoom(height_factor=0.2, width_factor=0.2),
84
],
85
name="data_augmentation",
86
)
87
# Compute the mean and the variance of the training data for normalization.
88
data_augmentation.layers[0].adapt(x_train)
89
90
91
"""
92
## Implement multilayer perceptron (MLP)
93
"""
94
95
96
def mlp(x, hidden_units, dropout_rate):
97
for units in hidden_units:
98
x = layers.Dense(units, activation=keras.activations.gelu)(x)
99
x = layers.Dropout(dropout_rate)(x)
100
return x
101
102
103
"""
104
## Implement patch creation as a layer
105
"""
106
107
108
class Patches(layers.Layer):
109
def __init__(self, patch_size):
110
super().__init__()
111
self.patch_size = patch_size
112
113
def call(self, images):
114
input_shape = ops.shape(images)
115
batch_size = input_shape[0]
116
height = input_shape[1]
117
width = input_shape[2]
118
channels = input_shape[3]
119
num_patches_h = height // self.patch_size
120
num_patches_w = width // self.patch_size
121
patches = keras.ops.image.extract_patches(images, size=self.patch_size)
122
patches = ops.reshape(
123
patches,
124
(
125
batch_size,
126
num_patches_h * num_patches_w,
127
self.patch_size * self.patch_size * channels,
128
),
129
)
130
return patches
131
132
def get_config(self):
133
config = super().get_config()
134
config.update({"patch_size": self.patch_size})
135
return config
136
137
138
"""
139
Let's display patches for a sample image
140
"""
141
142
plt.figure(figsize=(4, 4))
143
image = x_train[np.random.choice(range(x_train.shape[0]))]
144
plt.imshow(image.astype("uint8"))
145
plt.axis("off")
146
147
resized_image = ops.image.resize(
148
ops.convert_to_tensor([image]), size=(image_size, image_size)
149
)
150
patches = Patches(patch_size)(resized_image)
151
print(f"Image size: {image_size} X {image_size}")
152
print(f"Patch size: {patch_size} X {patch_size}")
153
print(f"Patches per image: {patches.shape[1]}")
154
print(f"Elements per patch: {patches.shape[-1]}")
155
156
n = int(np.sqrt(patches.shape[1]))
157
plt.figure(figsize=(4, 4))
158
for i, patch in enumerate(patches[0]):
159
ax = plt.subplot(n, n, i + 1)
160
patch_img = ops.reshape(patch, (patch_size, patch_size, 3))
161
plt.imshow(ops.convert_to_numpy(patch_img).astype("uint8"))
162
plt.axis("off")
163
164
"""
165
## Implement the patch encoding layer
166
167
The `PatchEncoder` layer will linearly transform a patch by projecting it into a
168
vector of size `projection_dim`. In addition, it adds a learnable position
169
embedding to the projected vector.
170
"""
171
172
173
class PatchEncoder(layers.Layer):
174
def __init__(self, num_patches, projection_dim):
175
super().__init__()
176
self.num_patches = num_patches
177
self.projection = layers.Dense(units=projection_dim)
178
self.position_embedding = layers.Embedding(
179
input_dim=num_patches, output_dim=projection_dim
180
)
181
182
def call(self, patch):
183
positions = ops.expand_dims(
184
ops.arange(start=0, stop=self.num_patches, step=1), axis=0
185
)
186
projected_patches = self.projection(patch)
187
encoded = projected_patches + self.position_embedding(positions)
188
return encoded
189
190
def get_config(self):
191
config = super().get_config()
192
config.update({"num_patches": self.num_patches})
193
return config
194
195
196
"""
197
## Build the ViT model
198
199
The ViT model consists of multiple Transformer blocks,
200
which use the `layers.MultiHeadAttention` layer as a self-attention mechanism
201
applied to the sequence of patches. The Transformer blocks produce a
202
`[batch_size, num_patches, projection_dim]` tensor, which is processed via an
203
classifier head with softmax to produce the final class probabilities output.
204
205
Unlike the technique described in the [paper](https://arxiv.org/abs/2010.11929),
206
which prepends a learnable embedding to the sequence of encoded patches to serve
207
as the image representation, all the outputs of the final Transformer block are
208
reshaped with `layers.Flatten()` and used as the image
209
representation input to the classifier head.
210
Note that the `layers.GlobalAveragePooling1D` layer
211
could also be used instead to aggregate the outputs of the Transformer block,
212
especially when the number of patches and the projection dimensions are large.
213
"""
214
215
216
def create_vit_classifier():
217
inputs = keras.Input(shape=input_shape)
218
# Augment data.
219
augmented = data_augmentation(inputs)
220
# Create patches.
221
patches = Patches(patch_size)(augmented)
222
# Encode patches.
223
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
224
225
# Create multiple layers of the Transformer block.
226
for _ in range(transformer_layers):
227
# Layer normalization 1.
228
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
229
# Create a multi-head attention layer.
230
attention_output = layers.MultiHeadAttention(
231
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
232
)(x1, x1)
233
# Skip connection 1.
234
x2 = layers.Add()([attention_output, encoded_patches])
235
# Layer normalization 2.
236
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
237
# MLP.
238
x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
239
# Skip connection 2.
240
encoded_patches = layers.Add()([x3, x2])
241
242
# Create a [batch_size, projection_dim] tensor.
243
representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
244
representation = layers.Flatten()(representation)
245
representation = layers.Dropout(0.5)(representation)
246
# Add MLP.
247
features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
248
# Classify outputs.
249
logits = layers.Dense(num_classes)(features)
250
# Create the Keras model.
251
model = keras.Model(inputs=inputs, outputs=logits)
252
return model
253
254
255
"""
256
## Compile, train, and evaluate the mode
257
"""
258
259
260
def run_experiment(model):
261
optimizer = keras.optimizers.AdamW(
262
learning_rate=learning_rate, weight_decay=weight_decay
263
)
264
265
model.compile(
266
optimizer=optimizer,
267
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
268
metrics=[
269
keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
270
keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
271
],
272
)
273
274
checkpoint_filepath = "/tmp/checkpoint.weights.h5"
275
checkpoint_callback = keras.callbacks.ModelCheckpoint(
276
checkpoint_filepath,
277
monitor="val_accuracy",
278
save_best_only=True,
279
save_weights_only=True,
280
)
281
282
history = model.fit(
283
x=x_train,
284
y=y_train,
285
batch_size=batch_size,
286
epochs=num_epochs,
287
validation_split=0.1,
288
callbacks=[checkpoint_callback],
289
)
290
291
model.load_weights(checkpoint_filepath)
292
_, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
293
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
294
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
295
296
return history
297
298
299
vit_classifier = create_vit_classifier()
300
history = run_experiment(vit_classifier)
301
302
303
def plot_history(item):
304
plt.plot(history.history[item], label=item)
305
plt.plot(history.history["val_" + item], label="val_" + item)
306
plt.xlabel("Epochs")
307
plt.ylabel(item)
308
plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
309
plt.legend()
310
plt.grid()
311
plt.show()
312
313
314
plot_history("loss")
315
plot_history("top-5-accuracy")
316
317
318
"""
319
After 100 epochs, the ViT model achieves around 55% accuracy and
320
82% top-5 accuracy on the test data. These are not competitive results on the CIFAR-100 dataset,
321
as a ResNet50V2 trained from scratch on the same data can achieve 67% accuracy.
322
323
Note that the state of the art results reported in the
324
[paper](https://arxiv.org/abs/2010.11929) are achieved by pre-training the ViT model using
325
the JFT-300M dataset, then fine-tuning it on the target dataset. To improve the model quality
326
without pre-training, you can try to train the model for more epochs, use a larger number of
327
Transformer layers, resize the input images, change the patch size, or increase the projection dimensions.
328
Besides, as mentioned in the paper, the quality of the model is affected not only by architecture choices,
329
but also by parameters such as the learning rate schedule, optimizer, weight decay, etc.
330
In practice, it's recommended to fine-tune a ViT model
331
that was pre-trained using a large, high-resolution dataset.
332
"""
333
334