Path: blob/master/examples/vision/focal_modulation_network.py
8412 views
"""1Title: Focal Modulation: A replacement for Self-Attention2Author: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha)3Date created: 2023/01/254Last modified: 2026/01/275Description: Image classification with Focal Modulation Networks.6Accelerator: GPU7Converted to Keras 3 by: [LakshmiKalaKadali](https://github.com/LakshmiKalaKadali)8"""910"""11## Introduction1213This tutorial aims to provide a comprehensive guide to the implementation of14Focal Modulation Networks, as presented in15[Yang et al.](https://arxiv.org/abs/2203.11926).1617This tutorial will provide a formal, minimalistic approach to implementing Focal18Modulation Networks and explore its potential applications in the field of Deep Learning.1920**Problem statement**2122The Transformer architecture ([Vaswani et al.](https://arxiv.org/abs/1706.03762)),23which has become the de facto standard in most Natural Language Processing tasks, has24also been applied to the field of computer vision, e.g. Vision25Transformers ([Dosovitskiy et al.](https://arxiv.org/abs/2010.11929v2)).2627> In Transformers, the self-attention (SA) is arguably the key to its success which28enables input-dependent global interactions, in contrast to convolution operation which29constraints interactions in a local region with a shared kernel.3031The **Attention** module is mathematically written as shown in **Equation 1**.3233|  |34| :--: |35| Equation 1: The mathematical equation of attention (Source: Aritra and Ritwik) |3637Where:3839- `Q` is the query40- `K` is the key41- `V` is the value42- `d_k` is the dimension of the key4344With **self-attention**, the query, key, and value are all sourced from the input45sequence. Let us rewrite the attention equation for self-attention as shown in **Equation462**.4748|  |49| :--: |50| Equation 2: The mathematical equation of self-attention (Source: Aritra and Ritwik) |5152Upon looking at the equation of self-attention, we see that it is a quadratic equation.53Therefore, as the number of tokens increase, so does the computation time (cost too). To54mitigate this problem and make Transformers more interpretable, Yang et al.55have tried to replace the Self-Attention module with better components.5657**The Solution**5859Yang et al. introduce the Focal Modulation layer to serve as a60seamless replacement for the Self-Attention Layer. The layer boasts high61interpretability, making it a valuable tool for Deep Learning practitioners.6263In this tutorial, we will delve into the practical application of this layer by training64the entire model on the CIFAR-10 dataset and visually interpreting the layer's65performance.6667Note: We try to align our implementation with the68[official implementation](https://github.com/microsoft/FocalNet).69"""7071"""72## Setup and Imports7374Keras 3 allows this model to run on JAX, PyTorch, or TensorFlow. We use keras.ops for all mathematical operations to ensure the code remains backend-agnostic.75"""7677import os7879# Set backend before importing keras80os.environ["KERAS_BACKEND"] = "tensorflow" # Or "torch" or "tensorflow"8182import numpy as np83import keras84from keras import layers85from keras import ops86from matplotlib import pyplot as plt87from random import randint8889# Set seed for reproducibility using Keras 3 utility.90keras.utils.set_random_seed(42)9192"""93## Global Configuration9495We do not have any strong rationale behind choosing these hyperparameters. Please feel96free to change the configuration and train the model.97"""9899# --- GLOBAL CONFIGURATION ---100TRAIN_SLICE = 40000101BATCH_SIZE = 128 # 1024102INPUT_SHAPE = (32, 32, 3)103IMAGE_SIZE = 48104NUM_CLASSES = 10105106LEARNING_RATE = 1e-4107WEIGHT_DECAY = 1e-4108EPOCHS = 20109110"""111## Data Loading with PyDataset112113Keras 3 introduces PyDataset as a standardized way to handle data.114It works identically across all backends and avoids the "Symbolic Tensor" issues often found115when using tf.data with JAX or PyTorch.116"""117118(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()119(x_train, y_train), (x_val, y_val) = (120(x_train[:TRAIN_SLICE], y_train[:TRAIN_SLICE]),121(x_train[TRAIN_SLICE:], y_train[TRAIN_SLICE:]),122)123124125class FocalDataset(keras.utils.PyDataset):126def __init__(self, x_data, y_data, batch_size, shuffle=False, **kwargs):127super().__init__(**kwargs)128self.x_data = x_data129self.y_data = y_data130self.batch_size = batch_size131self.shuffle = shuffle132self.indices = np.arange(len(x_data))133if self.shuffle:134np.random.shuffle(self.indices)135136def __len__(self):137return int(np.ceil(len(self.x_data) / self.batch_size))138139def __getitem__(self, idx):140start = idx * self.batch_size141end = min((idx + 1) * self.batch_size, len(self.x_data))142batch_indices = self.indices[start:end]143144x_batch = self.x_data[batch_indices]145y_batch = self.y_data[batch_indices]146147# Convert to backend-native tensors148x_batch = ops.convert_to_tensor(x_batch, dtype="float32")149y_batch = ops.convert_to_tensor(y_batch, dtype="int32")150151return x_batch, y_batch152153def on_epoch_end(self):154if self.shuffle:155np.random.shuffle(self.indices)156157158train_ds = FocalDataset(x_train, y_train, batch_size=BATCH_SIZE, shuffle=True)159val_ds = FocalDataset(x_val, y_val, batch_size=BATCH_SIZE, shuffle=False)160test_ds = FocalDataset(x_test, y_test, batch_size=BATCH_SIZE, shuffle=False)161162"""163## Architecture164165We pause here to take a quick look at the Architecture of the Focal Modulation Network.166**Figure 1** shows how every individual layer is compiled into a single model. This gives167us a bird's eye view of the entire architecture.168169|  |170| :--: |171| Figure 1: A diagram of the Focal Modulation model (Source: Aritra and Ritwik) |172173We dive deep into each of these layers in the following sections. This is the order we174will follow:175176177- Patch Embedding Layer178- Focal Modulation Block179- Multi-Layer Perceptron180- Focal Modulation Layer181- Hierarchical Contextualization182- Gated Aggregation183- Building Focal Modulation Block184- Building the Basic Layer185186To better understand the architecture in a format we are well versed in, let us see how187the Focal Modulation Network would look when drawn like a Transformer architecture.188189**Figure 2** shows the encoder layer of a traditional Transformer architecture where Self190Attention is replaced with the Focal Modulation layer.191192The <font color="blue">blue</font> blocks represent the Focal Modulation block. A stack193of these blocks builds a single Basic Layer. The <font color="green">green</font> blocks194represent the Focal Modulation layer.195196|  |197| :--: |198| Figure 2: The Entire Architecture (Source: Aritra and Ritwik) |199"""200201"""202## Patch Embedding Layer203204The patch embedding layer is used to patchify the input images and project them into a205latent space. This layer is also used as the down-sampling layer in the architecture.206"""207208209class PatchEmbed(layers.Layer):210"""Image patch embedding layer, also acts as the down-sampling layer.211212Args:213image_size (Tuple[int]): Input image resolution.214patch_size (Tuple[int]): Patch spatial resolution.215embed_dim (int): Embedding dimension.216"""217218def __init__(219self, image_size=(224, 224), patch_size=(4, 4), embed_dim=96, **kwargs220):221super().__init__(**kwargs)222self.patch_resolution = [223image_size[0] // patch_size[0],224image_size[1] // patch_size[1],225]226self.proj = layers.Conv2D(227filters=embed_dim, kernel_size=patch_size, strides=patch_size228)229self.flatten = layers.Reshape(target_shape=(-1, embed_dim))230self.norm = layers.LayerNormalization(epsilon=1e-7)231232def call(self, x):233"""Patchifies the image and converts into tokens.234235Args:236x: Tensor of shape (B, H, W, C)237238Returns:239A tuple of the processed tensor, height of the projected240feature map, width of the projected feature map, number241of channels of the projected feature map.242"""243x = self.proj(x)244shape = ops.shape(x)245height, width, channels = shape[1], shape[2], shape[3]246x = self.norm(self.flatten(x))247return x, height, width, channels248249250"""251## Focal Modulation block252253A Focal Modulation block can be considered as a single Transformer Block with the Self254Attention (SA) module being replaced with Focal Modulation module, as we saw in **Figure2552**.256257Let us recall how a focal modulation block is supposed to look like with the aid of the258**Figure 3**.259260261|  |262| :--: |263| Figure 3: The isolated view of the Focal Modulation Block (Source: Aritra and Ritwik) |264265The Focal Modulation Block consists of:266- Multilayer Perceptron267- Focal Modulation layer268"""269270"""271### Multilayer Perceptron272"""273274275def MLP(in_features, hidden_features=None, out_features=None, mlp_drop_rate=0.0):276hidden_features = hidden_features or in_features277out_features = out_features or in_features278return keras.Sequential(279[280layers.Dense(units=hidden_features, activation="gelu"),281layers.Dense(units=out_features),282layers.Dropout(rate=mlp_drop_rate),283]284)285286287"""288### Focal Modulation layer289290In a typical Transformer architecture, for each visual token (**query**) `x_i in R^C` in291an input feature map `X in R^{HxWxC}` a **generic encoding process** produces a feature292representation `y_i in R^C`.293294The encoding process consists of **interaction** (with its surroundings for e.g. a dot295product), and **aggregation** (over the contexts for e.g weighted mean).296297We will talk about two types of encoding here:298- Interaction and then Aggregation in **Self-Attention**299- Aggregation and then Interaction in **Focal Modulation**300301**Self-Attention**302303|  |304| :--: |305| **Figure 4**: Self-Attention module. (Source: Aritra and Ritwik) |306307|  |308| :--: |309| **Equation 3:** Aggregation and Interaction in Self-Attention(Surce: Aritra and Ritwik)|310311As shown in **Figure 4** the query and the key interact (in the interaction step) with312each other to output the attention scores. The weighted aggregation of the value comes313next, known as the aggregation step.314315**Focal Modulation**316317|  |318| :--: |319| **Figure 5**: Focal Modulation module. (Source: Aritra and Ritwik) |320321|  |322| :--: |323| **Equation 4:** Aggregation and Interaction in Focal Modulation (Source: Aritra and Ritwik) |324325**Figure 5** depicts the Focal Modulation layer. `q()` is the query projection326function. It is a **linear layer** that projects the query into a latent space. `m ()` is327the context aggregation function. Unlike self-attention, the328aggregation step takes place in focal modulation before the interaction step.329"""330331"""332While `q()` is pretty straightforward to understand, the context aggregation function333`m()` is more complex. Therefore, this section will focus on `m()`.334335| |336| :--: |337| **Figure 6**: Context Aggregation function `m()`. (Source: Aritra and Ritwik) |338339The context aggregation function `m()` consists of two parts as shown in **Figure 6**:340- Hierarchical Contextualization341- Gated Aggregation342"""343344"""345#### Hierarchical Contextualization346347| |348| :--: |349| **Figure 7**: Hierarchical Contextualization (Source: Aritra and Ritwik) |350351In **Figure 7**, we see that the input is first projected linearly. This linear projection352produces `Z^0`. Where `Z^0` can be expressed as follows:353354|  |355| :--: |356| Equation 5: Linear projection of `Z^0` (Source: Aritra and Ritwik) |357358`Z^0` is then passed on to a series of Depth-Wise (DWConv) Conv and359[GeLU](https://keras.io/api/layers/activations/#gelu-function) layers. The360authors term each block of DWConv and GeLU as levels denoted by `l`. In **Figure 6** we361have two levels. Mathematically this is represented as:362363|  |364| :--: |365| Equation 6: Levels of the modulation layer (Source: Aritra and Ritwik) |366367where `l in {1, ... , L}`368369The final feature map goes through a Global Average Pooling Layer. This can be expressed370as follows:371372|  |373| :--: |374| Equation 7: Average Pooling of the final feature (Source: Aritra and Ritwik)|375"""376377"""378#### Gated Aggregation379380| |381| :--: |382| **Figure 8**: Gated Aggregation (Source: Aritra and Ritwik) |383384Now that we have `L+1` intermediate feature maps by virtue of the Hierarchical385Contextualization step, we need a gating mechanism that lets some features pass and386prohibits others. This can be implemented with the attention module.387Later in the tutorial, we will visualize these gates to better understand their388usefulness.389390First, we build the weights for aggregation. Here we apply a **linear layer** on the input391feature map that projects it into `L+1` dimensions.392393|  |394| :--: |395| Eqation 8: Gates (Source: Aritra and Ritwik) |396397Next we perform the weighted aggregation over the contexts.398399|  |400| :--: |401| Eqation 9: Final feature map (Source: Aritra and Ritwik) |402403To enable communication across different channels, we use another linear layer `h()`404to obtain the modulator405406|  |407| :--: |408| Eqation 10: Modulator (Source: Aritra and Ritwik) |409410To sum up the Focal Modulation layer we have:411412|  |413| :--: |414| Eqation 11: Focal Modulation Layer (Source: Aritra and Ritwik) |415"""416417418class FocalModulationLayer(layers.Layer):419"""The Focal Modulation layer includes query projection & context aggregation.420421Args:422dim (int): Projection dimension.423focal_window (int): Window size for focal modulation.424focal_level (int): The current focal level.425focal_factor (int): Factor of focal modulation.426proj_drop_rate (float): Rate of dropout.427"""428429def __init__(430self,431dim,432focal_window,433focal_level,434focal_factor=2,435proj_drop_rate=0.0,436**kwargs,437):438super().__init__(**kwargs)439self.dim, self.focal_level = dim, focal_level440self.initial_proj = layers.Dense(units=(2 * dim) + (focal_level + 1))441self.focal_layers = [442keras.Sequential(443[444layers.ZeroPadding2D(445padding=((focal_factor * i + focal_window) // 2)446),447layers.Conv2D(448filters=dim,449kernel_size=(focal_factor * i + focal_window),450activation="gelu",451groups=dim,452use_bias=False,453),454]455)456for i in range(focal_level)457]458self.gap = layers.GlobalAveragePooling2D(keepdims=True)459self.mod_proj = layers.Conv2D(filters=dim, kernel_size=1)460self.proj = layers.Dense(units=dim)461self.proj_drop = layers.Dropout(proj_drop_rate)462463def call(self, x, training=None):464"""Forward pass of the layer.465466Args:467x: Tensor of shape (B, H, W, C)468"""469x_proj = self.initial_proj(x)470query, context, gates = ops.split(x_proj, [self.dim, 2 * self.dim], axis=-1)471472# Apply Softmax for numerical stability473gates = ops.softmax(gates, axis=-1)474self.gates = gates475476context = self.focal_layers[0](context)477context_all = context * gates[..., 0:1]478for i in range(1, self.focal_level):479context = self.focal_layers[i](context)480context_all = context_all + (context * gates[..., i : i + 1])481482context_global = ops.gelu(self.gap(context))483context_all = context_all + (context_global * gates[..., self.focal_level :])484485self.modulator = self.mod_proj(context_all)486x_out = query * self.modulator487return self.proj_drop(self.proj(x_out), training=training)488489490"""491### The Focal Modulation block492493Finally, we have all the components we need to build the Focal Modulation block. Here we494take the MLP and Focal Modulation layer together and build the Focal Modulation block.495"""496497498class FocalModulationBlock(layers.Layer):499"""Combine FFN and Focal Modulation Layer.500501Args:502dim (int): Number of input channels.503mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.504drop (float): Dropout rate.505focal_level (int): Number of focal levels.506focal_window (int): Focal window size at first focal level507"""508509def __init__(510self, dim, mlp_ratio=4.0, drop=0.0, focal_level=1, focal_window=3, **kwargs511):512super().__init__(**kwargs)513self.norm1 = layers.LayerNormalization(epsilon=1e-5)514self.modulation = FocalModulationLayer(515dim, focal_window, focal_level, proj_drop_rate=drop516)517self.norm2 = layers.LayerNormalization(epsilon=1e-5)518self.mlp = MLP(dim, int(dim * mlp_ratio), mlp_drop_rate=drop)519520def call(self, x, height=None, width=None, channels=None, training=None):521"""Processes the input tensor through the focal modulation block.522523Args:524x : Inputs of the shape (B, L, C)525height (int): The height of the feature map526width (int): The width of the feature map527channels (int): The number of channels of the feature map528529Returns:530The processed tensor.531"""532res = x533x = ops.reshape(x, (-1, height, width, channels))534x = self.modulation(x, training=training)535x = ops.reshape(x, (-1, height * width, channels))536x = res + x537return x + self.mlp(self.norm2(x), training=training)538539540"""541## The Basic Layer542543The basic layer consists of a collection of Focal Modulation blocks. This is544illustrated in **Figure 9**.545546|  |547| :--: |548| **Figure 9**: Basic Layer, a collection of focal modulation blocks. (Source: Aritra and Ritwik) |549550Notice how in **Fig. 9** there are more than one focal modulation blocks denoted by `Nx`.551This shows how the Basic Layer is a collection of Focal Modulation blocks.552"""553554555class BasicLayer(layers.Layer):556"""Collection of Focal Modulation Blocks.557558Args:559dim (int): Dimensions of the model.560out_dim (int): Dimension used by the Patch Embedding Layer.561input_res (Tuple[int]): Input image resolution.562depth (int): The number of Focal Modulation Blocks.563mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.564drop (float): Dropout rate.565downsample (keras.layers.Layer): Downsampling layer at the end of the layer.566focal_level (int): The current focal level.567focal_window (int): Focal window used.568"""569570def __init__(571self,572dim,573out_dim,574input_res,575depth,576mlp_ratio=4.0,577drop=0.0,578downsample=None,579focal_level=1,580focal_window=1,581**kwargs,582):583super().__init__(**kwargs)584self.blocks = [585FocalModulationBlock(dim, mlp_ratio, drop, focal_level, focal_window)586for _ in range(depth)587]588self.downsample = (589downsample(image_size=input_res, patch_size=(2, 2), embed_dim=out_dim)590if downsample591else None592)593594def call(self, x, height=None, width=None, channels=None, training=None):595"""Forward pass of the layer.596597Args:598x : Tensor of shape (B, L, C)599height (int): Height of feature map600width (int): Width of feature map601channels (int): Embed Dim of feature map602603Returns:604A tuple of the processed tensor, changed height, width, and605dim of the tensor.606"""607for block in self.blocks:608x = block(609x, height=height, width=width, channels=channels, training=training610)611if self.downsample:612x = ops.reshape(x, (-1, height, width, channels))613x, height, width, channels = self.downsample(x)614return x, height, width, channels615616617"""618## The Focal Modulation Network model619620This is the model that ties everything together.621It consists of a collection of Basic Layers with a classification head.622For a recap of how this is structured refer to **Figure 1**.623"""624625626class FocalModulationNetwork(keras.Model):627"""The Focal Modulation Network.628629Parameters:630image_size (Tuple[int]): Spatial size of images used.631patch_size (Tuple[int]): Patch size of each patch.632num_classes (int): Number of classes used for classification.633embed_dim (int): Patch embedding dimension.634depths (List[int]): Depth of each Focal Transformer block.635"""636637def __init__(638self,639image_size=(48, 48),640patch_size=(4, 4),641num_classes=10,642embed_dim=64,643depths=[2, 3, 2],644**kwargs,645):646super().__init__(**kwargs)647# Preprocessing integrated in model for backend-agnostic behavior648self.rescaling = layers.Rescaling(1.0 / 255.0)649self.resizing_larger = layers.Resizing(image_size[0] + 10, image_size[1] + 10)650self.random_crop = layers.RandomCrop(image_size[0], image_size[1])651self.resizing_target = layers.Resizing(image_size[0], image_size[1])652self.random_flip = layers.RandomFlip("horizontal")653654self.patch_embed = PatchEmbed(image_size, patch_size, embed_dim)655self.basic_layers = []656for i in range(len(depths)):657d = embed_dim * (2**i)658self.basic_layers.append(659BasicLayer(660dim=d,661out_dim=d * 2 if i < len(depths) - 1 else None,662input_res=(image_size[0] // (2**i), image_size[1] // (2**i)),663depth=depths[i],664downsample=PatchEmbed if i < len(depths) - 1 else None,665)666)667self.norm = layers.LayerNormalization(epsilon=1e-7)668self.avgpool = layers.GlobalAveragePooling1D()669self.head = layers.Dense(num_classes, activation="softmax")670671def call(self, x, training=None):672"""Forward pass of the layer.673674Args:675x: Tensor of shape (B, H, W, C)676677Returns:678The logits.679"""680x = self.rescaling(x)681if training:682x = self.resizing_larger(x)683x = self.random_crop(x)684x = self.random_flip(x)685else:686x = self.resizing_target(x)687688x, h, w, c = self.patch_embed(x)689for layer in self.basic_layers:690x, h, w, c = layer(x, height=h, width=w, channels=c, training=training)691return self.head(self.avgpool(self.norm(x)))692693694"""695## Train the model696697Now with all the components in place and the architecture actually built, we are ready to698put it to good use.699700In this section, we train our Focal Modulation model on the CIFAR-10 dataset.701"""702703"""704### Visualization Callback705706A key feature of the Focal Modulation Network is explicit input-dependency. This means707the modulator is calculated by looking at the local features around the target location,708so it depends on the input. In very simple terms, this makes interpretation easy. We can709simply lay down the gating values and the original image, next to each other to see how710the gating mechanism works.711712The authors of the paper visualize the gates and the modulator in order to focus on the713interpretability of the Focal Modulation layer. Below is a visualization714callback that shows the gates and modulator of a specific layer in the model while the715model trains.716717We will notice later that as the model trains, the visualizations get better.718719The gates appear to selectively permit certain aspects of the input image to pass720through, while gently disregarding others, ultimately leading to improved classification721accuracy.722"""723724725def display_grid(test_images, gates, modulator):726"""Displays the image with the gates and modulator overlayed.727728Args:729test_images: A batch of test images.730gates: The gates of the Focal Modualtion Layer.731modulator: The modulator of the Focal Modulation Layer.732"""733test_images_np = ops.convert_to_numpy(test_images) / 255.0734gates_np = ops.convert_to_numpy(gates)735mod_np = ops.convert_to_numpy(ops.norm(modulator, axis=-1))736737num_gates = gates_np.shape[-1]738idx = randint(0, test_images_np.shape[0] - 1)739fig, ax = plt.subplots(1, num_gates + 2, figsize=((num_gates + 2) * 4, 4))740741ax[0].imshow(test_images_np[idx])742ax[0].set_title("Original")743ax[0].axis("off")744for i in range(num_gates):745ax[i + 1].imshow(test_images_np[idx])746ax[i + 1].imshow(gates_np[idx, ..., i], cmap="inferno", alpha=0.6)747ax[i + 1].set_title(f"Gate {i+1}")748ax[i + 1].axis("off")749750ax[-1].imshow(test_images_np[idx])751ax[-1].imshow(mod_np[idx], cmap="inferno", alpha=0.6)752ax[-1].set_title("Modulator")753ax[-1].axis("off")754plt.show()755plt.close()756757758"""759### TrainMonitor760"""761762# Fetch test batch for callback763test_batch_images, _ = test_ds[0]764765766class TrainMonitor(keras.callbacks.Callback):767def __init__(self, epoch_interval=10):768super().__init__()769self.epoch_interval = epoch_interval770self.upsampler = layers.UpSampling2D(size=(4, 4), interpolation="bilinear")771772def on_epoch_end(self, epoch, logs=None):773if (epoch + 1) % self.epoch_interval == 0:774_ = self.model(test_batch_images, training=False)775layer = self.model.basic_layers[1].blocks[-1].modulation776display_grid(777test_batch_images,778self.upsampler(layer.gates),779self.upsampler(layer.modulator),780)781782783"""784### Learning Rate scheduler785"""786787788class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):789def __init__(self, lr_base, total_steps, warmup_steps):790super().__init__()791self.lr_base, self.total_steps, self.warmup_steps = (792lr_base,793total_steps,794warmup_steps,795)796797def __call__(self, step):798step = ops.cast(step, "float32")799cos_lr = (8000.5801* self.lr_base802* (8031804+ ops.cos(805np.pi806* (step - self.warmup_steps)807/ (self.total_steps - self.warmup_steps)808)809)810)811warmup_lr = (self.lr_base / self.warmup_steps) * step812return ops.where(813step < self.warmup_steps,814warmup_lr,815ops.where(step > self.total_steps, 0.0, cos_lr),816)817818819total_steps = (len(x_train) // BATCH_SIZE) * EPOCHS820scheduled_lrs = WarmUpCosine(LEARNING_RATE, total_steps, int(total_steps * 0.15))821822"""823### Initialize, compile and train the model824"""825826model = FocalModulationNetwork(image_size=(IMAGE_SIZE, IMAGE_SIZE))827model.compile(828optimizer=keras.optimizers.AdamW(829learning_rate=scheduled_lrs, weight_decay=WEIGHT_DECAY, clipnorm=1.0830),831loss="sparse_categorical_crossentropy",832metrics=["accuracy"],833)834835history = model.fit(836train_ds,837validation_data=val_ds,838epochs=EPOCHS,839callbacks=[TrainMonitor(epoch_interval=5)],840)841842"""843## Plot loss and accuracy844"""845plt.figure(figsize=(12, 4))846plt.subplot(1, 2, 1)847plt.plot(history.history["loss"], label="Train Loss")848plt.plot(history.history["val_loss"], label="Val Loss")849plt.legend()850plt.subplot(1, 2, 2)851plt.plot(history.history["accuracy"], label="Train Acc")852plt.plot(history.history["val_accuracy"], label="Val Acc")853plt.legend()854plt.show()855856"""857## Test visualizations858859Let's test our model on some test images and see how the gates look like.860"""861test_images, test_labels = next(iter(test_ds))862863_ = model(test_images, training=False)864865target_layer = model.basic_layers[1].blocks[-1].modulation866gates = target_layer.gates867modulator = target_layer.modulator868869upsampler = layers.UpSampling2D(size=(4, 4), interpolation="bilinear")870gates_upsampled = upsampler(gates)871modulator_upsampled = upsampler(modulator)872873for row in range(5):874display_grid(875test_images=test_images,876gates=gates_upsampled,877modulator=modulator_upsampled,878)879880"""881## Conclusion882883The proposed architecture, the Focal Modulation Network884architecture is a mechanism that allows different885parts of an image to interact with each other in a way that depends on the image itself.886It works by first gathering different levels of context information around each part of887the image (the "query token"), then using a gate to decide which context information is888most relevant, and finally combining the chosen information in a simple but effective889way.890891This is meant as a replacement of Self-Attention mechanism from the Transformer892architecture. The key feature that makes this research notable is not the conception of893attention-less networks, but rather the introduction of a equally powerful architecture894that is interpretable.895896The authors also mention that they created a series of Focal Modulation Networks897(FocalNets) that significantly outperform Self-Attention counterparts and with a fraction898of parameters and pretraining data.899900The FocalNets architecture has the potential to deliver impressive results and offers a901simple implementation. Its promising performance and ease of use make it an attractive902alternative to Self-Attention for researchers to explore in their own projects. It could903potentially become widely adopted by the Deep Learning community in the near future.904905## Acknowledgement906907We would like to thank [PyImageSearch](https://pyimagesearch.com/) for providing with a908Colab Pro account, [JarvisLabs.ai](https://cloud.jarvislabs.ai/) for GPU credits,909and also Microsoft Research for providing an910[official implementation](https://github.com/microsoft/FocalNet) of their paper.911We would also like to extend our gratitude to the first author of the912paper [Jianwei Yang](https://twitter.com/jw2yang4ai) who reviewed this tutorial913extensively.914"""915916"""917## Relevant Chapters from Deep Learning with Python918- [Chapter 8: Image classification](https://deeplearningwithpython.io/chapters/chapter08_image-classification)919"""920921922