Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/focal_modulation_network.py
3507 views
1
"""
2
Title: Focal Modulation: A replacement for Self-Attention
3
Author: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha)
4
Date created: 2023/01/25
5
Last modified: 2023/02/15
6
Description: Image classification with Focal Modulation Networks.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
This tutorial aims to provide a comprehensive guide to the implementation of
14
Focal Modulation Networks, as presented in
15
[Yang et al.](https://arxiv.org/abs/2203.11926).
16
17
This tutorial will provide a formal, minimalistic approach to implementing Focal
18
Modulation Networks and explore its potential applications in the field of Deep Learning.
19
20
**Problem statement**
21
22
The Transformer architecture ([Vaswani et al.](https://arxiv.org/abs/1706.03762)),
23
which has become the de facto standard in most Natural Language Processing tasks, has
24
also been applied to the field of computer vision, e.g. Vision
25
Transformers ([Dosovitskiy et al.](https://arxiv.org/abs/2010.11929v2)).
26
27
> In Transformers, the self-attention (SA) is arguably the key to its success which
28
enables input-dependent global interactions, in contrast to convolution operation which
29
constraints interactions in a local region with a shared kernel.
30
31
The **Attention** module is mathematically written as shown in **Equation 1**.
32
33
| ![Attention Equation](https://i.imgur.com/thdHvQx.png) |
34
| :--: |
35
| Equation 1: The mathematical equation of attention (Source: Aritra and Ritwik) |
36
37
Where:
38
39
- `Q` is the query
40
- `K` is the key
41
- `V` is the value
42
- `d_k` is the dimension of the key
43
44
With **self-attention**, the query, key, and value are all sourced from the input
45
sequence. Let us rewrite the attention equation for self-attention as shown in **Equation
46
2**.
47
48
| ![Self-Attention Equation](https://i.imgur.com/OFsmVdP.png) |
49
| :--: |
50
| Equation 2: The mathematical equation of self-attention (Source: Aritra and Ritwik) |
51
52
Upon looking at the equation of self-attention, we see that it is a quadratic equation.
53
Therefore, as the number of tokens increase, so does the computation time (cost too). To
54
mitigate this problem and make Transformers more interpretable, Yang et al.
55
have tried to replace the Self-Attention module with better components.
56
57
**The Solution**
58
59
Yang et al. introduce the Focal Modulation layer to serve as a
60
seamless replacement for the Self-Attention Layer. The layer boasts high
61
interpretability, making it a valuable tool for Deep Learning practitioners.
62
63
In this tutorial, we will delve into the practical application of this layer by training
64
the entire model on the CIFAR-10 dataset and visually interpreting the layer's
65
performance.
66
67
Note: We try to align our implementation with the
68
[official implementation](https://github.com/microsoft/FocalNet).
69
"""
70
71
"""
72
## Setup and Imports
73
74
We use tensorflow version `2.11.0` for this tutorial.
75
"""
76
77
import numpy as np
78
import tensorflow as tf
79
from tensorflow import keras
80
from tensorflow.keras import layers
81
from tensorflow.keras.optimizers.experimental import AdamW
82
from typing import Optional, Tuple, List
83
from matplotlib import pyplot as plt
84
from random import randint
85
86
# Set seed for reproducibility.
87
tf.keras.utils.set_random_seed(42)
88
89
"""
90
## Global Configuration
91
92
We do not have any strong rationale behind choosing these hyperparameters. Please feel
93
free to change the configuration and train the model.
94
"""
95
96
# DATA
97
TRAIN_SLICE = 40000
98
BUFFER_SIZE = 2048
99
BATCH_SIZE = 1024
100
AUTO = tf.data.AUTOTUNE
101
INPUT_SHAPE = (32, 32, 3)
102
IMAGE_SIZE = 48
103
NUM_CLASSES = 10
104
105
# OPTIMIZER
106
LEARNING_RATE = 1e-4
107
WEIGHT_DECAY = 1e-4
108
109
# TRAINING
110
EPOCHS = 25
111
112
"""
113
## Load and process the CIFAR-10 dataset
114
"""
115
116
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
117
(x_train, y_train), (x_val, y_val) = (
118
(x_train[:TRAIN_SLICE], y_train[:TRAIN_SLICE]),
119
(x_train[TRAIN_SLICE:], y_train[TRAIN_SLICE:]),
120
)
121
122
"""
123
### Build the augmentations
124
125
We use the `keras.Sequential` API to compose all the individual augmentation steps
126
into one API.
127
"""
128
129
# Build the `train` augmentation pipeline.
130
train_aug = keras.Sequential(
131
[
132
layers.Rescaling(1 / 255.0),
133
layers.Resizing(INPUT_SHAPE[0] + 20, INPUT_SHAPE[0] + 20),
134
layers.RandomCrop(IMAGE_SIZE, IMAGE_SIZE),
135
layers.RandomFlip("horizontal"),
136
],
137
name="train_data_augmentation",
138
)
139
140
# Build the `val` and `test` data pipeline.
141
test_aug = keras.Sequential(
142
[
143
layers.Rescaling(1 / 255.0),
144
layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),
145
],
146
name="test_data_augmentation",
147
)
148
149
"""
150
### Build `tf.data` pipeline
151
"""
152
153
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
154
train_ds = (
155
train_ds.map(
156
lambda image, label: (train_aug(image), label), num_parallel_calls=AUTO
157
)
158
.shuffle(BUFFER_SIZE)
159
.batch(BATCH_SIZE)
160
.prefetch(AUTO)
161
)
162
163
val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
164
val_ds = (
165
val_ds.map(lambda image, label: (test_aug(image), label), num_parallel_calls=AUTO)
166
.batch(BATCH_SIZE)
167
.prefetch(AUTO)
168
)
169
170
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
171
test_ds = (
172
test_ds.map(lambda image, label: (test_aug(image), label), num_parallel_calls=AUTO)
173
.batch(BATCH_SIZE)
174
.prefetch(AUTO)
175
)
176
177
"""
178
## Architecture
179
180
We pause here to take a quick look at the Architecture of the Focal Modulation Network.
181
**Figure 1** shows how every individual layer is compiled into a single model. This gives
182
us a bird's eye view of the entire architecture.
183
184
| ![Diagram of the model](https://i.imgur.com/v5HYV5R.png) |
185
| :--: |
186
| Figure 1: A diagram of the Focal Modulation model (Source: Aritra and Ritwik) |
187
188
We dive deep into each of these layers in the following sections. This is the order we
189
will follow:
190
191
192
- Patch Embedding Layer
193
- Focal Modulation Block
194
- Multi-Layer Perceptron
195
- Focal Modulation Layer
196
- Hierarchical Contextualization
197
- Gated Aggregation
198
- Building Focal Modulation Block
199
- Building the Basic Layer
200
201
To better understand the architecture in a format we are well versed in, let us see how
202
the Focal Modulation Network would look when drawn like a Transformer architecture.
203
204
**Figure 2** shows the encoder layer of a traditional Transformer architecture where Self
205
Attention is replaced with the Focal Modulation layer.
206
207
The <font color="blue">blue</font> blocks represent the Focal Modulation block. A stack
208
of these blocks builds a single Basic Layer. The <font color="green">green</font> blocks
209
represent the Focal Modulation layer.
210
211
| ![The Entire Architecture](https://i.imgur.com/PduYD6m.png) |
212
| :--: |
213
| Figure 2: The Entire Architecture (Source: Aritra and Ritwik) |
214
"""
215
216
"""
217
## Patch Embedding Layer
218
219
The patch embedding layer is used to patchify the input images and project them into a
220
latent space. This layer is also used as the down-sampling layer in the architecture.
221
"""
222
223
224
class PatchEmbed(layers.Layer):
225
"""Image patch embedding layer, also acts as the down-sampling layer.
226
227
Args:
228
image_size (Tuple[int]): Input image resolution.
229
patch_size (Tuple[int]): Patch spatial resolution.
230
embed_dim (int): Embedding dimension.
231
"""
232
233
def __init__(
234
self,
235
image_size: Tuple[int] = (224, 224),
236
patch_size: Tuple[int] = (4, 4),
237
embed_dim: int = 96,
238
**kwargs,
239
):
240
super().__init__(**kwargs)
241
patch_resolution = [
242
image_size[0] // patch_size[0],
243
image_size[1] // patch_size[1],
244
]
245
self.image_size = image_size
246
self.patch_size = patch_size
247
self.embed_dim = embed_dim
248
self.patch_resolution = patch_resolution
249
self.num_patches = patch_resolution[0] * patch_resolution[1]
250
self.proj = layers.Conv2D(
251
filters=embed_dim, kernel_size=patch_size, strides=patch_size
252
)
253
self.flatten = layers.Reshape(target_shape=(-1, embed_dim))
254
self.norm = keras.layers.LayerNormalization(epsilon=1e-7)
255
256
def call(self, x: tf.Tensor) -> Tuple[tf.Tensor, int, int, int]:
257
"""Patchifies the image and converts into tokens.
258
259
Args:
260
x: Tensor of shape (B, H, W, C)
261
262
Returns:
263
A tuple of the processed tensor, height of the projected
264
feature map, width of the projected feature map, number
265
of channels of the projected feature map.
266
"""
267
# Project the inputs.
268
x = self.proj(x)
269
270
# Obtain the shape from the projected tensor.
271
height = tf.shape(x)[1]
272
width = tf.shape(x)[2]
273
channels = tf.shape(x)[3]
274
275
# B, H, W, C -> B, H*W, C
276
x = self.norm(self.flatten(x))
277
278
return x, height, width, channels
279
280
281
"""
282
## Focal Modulation block
283
284
A Focal Modulation block can be considered as a single Transformer Block with the Self
285
Attention (SA) module being replaced with Focal Modulation module, as we saw in **Figure
286
2**.
287
288
Let us recall how a focal modulation block is supposed to look like with the aid of the
289
**Figure 3**.
290
291
292
| ![Focal Modulation Block](https://i.imgur.com/bPYTSiB.png) |
293
| :--: |
294
| Figure 3: The isolated view of the Focal Modulation Block (Source: Aritra and Ritwik) |
295
296
The Focal Modulation Block consists of:
297
- Multilayer Perceptron
298
- Focal Modulation layer
299
"""
300
301
"""
302
### Multilayer Perceptron
303
"""
304
305
306
def MLP(
307
in_features: int,
308
hidden_features: Optional[int] = None,
309
out_features: Optional[int] = None,
310
mlp_drop_rate: float = 0.0,
311
):
312
hidden_features = hidden_features or in_features
313
out_features = out_features or in_features
314
315
return keras.Sequential(
316
[
317
layers.Dense(units=hidden_features, activation=keras.activations.gelu),
318
layers.Dense(units=out_features),
319
layers.Dropout(rate=mlp_drop_rate),
320
]
321
)
322
323
324
"""
325
### Focal Modulation layer
326
327
In a typical Transformer architecture, for each visual token (**query**) `x_i in R^C` in
328
an input feature map `X in R^{HxWxC}` a **generic encoding process** produces a feature
329
representation `y_i in R^C`.
330
331
The encoding process consists of **interaction** (with its surroundings for e.g. a dot
332
product), and **aggregation** (over the contexts for e.g weighted mean).
333
334
We will talk about two types of encoding here:
335
- Interaction and then Aggregation in **Self-Attention**
336
- Aggregation and then Interaction in **Focal Modulation**
337
338
**Self-Attention**
339
340
| ![Self-Attention Expression](https://i.imgur.com/heBYp0F.png) |
341
| :--: |
342
| **Figure 4**: Self-Attention module. (Source: Aritra and Ritwik) |
343
344
| ![Aggregation and Interaction for Self-Attention](https://i.imgur.com/j1k8Xmy.png) |
345
| :--: |
346
| **Equation 3:** Aggregation and Interaction in Self-Attention(Surce: Aritra and Ritwik)|
347
348
As shown in **Figure 4** the query and the key interact (in the interaction step) with
349
each other to output the attention scores. The weighted aggregation of the value comes
350
next, known as the aggregation step.
351
352
**Focal Modulation**
353
354
| ![Focal Modulation module](https://i.imgur.com/tmbLgQl.png) |
355
| :--: |
356
| **Figure 5**: Focal Modulation module. (Source: Aritra and Ritwik) |
357
358
| ![Aggregation and Interaction in Focal Modulation](https://i.imgur.com/gsvJfWp.png) |
359
| :--: |
360
| **Equation 4:** Aggregation and Interaction in Focal Modulation (Source: Aritra and Ritwik) |
361
362
**Figure 5** depicts the Focal Modulation layer. `q()` is the query projection
363
function. It is a **linear layer** that projects the query into a latent space. `m ()` is
364
the context aggregation function. Unlike self-attention, the
365
aggregation step takes place in focal modulation before the interaction step.
366
"""
367
368
"""
369
While `q()` is pretty straightforward to understand, the context aggregation function
370
`m()` is more complex. Therefore, this section will focus on `m()`.
371
372
| ![Context Aggregation](https://i.imgur.com/uqIRXI7.png)|
373
| :--: |
374
| **Figure 6**: Context Aggregation function `m()`. (Source: Aritra and Ritwik) |
375
376
The context aggregation function `m()` consists of two parts as shown in **Figure 6**:
377
- Hierarchical Contextualization
378
- Gated Aggregation
379
"""
380
381
"""
382
#### Hierarchical Contextualization
383
384
| ![Hierarchical Contextualization](https://i.imgur.com/q875c83.png)|
385
| :--: |
386
| **Figure 7**: Hierarchical Contextualization (Source: Aritra and Ritwik) |
387
388
In **Figure 7**, we see that the input is first projected linearly. This linear projection
389
produces `Z^0`. Where `Z^0` can be expressed as follows:
390
391
| ![Linear projection of z_not](https://i.imgur.com/pd0Z2Of.png) |
392
| :--: |
393
| Equation 5: Linear projection of `Z^0` (Source: Aritra and Ritwik) |
394
395
`Z^0` is then passed on to a series of Depth-Wise (DWConv) Conv and
396
[GeLU](https://www.tensorflow.org/api_docs/python/tf/keras/activations/gelu) layers. The
397
authors term each block of DWConv and GeLU as levels denoted by `l`. In **Figure 6** we
398
have two levels. Mathematically this is represented as:
399
400
| ![Levels of modulation](https://i.imgur.com/ijGD1Df.png) |
401
| :--: |
402
| Equation 6: Levels of the modulation layer (Source: Aritra and Ritwik) |
403
404
where `l in {1, ... , L}`
405
406
The final feature map goes through a Global Average Pooling Layer. This can be expressed
407
as follows:
408
409
| ![Avg Pool](https://i.imgur.com/MQzQhbo.png) |
410
| :--: |
411
| Equation 7: Average Pooling of the final feature (Source: Aritra and Ritwik)|
412
"""
413
414
"""
415
#### Gated Aggregation
416
417
| ![Gated Aggregation](https://i.imgur.com/LwrdDKo.png[/img)|
418
| :--: |
419
| **Figure 8**: Gated Aggregation (Source: Aritra and Ritwik) |
420
421
Now that we have `L+1` intermediate feature maps by virtue of the Hierarchical
422
Contextualization step, we need a gating mechanism that lets some features pass and
423
prohibits others. This can be implemented with the attention module.
424
Later in the tutorial, we will visualize these gates to better understand their
425
usefulness.
426
427
First, we build the weights for aggregation. Here we apply a **linear layer** on the input
428
feature map that projects it into `L+1` dimensions.
429
430
| ![Gates](https://i.imgur.com/1CgEo1G.png) |
431
| :--: |
432
| Eqation 8: Gates (Source: Aritra and Ritwik) |
433
434
Next we perform the weighted aggregation over the contexts.
435
436
| ![z out](https://i.imgur.com/mpJ712R.png) |
437
| :--: |
438
| Eqation 9: Final feature map (Source: Aritra and Ritwik) |
439
440
To enable communication across different channels, we use another linear layer `h()`
441
to obtain the modulator
442
443
| ![Modulator](https://i.imgur.com/0EpT3Ti.png) |
444
| :--: |
445
| Eqation 10: Modulator (Source: Aritra and Ritwik) |
446
447
To sum up the Focal Modulation layer we have:
448
449
| ![Focal Modulation Layer](https://i.imgur.com/1QIhvYA.png) |
450
| :--: |
451
| Eqation 11: Focal Modulation Layer (Source: Aritra and Ritwik) |
452
"""
453
454
455
class FocalModulationLayer(layers.Layer):
456
"""The Focal Modulation layer includes query projection & context aggregation.
457
458
Args:
459
dim (int): Projection dimension.
460
focal_window (int): Window size for focal modulation.
461
focal_level (int): The current focal level.
462
focal_factor (int): Factor of focal modulation.
463
proj_drop_rate (float): Rate of dropout.
464
"""
465
466
def __init__(
467
self,
468
dim: int,
469
focal_window: int,
470
focal_level: int,
471
focal_factor: int = 2,
472
proj_drop_rate: float = 0.0,
473
**kwargs,
474
):
475
super().__init__(**kwargs)
476
self.dim = dim
477
self.focal_window = focal_window
478
self.focal_level = focal_level
479
self.focal_factor = focal_factor
480
self.proj_drop_rate = proj_drop_rate
481
482
# Project the input feature into a new feature space using a
483
# linear layer. Note the `units` used. We will be projecting the input
484
# feature all at once and split the projection into query, context,
485
# and gates.
486
self.initial_proj = layers.Dense(
487
units=(2 * self.dim) + (self.focal_level + 1),
488
use_bias=True,
489
)
490
self.focal_layers = list()
491
self.kernel_sizes = list()
492
for idx in range(self.focal_level):
493
kernel_size = (self.focal_factor * idx) + self.focal_window
494
depth_gelu_block = keras.Sequential(
495
[
496
layers.ZeroPadding2D(padding=(kernel_size // 2, kernel_size // 2)),
497
layers.Conv2D(
498
filters=self.dim,
499
kernel_size=kernel_size,
500
activation=keras.activations.gelu,
501
groups=self.dim,
502
use_bias=False,
503
),
504
]
505
)
506
self.focal_layers.append(depth_gelu_block)
507
self.kernel_sizes.append(kernel_size)
508
self.activation = keras.activations.gelu
509
self.gap = layers.GlobalAveragePooling2D(keepdims=True)
510
self.modulator_proj = layers.Conv2D(
511
filters=self.dim,
512
kernel_size=(1, 1),
513
use_bias=True,
514
)
515
self.proj = layers.Dense(units=self.dim)
516
self.proj_drop = layers.Dropout(self.proj_drop_rate)
517
518
def call(self, x: tf.Tensor, training: Optional[bool] = None) -> tf.Tensor:
519
"""Forward pass of the layer.
520
521
Args:
522
x: Tensor of shape (B, H, W, C)
523
"""
524
# Apply the linear projecion to the input feature map
525
x_proj = self.initial_proj(x)
526
527
# Split the projected x into query, context and gates
528
query, context, self.gates = tf.split(
529
value=x_proj,
530
num_or_size_splits=[self.dim, self.dim, self.focal_level + 1],
531
axis=-1,
532
)
533
534
# Context aggregation
535
context = self.focal_layers[0](context)
536
context_all = context * self.gates[..., 0:1]
537
for idx in range(1, self.focal_level):
538
context = self.focal_layers[idx](context)
539
context_all += context * self.gates[..., idx : idx + 1]
540
541
# Build the global context
542
context_global = self.activation(self.gap(context))
543
context_all += context_global * self.gates[..., self.focal_level :]
544
545
# Focal Modulation
546
self.modulator = self.modulator_proj(context_all)
547
x_output = query * self.modulator
548
549
# Project the output and apply dropout
550
x_output = self.proj(x_output)
551
x_output = self.proj_drop(x_output)
552
553
return x_output
554
555
556
"""
557
### The Focal Modulation block
558
559
Finally, we have all the components we need to build the Focal Modulation block. Here we
560
take the MLP and Focal Modulation layer together and build the Focal Modulation block.
561
"""
562
563
564
class FocalModulationBlock(layers.Layer):
565
"""Combine FFN and Focal Modulation Layer.
566
567
Args:
568
dim (int): Number of input channels.
569
input_resolution (Tuple[int]): Input resulotion.
570
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
571
drop (float): Dropout rate.
572
drop_path (float): Stochastic depth rate.
573
focal_level (int): Number of focal levels.
574
focal_window (int): Focal window size at first focal level
575
"""
576
577
def __init__(
578
self,
579
dim: int,
580
input_resolution: Tuple[int],
581
mlp_ratio: float = 4.0,
582
drop: float = 0.0,
583
drop_path: float = 0.0,
584
focal_level: int = 1,
585
focal_window: int = 3,
586
**kwargs,
587
):
588
super().__init__(**kwargs)
589
self.dim = dim
590
self.input_resolution = input_resolution
591
self.mlp_ratio = mlp_ratio
592
self.focal_level = focal_level
593
self.focal_window = focal_window
594
self.norm = layers.LayerNormalization(epsilon=1e-5)
595
self.modulation = FocalModulationLayer(
596
dim=self.dim,
597
focal_window=self.focal_window,
598
focal_level=self.focal_level,
599
proj_drop_rate=drop,
600
)
601
mlp_hidden_dim = int(self.dim * self.mlp_ratio)
602
self.mlp = MLP(
603
in_features=self.dim,
604
hidden_features=mlp_hidden_dim,
605
mlp_drop_rate=drop,
606
)
607
608
def call(self, x: tf.Tensor, height: int, width: int, channels: int) -> tf.Tensor:
609
"""Processes the input tensor through the focal modulation block.
610
611
Args:
612
x (tf.Tensor): Inputs of the shape (B, L, C)
613
height (int): The height of the feature map
614
width (int): The width of the feature map
615
channels (int): The number of channels of the feature map
616
617
Returns:
618
The processed tensor.
619
"""
620
shortcut = x
621
622
# Focal Modulation
623
x = tf.reshape(x, shape=(-1, height, width, channels))
624
x = self.modulation(x)
625
x = tf.reshape(x, shape=(-1, height * width, channels))
626
627
# FFN
628
x = shortcut + x
629
x = x + self.mlp(self.norm(x))
630
return x
631
632
633
"""
634
## The Basic Layer
635
636
The basic layer consists of a collection of Focal Modulation blocks. This is
637
illustrated in **Figure 9**.
638
639
| ![Basic Layer](https://i.imgur.com/UcZV0K6.png) |
640
| :--: |
641
| **Figure 9**: Basic Layer, a collection of focal modulation blocks. (Source: Aritra and Ritwik) |
642
643
Notice how in **Fig. 9** there are more than one focal modulation blocks denoted by `Nx`.
644
This shows how the Basic Layer is a collection of Focal Modulation blocks.
645
"""
646
647
648
class BasicLayer(layers.Layer):
649
"""Collection of Focal Modulation Blocks.
650
651
Args:
652
dim (int): Dimensions of the model.
653
out_dim (int): Dimension used by the Patch Embedding Layer.
654
input_resolution (Tuple[int]): Input image resolution.
655
depth (int): The number of Focal Modulation Blocks.
656
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
657
drop (float): Dropout rate.
658
downsample (tf.keras.layers.Layer): Downsampling layer at the end of the layer.
659
focal_level (int): The current focal level.
660
focal_window (int): Focal window used.
661
"""
662
663
def __init__(
664
self,
665
dim: int,
666
out_dim: int,
667
input_resolution: Tuple[int],
668
depth: int,
669
mlp_ratio: float = 4.0,
670
drop: float = 0.0,
671
downsample=None,
672
focal_level: int = 1,
673
focal_window: int = 1,
674
**kwargs,
675
):
676
super().__init__(**kwargs)
677
self.dim = dim
678
self.input_resolution = input_resolution
679
self.depth = depth
680
self.blocks = [
681
FocalModulationBlock(
682
dim=dim,
683
input_resolution=input_resolution,
684
mlp_ratio=mlp_ratio,
685
drop=drop,
686
focal_level=focal_level,
687
focal_window=focal_window,
688
)
689
for i in range(self.depth)
690
]
691
692
# Downsample layer at the end of the layer
693
if downsample is not None:
694
self.downsample = downsample(
695
image_size=input_resolution,
696
patch_size=(2, 2),
697
embed_dim=out_dim,
698
)
699
else:
700
self.downsample = None
701
702
def call(
703
self, x: tf.Tensor, height: int, width: int, channels: int
704
) -> Tuple[tf.Tensor, int, int, int]:
705
"""Forward pass of the layer.
706
707
Args:
708
x (tf.Tensor): Tensor of shape (B, L, C)
709
height (int): Height of feature map
710
width (int): Width of feature map
711
channels (int): Embed Dim of feature map
712
713
Returns:
714
A tuple of the processed tensor, changed height, width, and
715
dim of the tensor.
716
"""
717
# Apply Focal Modulation Blocks
718
for block in self.blocks:
719
x = block(x, height, width, channels)
720
721
# Except the last Basic Layer, all the layers have
722
# downsample at the end of it.
723
if self.downsample is not None:
724
x = tf.reshape(x, shape=(-1, height, width, channels))
725
x, height_o, width_o, channels_o = self.downsample(x)
726
else:
727
height_o, width_o, channels_o = height, width, channels
728
729
return x, height_o, width_o, channels_o
730
731
732
"""
733
## The Focal Modulation Network model
734
735
This is the model that ties everything together.
736
It consists of a collection of Basic Layers with a classification head.
737
For a recap of how this is structured refer to **Figure 1**.
738
"""
739
740
741
class FocalModulationNetwork(keras.Model):
742
"""The Focal Modulation Network.
743
744
Parameters:
745
image_size (Tuple[int]): Spatial size of images used.
746
patch_size (Tuple[int]): Patch size of each patch.
747
num_classes (int): Number of classes used for classification.
748
embed_dim (int): Patch embedding dimension.
749
depths (List[int]): Depth of each Focal Transformer block.
750
mlp_ratio (float): Ratio of expansion for the intermediate layer of MLP.
751
drop_rate (float): The dropout rate for FM and MLP layers.
752
focal_levels (list): How many focal levels at all stages.
753
Note that this excludes the finest-grain level.
754
focal_windows (list): The focal window size at all stages.
755
"""
756
757
def __init__(
758
self,
759
image_size: Tuple[int] = (48, 48),
760
patch_size: Tuple[int] = (4, 4),
761
num_classes: int = 10,
762
embed_dim: int = 256,
763
depths: List[int] = [2, 3, 2],
764
mlp_ratio: float = 4.0,
765
drop_rate: float = 0.1,
766
focal_levels=[2, 2, 2],
767
focal_windows=[3, 3, 3],
768
**kwargs,
769
):
770
super().__init__(**kwargs)
771
self.num_layers = len(depths)
772
embed_dim = [embed_dim * (2**i) for i in range(self.num_layers)]
773
self.num_classes = num_classes
774
self.embed_dim = embed_dim
775
self.num_features = embed_dim[-1]
776
self.mlp_ratio = mlp_ratio
777
self.patch_embed = PatchEmbed(
778
image_size=image_size,
779
patch_size=patch_size,
780
embed_dim=embed_dim[0],
781
)
782
num_patches = self.patch_embed.num_patches
783
patches_resolution = self.patch_embed.patch_resolution
784
self.patches_resolution = patches_resolution
785
self.pos_drop = layers.Dropout(drop_rate)
786
self.basic_layers = list()
787
for i_layer in range(self.num_layers):
788
layer = BasicLayer(
789
dim=embed_dim[i_layer],
790
out_dim=(
791
embed_dim[i_layer + 1] if (i_layer < self.num_layers - 1) else None
792
),
793
input_resolution=(
794
patches_resolution[0] // (2**i_layer),
795
patches_resolution[1] // (2**i_layer),
796
),
797
depth=depths[i_layer],
798
mlp_ratio=self.mlp_ratio,
799
drop=drop_rate,
800
downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None,
801
focal_level=focal_levels[i_layer],
802
focal_window=focal_windows[i_layer],
803
)
804
self.basic_layers.append(layer)
805
self.norm = keras.layers.LayerNormalization(epsilon=1e-7)
806
self.avgpool = layers.GlobalAveragePooling1D()
807
self.flatten = layers.Flatten()
808
self.head = layers.Dense(self.num_classes, activation="softmax")
809
810
def call(self, x: tf.Tensor) -> tf.Tensor:
811
"""Forward pass of the layer.
812
813
Args:
814
x: Tensor of shape (B, H, W, C)
815
816
Returns:
817
The logits.
818
"""
819
# Patch Embed the input images.
820
x, height, width, channels = self.patch_embed(x)
821
x = self.pos_drop(x)
822
823
for idx, layer in enumerate(self.basic_layers):
824
x, height, width, channels = layer(x, height, width, channels)
825
826
x = self.norm(x)
827
x = self.avgpool(x)
828
x = self.flatten(x)
829
x = self.head(x)
830
return x
831
832
833
"""
834
## Train the model
835
836
Now with all the components in place and the architecture actually built, we are ready to
837
put it to good use.
838
839
In this section, we train our Focal Modulation model on the CIFAR-10 dataset.
840
"""
841
842
"""
843
### Visualization Callback
844
845
A key feature of the Focal Modulation Network is explicit input-dependency. This means
846
the modulator is calculated by looking at the local features around the target location,
847
so it depends on the input. In very simple terms, this makes interpretation easy. We can
848
simply lay down the gating values and the original image, next to each other to see how
849
the gating mechanism works.
850
851
The authors of the paper visualize the gates and the modulator in order to focus on the
852
interpretability of the Focal Modulation layer. Below is a visualization
853
callback that shows the gates and modulator of a specific layer in the model while the
854
model trains.
855
856
We will notice later that as the model trains, the visualizations get better.
857
858
The gates appear to selectively permit certain aspects of the input image to pass
859
through, while gently disregarding others, ultimately leading to improved classification
860
accuracy.
861
"""
862
863
864
def display_grid(
865
test_images: tf.Tensor,
866
gates: tf.Tensor,
867
modulator: tf.Tensor,
868
):
869
"""Displays the image with the gates and modulator overlayed.
870
871
Args:
872
test_images (tf.Tensor): A batch of test images.
873
gates (tf.Tensor): The gates of the Focal Modualtion Layer.
874
modulator (tf.Tensor): The modulator of the Focal Modulation Layer.
875
"""
876
fig, ax = plt.subplots(nrows=1, ncols=5, figsize=(25, 5))
877
878
# Radomly sample an image from the batch.
879
index = randint(0, BATCH_SIZE - 1)
880
orig_image = test_images[index]
881
gate_image = gates[index]
882
modulator_image = modulator[index]
883
884
# Original Image
885
ax[0].imshow(orig_image)
886
ax[0].set_title("Original:")
887
ax[0].axis("off")
888
889
for index in range(1, 5):
890
img = ax[index].imshow(orig_image)
891
if index != 4:
892
overlay_image = gate_image[..., index - 1]
893
title = f"G {index}:"
894
else:
895
overlay_image = tf.norm(modulator_image, ord=2, axis=-1)
896
title = f"MOD:"
897
898
ax[index].imshow(
899
overlay_image, cmap="inferno", alpha=0.6, extent=img.get_extent()
900
)
901
ax[index].set_title(title)
902
ax[index].axis("off")
903
904
plt.axis("off")
905
plt.show()
906
plt.close()
907
908
909
"""
910
### TrainMonitor
911
"""
912
913
# Taking a batch of test inputs to measure the model's progress.
914
test_images, test_labels = next(iter(test_ds))
915
upsampler = tf.keras.layers.UpSampling2D(
916
size=(4, 4),
917
interpolation="bilinear",
918
)
919
920
921
class TrainMonitor(keras.callbacks.Callback):
922
def __init__(self, epoch_interval=None):
923
self.epoch_interval = epoch_interval
924
925
def on_epoch_end(self, epoch, logs=None):
926
if self.epoch_interval and epoch % self.epoch_interval == 0:
927
_ = self.model(test_images)
928
929
# Take the mid layer for visualization
930
gates = self.model.basic_layers[1].blocks[-1].modulation.gates
931
gates = upsampler(gates)
932
modulator = self.model.basic_layers[1].blocks[-1].modulation.modulator
933
modulator = upsampler(modulator)
934
935
# Display the grid of gates and modulator.
936
display_grid(test_images=test_images, gates=gates, modulator=modulator)
937
938
939
"""
940
### Learning Rate scheduler
941
"""
942
943
944
# Some code is taken from:
945
# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2.
946
class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
947
def __init__(
948
self, learning_rate_base, total_steps, warmup_learning_rate, warmup_steps
949
):
950
super().__init__()
951
self.learning_rate_base = learning_rate_base
952
self.total_steps = total_steps
953
self.warmup_learning_rate = warmup_learning_rate
954
self.warmup_steps = warmup_steps
955
self.pi = tf.constant(np.pi)
956
957
def __call__(self, step):
958
if self.total_steps < self.warmup_steps:
959
raise ValueError("Total_steps must be larger or equal to warmup_steps.")
960
cos_annealed_lr = tf.cos(
961
self.pi
962
* (tf.cast(step, tf.float32) - self.warmup_steps)
963
/ float(self.total_steps - self.warmup_steps)
964
)
965
learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr)
966
if self.warmup_steps > 0:
967
if self.learning_rate_base < self.warmup_learning_rate:
968
raise ValueError(
969
"Learning_rate_base must be larger or equal to "
970
"warmup_learning_rate."
971
)
972
slope = (
973
self.learning_rate_base - self.warmup_learning_rate
974
) / self.warmup_steps
975
warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate
976
learning_rate = tf.where(
977
step < self.warmup_steps, warmup_rate, learning_rate
978
)
979
return tf.where(
980
step > self.total_steps, 0.0, learning_rate, name="learning_rate"
981
)
982
983
984
total_steps = int((len(x_train) / BATCH_SIZE) * EPOCHS)
985
warmup_epoch_percentage = 0.15
986
warmup_steps = int(total_steps * warmup_epoch_percentage)
987
scheduled_lrs = WarmUpCosine(
988
learning_rate_base=LEARNING_RATE,
989
total_steps=total_steps,
990
warmup_learning_rate=0.0,
991
warmup_steps=warmup_steps,
992
)
993
994
"""
995
### Initialize, compile and train the model
996
"""
997
998
focal_mod_net = FocalModulationNetwork()
999
optimizer = AdamW(learning_rate=scheduled_lrs, weight_decay=WEIGHT_DECAY)
1000
1001
# Compile and train the model.
1002
focal_mod_net.compile(
1003
optimizer=optimizer,
1004
loss="sparse_categorical_crossentropy",
1005
metrics=["accuracy"],
1006
)
1007
history = focal_mod_net.fit(
1008
train_ds,
1009
epochs=EPOCHS,
1010
validation_data=val_ds,
1011
callbacks=[TrainMonitor(epoch_interval=10)],
1012
)
1013
1014
"""
1015
## Plot loss and accuracy
1016
"""
1017
1018
plt.plot(history.history["loss"], label="loss")
1019
plt.plot(history.history["val_loss"], label="val_loss")
1020
plt.legend()
1021
plt.show()
1022
1023
plt.plot(history.history["accuracy"], label="accuracy")
1024
plt.plot(history.history["val_accuracy"], label="val_accuracy")
1025
plt.legend()
1026
plt.show()
1027
1028
"""
1029
## Test visualizations
1030
1031
Let's test our model on some test images and see how the gates look like.
1032
"""
1033
1034
test_images, test_labels = next(iter(test_ds))
1035
_ = focal_mod_net(test_images)
1036
1037
# Take the mid layer for visualization
1038
gates = focal_mod_net.basic_layers[1].blocks[-1].modulation.gates
1039
gates = upsampler(gates)
1040
modulator = focal_mod_net.basic_layers[1].blocks[-1].modulation.modulator
1041
modulator = upsampler(modulator)
1042
1043
# Plot the test images with the gates and modulator overlayed.
1044
for row in range(5):
1045
display_grid(
1046
test_images=test_images,
1047
gates=gates,
1048
modulator=modulator,
1049
)
1050
1051
"""
1052
## Conclusion
1053
1054
The proposed architecture, the Focal Modulation Network
1055
architecture is a mechanism that allows different
1056
parts of an image to interact with each other in a way that depends on the image itself.
1057
It works by first gathering different levels of context information around each part of
1058
the image (the "query token"), then using a gate to decide which context information is
1059
most relevant, and finally combining the chosen information in a simple but effective
1060
way.
1061
1062
This is meant as a replacement of Self-Attention mechanism from the Transformer
1063
architecture. The key feature that makes this research notable is not the conception of
1064
attention-less networks, but rather the introduction of a equally powerful architecture
1065
that is interpretable.
1066
1067
The authors also mention that they created a series of Focal Modulation Networks
1068
(FocalNets) that significantly outperform Self-Attention counterparts and with a fraction
1069
of parameters and pretraining data.
1070
1071
The FocalNets architecture has the potential to deliver impressive results and offers a
1072
simple implementation. Its promising performance and ease of use make it an attractive
1073
alternative to Self-Attention for researchers to explore in their own projects. It could
1074
potentially become widely adopted by the Deep Learning community in the near future.
1075
1076
## Acknowledgement
1077
1078
We would like to thank [PyImageSearch](https://pyimagesearch.com/) for providing with a
1079
Colab Pro account, [JarvisLabs.ai](https://cloud.jarvislabs.ai/) for GPU credits,
1080
and also Microsoft Research for providing an
1081
[official implementation](https://github.com/microsoft/FocalNet) of their paper.
1082
We would also like to extend our gratitude to the first author of the
1083
paper [Jianwei Yang](https://twitter.com/jw2yang4ai) who reviewed this tutorial
1084
extensively.
1085
"""
1086
1087