Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/eanet.py
3507 views
1
"""
2
Title: Image classification with EANet (External Attention Transformer)
3
Author: [ZhiYong Chang](https://github.com/czy00000)
4
Date created: 2021/10/19
5
Last modified: 2023/07/18
6
Description: Image classification with a Transformer that leverages external attention.
7
Accelerator: GPU
8
Converted to Keras 3: [Muhammad Anas Raza](https://anasrz.com)
9
"""
10
11
"""
12
## Introduction
13
14
This example implements the [EANet](https://arxiv.org/abs/2105.02358)
15
model for image classification, and demonstrates it on the CIFAR-100 dataset.
16
EANet introduces a novel attention mechanism
17
named ***external attention***, based on two external, small, learnable, and
18
shared memories, which can be implemented easily by simply using two cascaded
19
linear layers and two normalization layers. It conveniently replaces self-attention
20
as used in existing architectures. External attention has linear complexity, as it only
21
implicitly considers the correlations between all samples.
22
"""
23
24
"""
25
## Setup
26
"""
27
28
import keras
29
from keras import layers
30
from keras import ops
31
32
import matplotlib.pyplot as plt
33
34
35
"""
36
## Prepare the data
37
"""
38
39
num_classes = 100
40
input_shape = (32, 32, 3)
41
42
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
43
y_train = keras.utils.to_categorical(y_train, num_classes)
44
y_test = keras.utils.to_categorical(y_test, num_classes)
45
print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
46
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
47
48
"""
49
## Configure the hyperparameters
50
"""
51
52
weight_decay = 0.0001
53
learning_rate = 0.001
54
label_smoothing = 0.1
55
validation_split = 0.2
56
batch_size = 128
57
num_epochs = 50
58
patch_size = 2 # Size of the patches to be extracted from the input images.
59
num_patches = (input_shape[0] // patch_size) ** 2 # Number of patch
60
embedding_dim = 64 # Number of hidden units.
61
mlp_dim = 64
62
dim_coefficient = 4
63
num_heads = 4
64
attention_dropout = 0.2
65
projection_dropout = 0.2
66
num_transformer_blocks = 8 # Number of repetitions of the transformer layer
67
68
print(f"Patch size: {patch_size} X {patch_size} = {patch_size ** 2} ")
69
print(f"Patches per image: {num_patches}")
70
71
72
"""
73
## Use data augmentation
74
"""
75
76
data_augmentation = keras.Sequential(
77
[
78
layers.Normalization(),
79
layers.RandomFlip("horizontal"),
80
layers.RandomRotation(factor=0.1),
81
layers.RandomContrast(factor=0.1),
82
layers.RandomZoom(height_factor=0.2, width_factor=0.2),
83
],
84
name="data_augmentation",
85
)
86
# Compute the mean and the variance of the training data for normalization.
87
data_augmentation.layers[0].adapt(x_train)
88
89
"""
90
## Implement the patch extraction and encoding layer
91
"""
92
93
94
class PatchExtract(layers.Layer):
95
def __init__(self, patch_size, **kwargs):
96
super().__init__(**kwargs)
97
self.patch_size = patch_size
98
99
def call(self, x):
100
B, C = ops.shape(x)[0], ops.shape(x)[-1]
101
x = ops.image.extract_patches(x, self.patch_size)
102
x = ops.reshape(x, (B, -1, self.patch_size * self.patch_size * C))
103
return x
104
105
106
class PatchEmbedding(layers.Layer):
107
def __init__(self, num_patch, embed_dim, **kwargs):
108
super().__init__(**kwargs)
109
self.num_patch = num_patch
110
self.proj = layers.Dense(embed_dim)
111
self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)
112
113
def call(self, patch):
114
pos = ops.arange(start=0, stop=self.num_patch, step=1)
115
return self.proj(patch) + self.pos_embed(pos)
116
117
118
"""
119
## Implement the external attention block
120
"""
121
122
123
def external_attention(
124
x,
125
dim,
126
num_heads,
127
dim_coefficient=4,
128
attention_dropout=0,
129
projection_dropout=0,
130
):
131
_, num_patch, channel = x.shape
132
assert dim % num_heads == 0
133
num_heads = num_heads * dim_coefficient
134
135
x = layers.Dense(dim * dim_coefficient)(x)
136
# create tensor [batch_size, num_patches, num_heads, dim*dim_coefficient//num_heads]
137
x = ops.reshape(x, (-1, num_patch, num_heads, dim * dim_coefficient // num_heads))
138
x = ops.transpose(x, axes=[0, 2, 1, 3])
139
# a linear layer M_k
140
attn = layers.Dense(dim // dim_coefficient)(x)
141
# normalize attention map
142
attn = layers.Softmax(axis=2)(attn)
143
# dobule-normalization
144
attn = layers.Lambda(
145
lambda attn: ops.divide(
146
attn,
147
ops.convert_to_tensor(1e-9) + ops.sum(attn, axis=-1, keepdims=True),
148
)
149
)(attn)
150
attn = layers.Dropout(attention_dropout)(attn)
151
# a linear layer M_v
152
x = layers.Dense(dim * dim_coefficient // num_heads)(attn)
153
x = ops.transpose(x, axes=[0, 2, 1, 3])
154
x = ops.reshape(x, [-1, num_patch, dim * dim_coefficient])
155
# a linear layer to project original dim
156
x = layers.Dense(dim)(x)
157
x = layers.Dropout(projection_dropout)(x)
158
return x
159
160
161
"""
162
## Implement the MLP block
163
"""
164
165
166
def mlp(x, embedding_dim, mlp_dim, drop_rate=0.2):
167
x = layers.Dense(mlp_dim, activation=ops.gelu)(x)
168
x = layers.Dropout(drop_rate)(x)
169
x = layers.Dense(embedding_dim)(x)
170
x = layers.Dropout(drop_rate)(x)
171
return x
172
173
174
"""
175
## Implement the Transformer block
176
"""
177
178
179
def transformer_encoder(
180
x,
181
embedding_dim,
182
mlp_dim,
183
num_heads,
184
dim_coefficient,
185
attention_dropout,
186
projection_dropout,
187
attention_type="external_attention",
188
):
189
residual_1 = x
190
x = layers.LayerNormalization(epsilon=1e-5)(x)
191
if attention_type == "external_attention":
192
x = external_attention(
193
x,
194
embedding_dim,
195
num_heads,
196
dim_coefficient,
197
attention_dropout,
198
projection_dropout,
199
)
200
elif attention_type == "self_attention":
201
x = layers.MultiHeadAttention(
202
num_heads=num_heads,
203
key_dim=embedding_dim,
204
dropout=attention_dropout,
205
)(x, x)
206
x = layers.add([x, residual_1])
207
residual_2 = x
208
x = layers.LayerNormalization(epsilon=1e-5)(x)
209
x = mlp(x, embedding_dim, mlp_dim)
210
x = layers.add([x, residual_2])
211
return x
212
213
214
"""
215
## Implement the EANet model
216
"""
217
218
"""
219
The EANet model leverages external attention.
220
The computational complexity of traditional self attention is `O(d * N ** 2)`,
221
where `d` is the embedding size, and `N` is the number of patch.
222
the authors find that most pixels are closely related to just a few other
223
pixels, and an `N`-to-`N` attention matrix may be redundant.
224
So, they propose as an alternative an external
225
attention module where the computational complexity of external attention is `O(d * S * N)`.
226
As `d` and `S` are hyper-parameters,
227
the proposed algorithm is linear in the number of pixels. In fact, this is equivalent
228
to a drop patch operation, because a lot of information contained in a patch
229
in an image is redundant and unimportant.
230
"""
231
232
233
def get_model(attention_type="external_attention"):
234
inputs = layers.Input(shape=input_shape)
235
# Image augment
236
x = data_augmentation(inputs)
237
# Extract patches.
238
x = PatchExtract(patch_size)(x)
239
# Create patch embedding.
240
x = PatchEmbedding(num_patches, embedding_dim)(x)
241
# Create Transformer block.
242
for _ in range(num_transformer_blocks):
243
x = transformer_encoder(
244
x,
245
embedding_dim,
246
mlp_dim,
247
num_heads,
248
dim_coefficient,
249
attention_dropout,
250
projection_dropout,
251
attention_type,
252
)
253
254
x = layers.GlobalAveragePooling1D()(x)
255
outputs = layers.Dense(num_classes, activation="softmax")(x)
256
model = keras.Model(inputs=inputs, outputs=outputs)
257
return model
258
259
260
"""
261
## Train on CIFAR-100
262
263
"""
264
265
266
model = get_model(attention_type="external_attention")
267
268
model.compile(
269
loss=keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing),
270
optimizer=keras.optimizers.AdamW(
271
learning_rate=learning_rate, weight_decay=weight_decay
272
),
273
metrics=[
274
keras.metrics.CategoricalAccuracy(name="accuracy"),
275
keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
276
],
277
)
278
279
history = model.fit(
280
x_train,
281
y_train,
282
batch_size=batch_size,
283
epochs=num_epochs,
284
validation_split=validation_split,
285
)
286
287
"""
288
### Let's visualize the training progress of the model.
289
290
"""
291
292
plt.plot(history.history["loss"], label="train_loss")
293
plt.plot(history.history["val_loss"], label="val_loss")
294
plt.xlabel("Epochs")
295
plt.ylabel("Loss")
296
plt.title("Train and Validation Losses Over Epochs", fontsize=14)
297
plt.legend()
298
plt.grid()
299
plt.show()
300
301
"""
302
### Let's display the final results of the test on CIFAR-100.
303
304
"""
305
306
loss, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
307
print(f"Test loss: {round(loss, 2)}")
308
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
309
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
310
311
"""
312
EANet just replaces self attention in Vit with external attention.
313
The traditional Vit achieved a ~73% test top-5 accuracy and ~41 top-1 accuracy after
314
training 50 epochs, but with 0.6M parameters. Under the same experimental environment
315
and the same hyperparameters, The EANet model we just trained has just 0.3M parameters,
316
and it gets us to ~73% test top-5 accuracy and ~43% top-1 accuracy. This fully demonstrates the
317
effectiveness of external attention.
318
319
We only show the training
320
process of EANet, you can train Vit under the same experimental conditions and observe
321
the test results.
322
"""
323
324