Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/focal_modulation_network.py
8412 views
1
"""
2
Title: Focal Modulation: A replacement for Self-Attention
3
Author: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha)
4
Date created: 2023/01/25
5
Last modified: 2026/01/27
6
Description: Image classification with Focal Modulation Networks.
7
Accelerator: GPU
8
Converted to Keras 3 by: [LakshmiKalaKadali](https://github.com/LakshmiKalaKadali)
9
"""
10
11
"""
12
## Introduction
13
14
This tutorial aims to provide a comprehensive guide to the implementation of
15
Focal Modulation Networks, as presented in
16
[Yang et al.](https://arxiv.org/abs/2203.11926).
17
18
This tutorial will provide a formal, minimalistic approach to implementing Focal
19
Modulation Networks and explore its potential applications in the field of Deep Learning.
20
21
**Problem statement**
22
23
The Transformer architecture ([Vaswani et al.](https://arxiv.org/abs/1706.03762)),
24
which has become the de facto standard in most Natural Language Processing tasks, has
25
also been applied to the field of computer vision, e.g. Vision
26
Transformers ([Dosovitskiy et al.](https://arxiv.org/abs/2010.11929v2)).
27
28
> In Transformers, the self-attention (SA) is arguably the key to its success which
29
enables input-dependent global interactions, in contrast to convolution operation which
30
constraints interactions in a local region with a shared kernel.
31
32
The **Attention** module is mathematically written as shown in **Equation 1**.
33
34
| ![Attention Equation](https://i.imgur.com/thdHvQx.png) |
35
| :--: |
36
| Equation 1: The mathematical equation of attention (Source: Aritra and Ritwik) |
37
38
Where:
39
40
- `Q` is the query
41
- `K` is the key
42
- `V` is the value
43
- `d_k` is the dimension of the key
44
45
With **self-attention**, the query, key, and value are all sourced from the input
46
sequence. Let us rewrite the attention equation for self-attention as shown in **Equation
47
2**.
48
49
| ![Self-Attention Equation](https://i.imgur.com/OFsmVdP.png) |
50
| :--: |
51
| Equation 2: The mathematical equation of self-attention (Source: Aritra and Ritwik) |
52
53
Upon looking at the equation of self-attention, we see that it is a quadratic equation.
54
Therefore, as the number of tokens increase, so does the computation time (cost too). To
55
mitigate this problem and make Transformers more interpretable, Yang et al.
56
have tried to replace the Self-Attention module with better components.
57
58
**The Solution**
59
60
Yang et al. introduce the Focal Modulation layer to serve as a
61
seamless replacement for the Self-Attention Layer. The layer boasts high
62
interpretability, making it a valuable tool for Deep Learning practitioners.
63
64
In this tutorial, we will delve into the practical application of this layer by training
65
the entire model on the CIFAR-10 dataset and visually interpreting the layer's
66
performance.
67
68
Note: We try to align our implementation with the
69
[official implementation](https://github.com/microsoft/FocalNet).
70
"""
71
72
"""
73
## Setup and Imports
74
75
Keras 3 allows this model to run on JAX, PyTorch, or TensorFlow. We use keras.ops for all mathematical operations to ensure the code remains backend-agnostic.
76
"""
77
78
import os
79
80
# Set backend before importing keras
81
os.environ["KERAS_BACKEND"] = "tensorflow" # Or "torch" or "tensorflow"
82
83
import numpy as np
84
import keras
85
from keras import layers
86
from keras import ops
87
from matplotlib import pyplot as plt
88
from random import randint
89
90
# Set seed for reproducibility using Keras 3 utility.
91
keras.utils.set_random_seed(42)
92
93
"""
94
## Global Configuration
95
96
We do not have any strong rationale behind choosing these hyperparameters. Please feel
97
free to change the configuration and train the model.
98
"""
99
100
# --- GLOBAL CONFIGURATION ---
101
TRAIN_SLICE = 40000
102
BATCH_SIZE = 128 # 1024
103
INPUT_SHAPE = (32, 32, 3)
104
IMAGE_SIZE = 48
105
NUM_CLASSES = 10
106
107
LEARNING_RATE = 1e-4
108
WEIGHT_DECAY = 1e-4
109
EPOCHS = 20
110
111
"""
112
## Data Loading with PyDataset
113
114
Keras 3 introduces PyDataset as a standardized way to handle data.
115
It works identically across all backends and avoids the "Symbolic Tensor" issues often found
116
when using tf.data with JAX or PyTorch.
117
"""
118
119
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
120
(x_train, y_train), (x_val, y_val) = (
121
(x_train[:TRAIN_SLICE], y_train[:TRAIN_SLICE]),
122
(x_train[TRAIN_SLICE:], y_train[TRAIN_SLICE:]),
123
)
124
125
126
class FocalDataset(keras.utils.PyDataset):
127
def __init__(self, x_data, y_data, batch_size, shuffle=False, **kwargs):
128
super().__init__(**kwargs)
129
self.x_data = x_data
130
self.y_data = y_data
131
self.batch_size = batch_size
132
self.shuffle = shuffle
133
self.indices = np.arange(len(x_data))
134
if self.shuffle:
135
np.random.shuffle(self.indices)
136
137
def __len__(self):
138
return int(np.ceil(len(self.x_data) / self.batch_size))
139
140
def __getitem__(self, idx):
141
start = idx * self.batch_size
142
end = min((idx + 1) * self.batch_size, len(self.x_data))
143
batch_indices = self.indices[start:end]
144
145
x_batch = self.x_data[batch_indices]
146
y_batch = self.y_data[batch_indices]
147
148
# Convert to backend-native tensors
149
x_batch = ops.convert_to_tensor(x_batch, dtype="float32")
150
y_batch = ops.convert_to_tensor(y_batch, dtype="int32")
151
152
return x_batch, y_batch
153
154
def on_epoch_end(self):
155
if self.shuffle:
156
np.random.shuffle(self.indices)
157
158
159
train_ds = FocalDataset(x_train, y_train, batch_size=BATCH_SIZE, shuffle=True)
160
val_ds = FocalDataset(x_val, y_val, batch_size=BATCH_SIZE, shuffle=False)
161
test_ds = FocalDataset(x_test, y_test, batch_size=BATCH_SIZE, shuffle=False)
162
163
"""
164
## Architecture
165
166
We pause here to take a quick look at the Architecture of the Focal Modulation Network.
167
**Figure 1** shows how every individual layer is compiled into a single model. This gives
168
us a bird's eye view of the entire architecture.
169
170
| ![Diagram of the model](https://i.imgur.com/v5HYV5R.png) |
171
| :--: |
172
| Figure 1: A diagram of the Focal Modulation model (Source: Aritra and Ritwik) |
173
174
We dive deep into each of these layers in the following sections. This is the order we
175
will follow:
176
177
178
- Patch Embedding Layer
179
- Focal Modulation Block
180
- Multi-Layer Perceptron
181
- Focal Modulation Layer
182
- Hierarchical Contextualization
183
- Gated Aggregation
184
- Building Focal Modulation Block
185
- Building the Basic Layer
186
187
To better understand the architecture in a format we are well versed in, let us see how
188
the Focal Modulation Network would look when drawn like a Transformer architecture.
189
190
**Figure 2** shows the encoder layer of a traditional Transformer architecture where Self
191
Attention is replaced with the Focal Modulation layer.
192
193
The <font color="blue">blue</font> blocks represent the Focal Modulation block. A stack
194
of these blocks builds a single Basic Layer. The <font color="green">green</font> blocks
195
represent the Focal Modulation layer.
196
197
| ![The Entire Architecture](https://i.imgur.com/PduYD6m.png) |
198
| :--: |
199
| Figure 2: The Entire Architecture (Source: Aritra and Ritwik) |
200
"""
201
202
"""
203
## Patch Embedding Layer
204
205
The patch embedding layer is used to patchify the input images and project them into a
206
latent space. This layer is also used as the down-sampling layer in the architecture.
207
"""
208
209
210
class PatchEmbed(layers.Layer):
211
"""Image patch embedding layer, also acts as the down-sampling layer.
212
213
Args:
214
image_size (Tuple[int]): Input image resolution.
215
patch_size (Tuple[int]): Patch spatial resolution.
216
embed_dim (int): Embedding dimension.
217
"""
218
219
def __init__(
220
self, image_size=(224, 224), patch_size=(4, 4), embed_dim=96, **kwargs
221
):
222
super().__init__(**kwargs)
223
self.patch_resolution = [
224
image_size[0] // patch_size[0],
225
image_size[1] // patch_size[1],
226
]
227
self.proj = layers.Conv2D(
228
filters=embed_dim, kernel_size=patch_size, strides=patch_size
229
)
230
self.flatten = layers.Reshape(target_shape=(-1, embed_dim))
231
self.norm = layers.LayerNormalization(epsilon=1e-7)
232
233
def call(self, x):
234
"""Patchifies the image and converts into tokens.
235
236
Args:
237
x: Tensor of shape (B, H, W, C)
238
239
Returns:
240
A tuple of the processed tensor, height of the projected
241
feature map, width of the projected feature map, number
242
of channels of the projected feature map.
243
"""
244
x = self.proj(x)
245
shape = ops.shape(x)
246
height, width, channels = shape[1], shape[2], shape[3]
247
x = self.norm(self.flatten(x))
248
return x, height, width, channels
249
250
251
"""
252
## Focal Modulation block
253
254
A Focal Modulation block can be considered as a single Transformer Block with the Self
255
Attention (SA) module being replaced with Focal Modulation module, as we saw in **Figure
256
2**.
257
258
Let us recall how a focal modulation block is supposed to look like with the aid of the
259
**Figure 3**.
260
261
262
| ![Focal Modulation Block](https://i.imgur.com/bPYTSiB.png) |
263
| :--: |
264
| Figure 3: The isolated view of the Focal Modulation Block (Source: Aritra and Ritwik) |
265
266
The Focal Modulation Block consists of:
267
- Multilayer Perceptron
268
- Focal Modulation layer
269
"""
270
271
"""
272
### Multilayer Perceptron
273
"""
274
275
276
def MLP(in_features, hidden_features=None, out_features=None, mlp_drop_rate=0.0):
277
hidden_features = hidden_features or in_features
278
out_features = out_features or in_features
279
return keras.Sequential(
280
[
281
layers.Dense(units=hidden_features, activation="gelu"),
282
layers.Dense(units=out_features),
283
layers.Dropout(rate=mlp_drop_rate),
284
]
285
)
286
287
288
"""
289
### Focal Modulation layer
290
291
In a typical Transformer architecture, for each visual token (**query**) `x_i in R^C` in
292
an input feature map `X in R^{HxWxC}` a **generic encoding process** produces a feature
293
representation `y_i in R^C`.
294
295
The encoding process consists of **interaction** (with its surroundings for e.g. a dot
296
product), and **aggregation** (over the contexts for e.g weighted mean).
297
298
We will talk about two types of encoding here:
299
- Interaction and then Aggregation in **Self-Attention**
300
- Aggregation and then Interaction in **Focal Modulation**
301
302
**Self-Attention**
303
304
| ![Self-Attention Expression](https://i.imgur.com/heBYp0F.png) |
305
| :--: |
306
| **Figure 4**: Self-Attention module. (Source: Aritra and Ritwik) |
307
308
| ![Aggregation and Interaction for Self-Attention](https://i.imgur.com/j1k8Xmy.png) |
309
| :--: |
310
| **Equation 3:** Aggregation and Interaction in Self-Attention(Surce: Aritra and Ritwik)|
311
312
As shown in **Figure 4** the query and the key interact (in the interaction step) with
313
each other to output the attention scores. The weighted aggregation of the value comes
314
next, known as the aggregation step.
315
316
**Focal Modulation**
317
318
| ![Focal Modulation module](https://i.imgur.com/tmbLgQl.png) |
319
| :--: |
320
| **Figure 5**: Focal Modulation module. (Source: Aritra and Ritwik) |
321
322
| ![Aggregation and Interaction in Focal Modulation](https://i.imgur.com/gsvJfWp.png) |
323
| :--: |
324
| **Equation 4:** Aggregation and Interaction in Focal Modulation (Source: Aritra and Ritwik) |
325
326
**Figure 5** depicts the Focal Modulation layer. `q()` is the query projection
327
function. It is a **linear layer** that projects the query into a latent space. `m ()` is
328
the context aggregation function. Unlike self-attention, the
329
aggregation step takes place in focal modulation before the interaction step.
330
"""
331
332
"""
333
While `q()` is pretty straightforward to understand, the context aggregation function
334
`m()` is more complex. Therefore, this section will focus on `m()`.
335
336
| ![Context Aggregation](https://i.imgur.com/uqIRXI7.png)|
337
| :--: |
338
| **Figure 6**: Context Aggregation function `m()`. (Source: Aritra and Ritwik) |
339
340
The context aggregation function `m()` consists of two parts as shown in **Figure 6**:
341
- Hierarchical Contextualization
342
- Gated Aggregation
343
"""
344
345
"""
346
#### Hierarchical Contextualization
347
348
| ![Hierarchical Contextualization](https://i.imgur.com/q875c83.png)|
349
| :--: |
350
| **Figure 7**: Hierarchical Contextualization (Source: Aritra and Ritwik) |
351
352
In **Figure 7**, we see that the input is first projected linearly. This linear projection
353
produces `Z^0`. Where `Z^0` can be expressed as follows:
354
355
| ![Linear projection of z_not](https://i.imgur.com/pd0Z2Of.png) |
356
| :--: |
357
| Equation 5: Linear projection of `Z^0` (Source: Aritra and Ritwik) |
358
359
`Z^0` is then passed on to a series of Depth-Wise (DWConv) Conv and
360
[GeLU](https://keras.io/api/layers/activations/#gelu-function) layers. The
361
authors term each block of DWConv and GeLU as levels denoted by `l`. In **Figure 6** we
362
have two levels. Mathematically this is represented as:
363
364
| ![Levels of modulation](https://i.imgur.com/ijGD1Df.png) |
365
| :--: |
366
| Equation 6: Levels of the modulation layer (Source: Aritra and Ritwik) |
367
368
where `l in {1, ... , L}`
369
370
The final feature map goes through a Global Average Pooling Layer. This can be expressed
371
as follows:
372
373
| ![Avg Pool](https://i.imgur.com/MQzQhbo.png) |
374
| :--: |
375
| Equation 7: Average Pooling of the final feature (Source: Aritra and Ritwik)|
376
"""
377
378
"""
379
#### Gated Aggregation
380
381
| ![Gated Aggregation](https://i.imgur.com/LwrdDKo.png[/img)|
382
| :--: |
383
| **Figure 8**: Gated Aggregation (Source: Aritra and Ritwik) |
384
385
Now that we have `L+1` intermediate feature maps by virtue of the Hierarchical
386
Contextualization step, we need a gating mechanism that lets some features pass and
387
prohibits others. This can be implemented with the attention module.
388
Later in the tutorial, we will visualize these gates to better understand their
389
usefulness.
390
391
First, we build the weights for aggregation. Here we apply a **linear layer** on the input
392
feature map that projects it into `L+1` dimensions.
393
394
| ![Gates](https://i.imgur.com/1CgEo1G.png) |
395
| :--: |
396
| Eqation 8: Gates (Source: Aritra and Ritwik) |
397
398
Next we perform the weighted aggregation over the contexts.
399
400
| ![z out](https://i.imgur.com/mpJ712R.png) |
401
| :--: |
402
| Eqation 9: Final feature map (Source: Aritra and Ritwik) |
403
404
To enable communication across different channels, we use another linear layer `h()`
405
to obtain the modulator
406
407
| ![Modulator](https://i.imgur.com/0EpT3Ti.png) |
408
| :--: |
409
| Eqation 10: Modulator (Source: Aritra and Ritwik) |
410
411
To sum up the Focal Modulation layer we have:
412
413
| ![Focal Modulation Layer](https://i.imgur.com/1QIhvYA.png) |
414
| :--: |
415
| Eqation 11: Focal Modulation Layer (Source: Aritra and Ritwik) |
416
"""
417
418
419
class FocalModulationLayer(layers.Layer):
420
"""The Focal Modulation layer includes query projection & context aggregation.
421
422
Args:
423
dim (int): Projection dimension.
424
focal_window (int): Window size for focal modulation.
425
focal_level (int): The current focal level.
426
focal_factor (int): Factor of focal modulation.
427
proj_drop_rate (float): Rate of dropout.
428
"""
429
430
def __init__(
431
self,
432
dim,
433
focal_window,
434
focal_level,
435
focal_factor=2,
436
proj_drop_rate=0.0,
437
**kwargs,
438
):
439
super().__init__(**kwargs)
440
self.dim, self.focal_level = dim, focal_level
441
self.initial_proj = layers.Dense(units=(2 * dim) + (focal_level + 1))
442
self.focal_layers = [
443
keras.Sequential(
444
[
445
layers.ZeroPadding2D(
446
padding=((focal_factor * i + focal_window) // 2)
447
),
448
layers.Conv2D(
449
filters=dim,
450
kernel_size=(focal_factor * i + focal_window),
451
activation="gelu",
452
groups=dim,
453
use_bias=False,
454
),
455
]
456
)
457
for i in range(focal_level)
458
]
459
self.gap = layers.GlobalAveragePooling2D(keepdims=True)
460
self.mod_proj = layers.Conv2D(filters=dim, kernel_size=1)
461
self.proj = layers.Dense(units=dim)
462
self.proj_drop = layers.Dropout(proj_drop_rate)
463
464
def call(self, x, training=None):
465
"""Forward pass of the layer.
466
467
Args:
468
x: Tensor of shape (B, H, W, C)
469
"""
470
x_proj = self.initial_proj(x)
471
query, context, gates = ops.split(x_proj, [self.dim, 2 * self.dim], axis=-1)
472
473
# Apply Softmax for numerical stability
474
gates = ops.softmax(gates, axis=-1)
475
self.gates = gates
476
477
context = self.focal_layers[0](context)
478
context_all = context * gates[..., 0:1]
479
for i in range(1, self.focal_level):
480
context = self.focal_layers[i](context)
481
context_all = context_all + (context * gates[..., i : i + 1])
482
483
context_global = ops.gelu(self.gap(context))
484
context_all = context_all + (context_global * gates[..., self.focal_level :])
485
486
self.modulator = self.mod_proj(context_all)
487
x_out = query * self.modulator
488
return self.proj_drop(self.proj(x_out), training=training)
489
490
491
"""
492
### The Focal Modulation block
493
494
Finally, we have all the components we need to build the Focal Modulation block. Here we
495
take the MLP and Focal Modulation layer together and build the Focal Modulation block.
496
"""
497
498
499
class FocalModulationBlock(layers.Layer):
500
"""Combine FFN and Focal Modulation Layer.
501
502
Args:
503
dim (int): Number of input channels.
504
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
505
drop (float): Dropout rate.
506
focal_level (int): Number of focal levels.
507
focal_window (int): Focal window size at first focal level
508
"""
509
510
def __init__(
511
self, dim, mlp_ratio=4.0, drop=0.0, focal_level=1, focal_window=3, **kwargs
512
):
513
super().__init__(**kwargs)
514
self.norm1 = layers.LayerNormalization(epsilon=1e-5)
515
self.modulation = FocalModulationLayer(
516
dim, focal_window, focal_level, proj_drop_rate=drop
517
)
518
self.norm2 = layers.LayerNormalization(epsilon=1e-5)
519
self.mlp = MLP(dim, int(dim * mlp_ratio), mlp_drop_rate=drop)
520
521
def call(self, x, height=None, width=None, channels=None, training=None):
522
"""Processes the input tensor through the focal modulation block.
523
524
Args:
525
x : Inputs of the shape (B, L, C)
526
height (int): The height of the feature map
527
width (int): The width of the feature map
528
channels (int): The number of channels of the feature map
529
530
Returns:
531
The processed tensor.
532
"""
533
res = x
534
x = ops.reshape(x, (-1, height, width, channels))
535
x = self.modulation(x, training=training)
536
x = ops.reshape(x, (-1, height * width, channels))
537
x = res + x
538
return x + self.mlp(self.norm2(x), training=training)
539
540
541
"""
542
## The Basic Layer
543
544
The basic layer consists of a collection of Focal Modulation blocks. This is
545
illustrated in **Figure 9**.
546
547
| ![Basic Layer](https://i.imgur.com/UcZV0K6.png) |
548
| :--: |
549
| **Figure 9**: Basic Layer, a collection of focal modulation blocks. (Source: Aritra and Ritwik) |
550
551
Notice how in **Fig. 9** there are more than one focal modulation blocks denoted by `Nx`.
552
This shows how the Basic Layer is a collection of Focal Modulation blocks.
553
"""
554
555
556
class BasicLayer(layers.Layer):
557
"""Collection of Focal Modulation Blocks.
558
559
Args:
560
dim (int): Dimensions of the model.
561
out_dim (int): Dimension used by the Patch Embedding Layer.
562
input_res (Tuple[int]): Input image resolution.
563
depth (int): The number of Focal Modulation Blocks.
564
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
565
drop (float): Dropout rate.
566
downsample (keras.layers.Layer): Downsampling layer at the end of the layer.
567
focal_level (int): The current focal level.
568
focal_window (int): Focal window used.
569
"""
570
571
def __init__(
572
self,
573
dim,
574
out_dim,
575
input_res,
576
depth,
577
mlp_ratio=4.0,
578
drop=0.0,
579
downsample=None,
580
focal_level=1,
581
focal_window=1,
582
**kwargs,
583
):
584
super().__init__(**kwargs)
585
self.blocks = [
586
FocalModulationBlock(dim, mlp_ratio, drop, focal_level, focal_window)
587
for _ in range(depth)
588
]
589
self.downsample = (
590
downsample(image_size=input_res, patch_size=(2, 2), embed_dim=out_dim)
591
if downsample
592
else None
593
)
594
595
def call(self, x, height=None, width=None, channels=None, training=None):
596
"""Forward pass of the layer.
597
598
Args:
599
x : Tensor of shape (B, L, C)
600
height (int): Height of feature map
601
width (int): Width of feature map
602
channels (int): Embed Dim of feature map
603
604
Returns:
605
A tuple of the processed tensor, changed height, width, and
606
dim of the tensor.
607
"""
608
for block in self.blocks:
609
x = block(
610
x, height=height, width=width, channels=channels, training=training
611
)
612
if self.downsample:
613
x = ops.reshape(x, (-1, height, width, channels))
614
x, height, width, channels = self.downsample(x)
615
return x, height, width, channels
616
617
618
"""
619
## The Focal Modulation Network model
620
621
This is the model that ties everything together.
622
It consists of a collection of Basic Layers with a classification head.
623
For a recap of how this is structured refer to **Figure 1**.
624
"""
625
626
627
class FocalModulationNetwork(keras.Model):
628
"""The Focal Modulation Network.
629
630
Parameters:
631
image_size (Tuple[int]): Spatial size of images used.
632
patch_size (Tuple[int]): Patch size of each patch.
633
num_classes (int): Number of classes used for classification.
634
embed_dim (int): Patch embedding dimension.
635
depths (List[int]): Depth of each Focal Transformer block.
636
"""
637
638
def __init__(
639
self,
640
image_size=(48, 48),
641
patch_size=(4, 4),
642
num_classes=10,
643
embed_dim=64,
644
depths=[2, 3, 2],
645
**kwargs,
646
):
647
super().__init__(**kwargs)
648
# Preprocessing integrated in model for backend-agnostic behavior
649
self.rescaling = layers.Rescaling(1.0 / 255.0)
650
self.resizing_larger = layers.Resizing(image_size[0] + 10, image_size[1] + 10)
651
self.random_crop = layers.RandomCrop(image_size[0], image_size[1])
652
self.resizing_target = layers.Resizing(image_size[0], image_size[1])
653
self.random_flip = layers.RandomFlip("horizontal")
654
655
self.patch_embed = PatchEmbed(image_size, patch_size, embed_dim)
656
self.basic_layers = []
657
for i in range(len(depths)):
658
d = embed_dim * (2**i)
659
self.basic_layers.append(
660
BasicLayer(
661
dim=d,
662
out_dim=d * 2 if i < len(depths) - 1 else None,
663
input_res=(image_size[0] // (2**i), image_size[1] // (2**i)),
664
depth=depths[i],
665
downsample=PatchEmbed if i < len(depths) - 1 else None,
666
)
667
)
668
self.norm = layers.LayerNormalization(epsilon=1e-7)
669
self.avgpool = layers.GlobalAveragePooling1D()
670
self.head = layers.Dense(num_classes, activation="softmax")
671
672
def call(self, x, training=None):
673
"""Forward pass of the layer.
674
675
Args:
676
x: Tensor of shape (B, H, W, C)
677
678
Returns:
679
The logits.
680
"""
681
x = self.rescaling(x)
682
if training:
683
x = self.resizing_larger(x)
684
x = self.random_crop(x)
685
x = self.random_flip(x)
686
else:
687
x = self.resizing_target(x)
688
689
x, h, w, c = self.patch_embed(x)
690
for layer in self.basic_layers:
691
x, h, w, c = layer(x, height=h, width=w, channels=c, training=training)
692
return self.head(self.avgpool(self.norm(x)))
693
694
695
"""
696
## Train the model
697
698
Now with all the components in place and the architecture actually built, we are ready to
699
put it to good use.
700
701
In this section, we train our Focal Modulation model on the CIFAR-10 dataset.
702
"""
703
704
"""
705
### Visualization Callback
706
707
A key feature of the Focal Modulation Network is explicit input-dependency. This means
708
the modulator is calculated by looking at the local features around the target location,
709
so it depends on the input. In very simple terms, this makes interpretation easy. We can
710
simply lay down the gating values and the original image, next to each other to see how
711
the gating mechanism works.
712
713
The authors of the paper visualize the gates and the modulator in order to focus on the
714
interpretability of the Focal Modulation layer. Below is a visualization
715
callback that shows the gates and modulator of a specific layer in the model while the
716
model trains.
717
718
We will notice later that as the model trains, the visualizations get better.
719
720
The gates appear to selectively permit certain aspects of the input image to pass
721
through, while gently disregarding others, ultimately leading to improved classification
722
accuracy.
723
"""
724
725
726
def display_grid(test_images, gates, modulator):
727
"""Displays the image with the gates and modulator overlayed.
728
729
Args:
730
test_images: A batch of test images.
731
gates: The gates of the Focal Modualtion Layer.
732
modulator: The modulator of the Focal Modulation Layer.
733
"""
734
test_images_np = ops.convert_to_numpy(test_images) / 255.0
735
gates_np = ops.convert_to_numpy(gates)
736
mod_np = ops.convert_to_numpy(ops.norm(modulator, axis=-1))
737
738
num_gates = gates_np.shape[-1]
739
idx = randint(0, test_images_np.shape[0] - 1)
740
fig, ax = plt.subplots(1, num_gates + 2, figsize=((num_gates + 2) * 4, 4))
741
742
ax[0].imshow(test_images_np[idx])
743
ax[0].set_title("Original")
744
ax[0].axis("off")
745
for i in range(num_gates):
746
ax[i + 1].imshow(test_images_np[idx])
747
ax[i + 1].imshow(gates_np[idx, ..., i], cmap="inferno", alpha=0.6)
748
ax[i + 1].set_title(f"Gate {i+1}")
749
ax[i + 1].axis("off")
750
751
ax[-1].imshow(test_images_np[idx])
752
ax[-1].imshow(mod_np[idx], cmap="inferno", alpha=0.6)
753
ax[-1].set_title("Modulator")
754
ax[-1].axis("off")
755
plt.show()
756
plt.close()
757
758
759
"""
760
### TrainMonitor
761
"""
762
763
# Fetch test batch for callback
764
test_batch_images, _ = test_ds[0]
765
766
767
class TrainMonitor(keras.callbacks.Callback):
768
def __init__(self, epoch_interval=10):
769
super().__init__()
770
self.epoch_interval = epoch_interval
771
self.upsampler = layers.UpSampling2D(size=(4, 4), interpolation="bilinear")
772
773
def on_epoch_end(self, epoch, logs=None):
774
if (epoch + 1) % self.epoch_interval == 0:
775
_ = self.model(test_batch_images, training=False)
776
layer = self.model.basic_layers[1].blocks[-1].modulation
777
display_grid(
778
test_batch_images,
779
self.upsampler(layer.gates),
780
self.upsampler(layer.modulator),
781
)
782
783
784
"""
785
### Learning Rate scheduler
786
"""
787
788
789
class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
790
def __init__(self, lr_base, total_steps, warmup_steps):
791
super().__init__()
792
self.lr_base, self.total_steps, self.warmup_steps = (
793
lr_base,
794
total_steps,
795
warmup_steps,
796
)
797
798
def __call__(self, step):
799
step = ops.cast(step, "float32")
800
cos_lr = (
801
0.5
802
* self.lr_base
803
* (
804
1
805
+ ops.cos(
806
np.pi
807
* (step - self.warmup_steps)
808
/ (self.total_steps - self.warmup_steps)
809
)
810
)
811
)
812
warmup_lr = (self.lr_base / self.warmup_steps) * step
813
return ops.where(
814
step < self.warmup_steps,
815
warmup_lr,
816
ops.where(step > self.total_steps, 0.0, cos_lr),
817
)
818
819
820
total_steps = (len(x_train) // BATCH_SIZE) * EPOCHS
821
scheduled_lrs = WarmUpCosine(LEARNING_RATE, total_steps, int(total_steps * 0.15))
822
823
"""
824
### Initialize, compile and train the model
825
"""
826
827
model = FocalModulationNetwork(image_size=(IMAGE_SIZE, IMAGE_SIZE))
828
model.compile(
829
optimizer=keras.optimizers.AdamW(
830
learning_rate=scheduled_lrs, weight_decay=WEIGHT_DECAY, clipnorm=1.0
831
),
832
loss="sparse_categorical_crossentropy",
833
metrics=["accuracy"],
834
)
835
836
history = model.fit(
837
train_ds,
838
validation_data=val_ds,
839
epochs=EPOCHS,
840
callbacks=[TrainMonitor(epoch_interval=5)],
841
)
842
843
"""
844
## Plot loss and accuracy
845
"""
846
plt.figure(figsize=(12, 4))
847
plt.subplot(1, 2, 1)
848
plt.plot(history.history["loss"], label="Train Loss")
849
plt.plot(history.history["val_loss"], label="Val Loss")
850
plt.legend()
851
plt.subplot(1, 2, 2)
852
plt.plot(history.history["accuracy"], label="Train Acc")
853
plt.plot(history.history["val_accuracy"], label="Val Acc")
854
plt.legend()
855
plt.show()
856
857
"""
858
## Test visualizations
859
860
Let's test our model on some test images and see how the gates look like.
861
"""
862
test_images, test_labels = next(iter(test_ds))
863
864
_ = model(test_images, training=False)
865
866
target_layer = model.basic_layers[1].blocks[-1].modulation
867
gates = target_layer.gates
868
modulator = target_layer.modulator
869
870
upsampler = layers.UpSampling2D(size=(4, 4), interpolation="bilinear")
871
gates_upsampled = upsampler(gates)
872
modulator_upsampled = upsampler(modulator)
873
874
for row in range(5):
875
display_grid(
876
test_images=test_images,
877
gates=gates_upsampled,
878
modulator=modulator_upsampled,
879
)
880
881
"""
882
## Conclusion
883
884
The proposed architecture, the Focal Modulation Network
885
architecture is a mechanism that allows different
886
parts of an image to interact with each other in a way that depends on the image itself.
887
It works by first gathering different levels of context information around each part of
888
the image (the "query token"), then using a gate to decide which context information is
889
most relevant, and finally combining the chosen information in a simple but effective
890
way.
891
892
This is meant as a replacement of Self-Attention mechanism from the Transformer
893
architecture. The key feature that makes this research notable is not the conception of
894
attention-less networks, but rather the introduction of a equally powerful architecture
895
that is interpretable.
896
897
The authors also mention that they created a series of Focal Modulation Networks
898
(FocalNets) that significantly outperform Self-Attention counterparts and with a fraction
899
of parameters and pretraining data.
900
901
The FocalNets architecture has the potential to deliver impressive results and offers a
902
simple implementation. Its promising performance and ease of use make it an attractive
903
alternative to Self-Attention for researchers to explore in their own projects. It could
904
potentially become widely adopted by the Deep Learning community in the near future.
905
906
## Acknowledgement
907
908
We would like to thank [PyImageSearch](https://pyimagesearch.com/) for providing with a
909
Colab Pro account, [JarvisLabs.ai](https://cloud.jarvislabs.ai/) for GPU credits,
910
and also Microsoft Research for providing an
911
[official implementation](https://github.com/microsoft/FocalNet) of their paper.
912
We would also like to extend our gratitude to the first author of the
913
paper [Jianwei Yang](https://twitter.com/jw2yang4ai) who reviewed this tutorial
914
extensively.
915
"""
916
917
"""
918
## Relevant Chapters from Deep Learning with Python
919
- [Chapter 8: Image classification](https://deeplearningwithpython.io/chapters/chapter08_image-classification)
920
"""
921
922