Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/involution.py
3507 views
1
"""
2
Title: Involutional neural networks
3
Author: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498)
4
Date created: 2021/07/25
5
Last modified: 2021/07/25
6
Description: Deep dive into location-specific and channel-agnostic "involution" kernels.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
Convolution has been the basis of most modern neural
14
networks for computer vision. A convolution kernel is
15
spatial-agnostic and channel-specific. Because of this, it isn't able
16
to adapt to different visual patterns with respect to
17
different spatial locations. Along with location-related problems, the
18
receptive field of convolution creates challenges with regard to capturing
19
long-range spatial interactions.
20
21
To address the above issues, Li et. al. rethink the properties
22
of convolution in
23
[Involution: Inverting the Inherence of Convolution for VisualRecognition](https://arxiv.org/abs/2103.06255).
24
The authors propose the "involution kernel", that is location-specific and
25
channel-agnostic. Due to the location-specific nature of the operation,
26
the authors say that self-attention falls under the design paradigm of
27
involution.
28
29
This example describes the involution kernel, compares two image
30
classification models, one with convolution and the other with
31
involution, and also tries drawing a parallel with the self-attention
32
layer.
33
"""
34
35
"""
36
## Setup
37
"""
38
39
import os
40
41
os.environ["KERAS_BACKEND"] = "tensorflow"
42
43
import tensorflow as tf
44
import keras
45
import matplotlib.pyplot as plt
46
47
# Set seed for reproducibility.
48
tf.random.set_seed(42)
49
50
"""
51
## Convolution
52
53
Convolution remains the mainstay of deep neural networks for computer vision.
54
To understand Involution, it is necessary to talk about the
55
convolution operation.
56
57
![Imgur](https://i.imgur.com/MSKLsm5.png)
58
59
Consider an input tensor **X** with dimensions **H**, **W** and
60
**C_in**. We take a collection of **C_out** convolution kernels each of
61
shape **K**, **K**, **C_in**. With the multiply-add operation between
62
the input tensor and the kernels we obtain an output tensor **Y** with
63
dimensions **H**, **W**, **C_out**.
64
65
In the diagram above `C_out=3`. This makes the output tensor of shape H,
66
W and 3. One can notice that the convoltuion kernel does not depend on
67
the spatial position of the input tensor which makes it
68
**location-agnostic**. On the other hand, each channel in the output
69
tensor is based on a specific convolution filter which makes is
70
**channel-specific**.
71
"""
72
73
"""
74
## Involution
75
76
The idea is to have an operation that is both **location-specific**
77
and **channel-agnostic**. Trying to implement these specific properties poses
78
a challenge. With a fixed number of involution kernels (for each
79
spatial position) we will **not** be able to process variable-resolution
80
input tensors.
81
82
To solve this problem, the authors have considered *generating* each
83
kernel conditioned on specific spatial positions. With this method, we
84
should be able to process variable-resolution input tensors with ease.
85
The diagram below provides an intuition on this kernel generation
86
method.
87
88
![Imgur](https://i.imgur.com/jtrGGQg.png)
89
"""
90
91
92
class Involution(keras.layers.Layer):
93
def __init__(
94
self, channel, group_number, kernel_size, stride, reduction_ratio, name
95
):
96
super().__init__(name=name)
97
98
# Initialize the parameters.
99
self.channel = channel
100
self.group_number = group_number
101
self.kernel_size = kernel_size
102
self.stride = stride
103
self.reduction_ratio = reduction_ratio
104
105
def build(self, input_shape):
106
# Get the shape of the input.
107
(_, height, width, num_channels) = input_shape
108
109
# Scale the height and width with respect to the strides.
110
height = height // self.stride
111
width = width // self.stride
112
113
# Define a layer that average pools the input tensor
114
# if stride is more than 1.
115
self.stride_layer = (
116
keras.layers.AveragePooling2D(
117
pool_size=self.stride, strides=self.stride, padding="same"
118
)
119
if self.stride > 1
120
else tf.identity
121
)
122
# Define the kernel generation layer.
123
self.kernel_gen = keras.Sequential(
124
[
125
keras.layers.Conv2D(
126
filters=self.channel // self.reduction_ratio, kernel_size=1
127
),
128
keras.layers.BatchNormalization(),
129
keras.layers.ReLU(),
130
keras.layers.Conv2D(
131
filters=self.kernel_size * self.kernel_size * self.group_number,
132
kernel_size=1,
133
),
134
]
135
)
136
# Define reshape layers
137
self.kernel_reshape = keras.layers.Reshape(
138
target_shape=(
139
height,
140
width,
141
self.kernel_size * self.kernel_size,
142
1,
143
self.group_number,
144
)
145
)
146
self.input_patches_reshape = keras.layers.Reshape(
147
target_shape=(
148
height,
149
width,
150
self.kernel_size * self.kernel_size,
151
num_channels // self.group_number,
152
self.group_number,
153
)
154
)
155
self.output_reshape = keras.layers.Reshape(
156
target_shape=(height, width, num_channels)
157
)
158
159
def call(self, x):
160
# Generate the kernel with respect to the input tensor.
161
# B, H, W, K*K*G
162
kernel_input = self.stride_layer(x)
163
kernel = self.kernel_gen(kernel_input)
164
165
# reshape the kerenl
166
# B, H, W, K*K, 1, G
167
kernel = self.kernel_reshape(kernel)
168
169
# Extract input patches.
170
# B, H, W, K*K*C
171
input_patches = tf.image.extract_patches(
172
images=x,
173
sizes=[1, self.kernel_size, self.kernel_size, 1],
174
strides=[1, self.stride, self.stride, 1],
175
rates=[1, 1, 1, 1],
176
padding="SAME",
177
)
178
179
# Reshape the input patches to align with later operations.
180
# B, H, W, K*K, C//G, G
181
input_patches = self.input_patches_reshape(input_patches)
182
183
# Compute the multiply-add operation of kernels and patches.
184
# B, H, W, K*K, C//G, G
185
output = tf.multiply(kernel, input_patches)
186
# B, H, W, C//G, G
187
output = tf.reduce_sum(output, axis=3)
188
189
# Reshape the output kernel.
190
# B, H, W, C
191
output = self.output_reshape(output)
192
193
# Return the output tensor and the kernel.
194
return output, kernel
195
196
197
"""
198
## Testing the Involution layer
199
"""
200
201
# Define the input tensor.
202
input_tensor = tf.random.normal((32, 256, 256, 3))
203
204
# Compute involution with stride 1.
205
output_tensor, _ = Involution(
206
channel=3, group_number=1, kernel_size=5, stride=1, reduction_ratio=1, name="inv_1"
207
)(input_tensor)
208
print(f"with stride 1 ouput shape: {output_tensor.shape}")
209
210
# Compute involution with stride 2.
211
output_tensor, _ = Involution(
212
channel=3, group_number=1, kernel_size=5, stride=2, reduction_ratio=1, name="inv_2"
213
)(input_tensor)
214
print(f"with stride 2 ouput shape: {output_tensor.shape}")
215
216
# Compute involution with stride 1, channel 16 and reduction ratio 2.
217
output_tensor, _ = Involution(
218
channel=16, group_number=1, kernel_size=5, stride=1, reduction_ratio=2, name="inv_3"
219
)(input_tensor)
220
print(
221
"with channel 16 and reduction ratio 2 ouput shape: {}".format(output_tensor.shape)
222
)
223
224
"""
225
## Image Classification
226
227
In this section, we will build an image-classifier model. There will
228
be two models one with convolutions and the other with involutions.
229
230
The image-classification model is heavily inspired by this
231
[Convolutional Neural Network (CNN)](https://www.tensorflow.org/tutorials/images/cnn)
232
tutorial from Google.
233
"""
234
235
"""
236
## Get the CIFAR10 Dataset
237
"""
238
239
# Load the CIFAR10 dataset.
240
print("loading the CIFAR10 dataset...")
241
(
242
(train_images, train_labels),
243
(
244
test_images,
245
test_labels,
246
),
247
) = keras.datasets.cifar10.load_data()
248
249
# Normalize pixel values to be between 0 and 1.
250
(train_images, test_images) = (train_images / 255.0, test_images / 255.0)
251
252
# Shuffle and batch the dataset.
253
train_ds = (
254
tf.data.Dataset.from_tensor_slices((train_images, train_labels))
255
.shuffle(256)
256
.batch(256)
257
)
258
test_ds = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(256)
259
260
"""
261
## Visualise the data
262
"""
263
264
class_names = [
265
"airplane",
266
"automobile",
267
"bird",
268
"cat",
269
"deer",
270
"dog",
271
"frog",
272
"horse",
273
"ship",
274
"truck",
275
]
276
277
plt.figure(figsize=(10, 10))
278
for i in range(25):
279
plt.subplot(5, 5, i + 1)
280
plt.xticks([])
281
plt.yticks([])
282
plt.grid(False)
283
plt.imshow(train_images[i])
284
plt.xlabel(class_names[train_labels[i][0]])
285
plt.show()
286
287
"""
288
## Convolutional Neural Network
289
"""
290
291
# Build the conv model.
292
print("building the convolution model...")
293
conv_model = keras.Sequential(
294
[
295
keras.layers.Conv2D(32, (3, 3), input_shape=(32, 32, 3), padding="same"),
296
keras.layers.ReLU(name="relu1"),
297
keras.layers.MaxPooling2D((2, 2)),
298
keras.layers.Conv2D(64, (3, 3), padding="same"),
299
keras.layers.ReLU(name="relu2"),
300
keras.layers.MaxPooling2D((2, 2)),
301
keras.layers.Conv2D(64, (3, 3), padding="same"),
302
keras.layers.ReLU(name="relu3"),
303
keras.layers.Flatten(),
304
keras.layers.Dense(64, activation="relu"),
305
keras.layers.Dense(10),
306
]
307
)
308
309
# Compile the mode with the necessary loss function and optimizer.
310
print("compiling the convolution model...")
311
conv_model.compile(
312
optimizer="adam",
313
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
314
metrics=["accuracy"],
315
)
316
317
# Train the model.
318
print("conv model training...")
319
conv_hist = conv_model.fit(train_ds, epochs=20, validation_data=test_ds)
320
321
"""
322
## Involutional Neural Network
323
"""
324
325
# Build the involution model.
326
print("building the involution model...")
327
328
inputs = keras.Input(shape=(32, 32, 3))
329
x, _ = Involution(
330
channel=3, group_number=1, kernel_size=3, stride=1, reduction_ratio=2, name="inv_1"
331
)(inputs)
332
x = keras.layers.ReLU()(x)
333
x = keras.layers.MaxPooling2D((2, 2))(x)
334
x, _ = Involution(
335
channel=3, group_number=1, kernel_size=3, stride=1, reduction_ratio=2, name="inv_2"
336
)(x)
337
x = keras.layers.ReLU()(x)
338
x = keras.layers.MaxPooling2D((2, 2))(x)
339
x, _ = Involution(
340
channel=3, group_number=1, kernel_size=3, stride=1, reduction_ratio=2, name="inv_3"
341
)(x)
342
x = keras.layers.ReLU()(x)
343
x = keras.layers.Flatten()(x)
344
x = keras.layers.Dense(64, activation="relu")(x)
345
outputs = keras.layers.Dense(10)(x)
346
347
inv_model = keras.Model(inputs=[inputs], outputs=[outputs], name="inv_model")
348
349
# Compile the mode with the necessary loss function and optimizer.
350
print("compiling the involution model...")
351
inv_model.compile(
352
optimizer="adam",
353
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
354
metrics=["accuracy"],
355
)
356
357
# train the model
358
print("inv model training...")
359
inv_hist = inv_model.fit(train_ds, epochs=20, validation_data=test_ds)
360
361
"""
362
## Comparisons
363
364
In this section, we will be looking at both the models and compare a
365
few pointers.
366
"""
367
368
"""
369
### Parameters
370
371
One can see that with a similar architecture the parameters in a CNN
372
is much larger than that of an INN (Involutional Neural Network).
373
"""
374
375
conv_model.summary()
376
377
inv_model.summary()
378
379
"""
380
### Loss and Accuracy Plots
381
382
Here, the loss and the accuracy plots demonstrate that INNs are slow
383
learners (with lower parameters).
384
"""
385
386
plt.figure(figsize=(20, 5))
387
388
plt.subplot(1, 2, 1)
389
plt.title("Convolution Loss")
390
plt.plot(conv_hist.history["loss"], label="loss")
391
plt.plot(conv_hist.history["val_loss"], label="val_loss")
392
plt.legend()
393
394
plt.subplot(1, 2, 2)
395
plt.title("Involution Loss")
396
plt.plot(inv_hist.history["loss"], label="loss")
397
plt.plot(inv_hist.history["val_loss"], label="val_loss")
398
plt.legend()
399
400
plt.show()
401
402
plt.figure(figsize=(20, 5))
403
404
plt.subplot(1, 2, 1)
405
plt.title("Convolution Accuracy")
406
plt.plot(conv_hist.history["accuracy"], label="accuracy")
407
plt.plot(conv_hist.history["val_accuracy"], label="val_accuracy")
408
plt.legend()
409
410
plt.subplot(1, 2, 2)
411
plt.title("Involution Accuracy")
412
plt.plot(inv_hist.history["accuracy"], label="accuracy")
413
plt.plot(inv_hist.history["val_accuracy"], label="val_accuracy")
414
plt.legend()
415
416
plt.show()
417
418
"""
419
## Visualizing Involution Kernels
420
421
To visualize the kernels, we take the sum of **K×K** values from each
422
involution kernel. **All the representatives at different spatial
423
locations frame the corresponding heat map.**
424
425
The authors mention:
426
427
"Our proposed involution is reminiscent of self-attention and
428
essentially could become a generalized version of it."
429
430
With the visualization of the kernel we can indeed obtain an attention
431
map of the image. The learned involution kernels provides attention to
432
individual spatial positions of the input tensor. The
433
**location-specific** property makes involution a generic space of models
434
in which self-attention belongs.
435
"""
436
437
layer_names = ["inv_1", "inv_2", "inv_3"]
438
outputs = [inv_model.get_layer(name).output[1] for name in layer_names]
439
vis_model = keras.Model(inv_model.input, outputs)
440
441
fig, axes = plt.subplots(nrows=10, ncols=4, figsize=(10, 30))
442
443
for ax, test_image in zip(axes, test_images[:10]):
444
(inv1_kernel, inv2_kernel, inv3_kernel) = vis_model.predict(test_image[None, ...])
445
inv1_kernel = tf.reduce_sum(inv1_kernel, axis=[-1, -2, -3])
446
inv2_kernel = tf.reduce_sum(inv2_kernel, axis=[-1, -2, -3])
447
inv3_kernel = tf.reduce_sum(inv3_kernel, axis=[-1, -2, -3])
448
449
ax[0].imshow(keras.utils.array_to_img(test_image))
450
ax[0].set_title("Input Image")
451
452
ax[1].imshow(keras.utils.array_to_img(inv1_kernel[0, ..., None]))
453
ax[1].set_title("Involution Kernel 1")
454
455
ax[2].imshow(keras.utils.array_to_img(inv2_kernel[0, ..., None]))
456
ax[2].set_title("Involution Kernel 2")
457
458
ax[3].imshow(keras.utils.array_to_img(inv3_kernel[0, ..., None]))
459
ax[3].set_title("Involution Kernel 3")
460
461
"""
462
## Conclusions
463
464
In this example, the main focus was to build an `Involution` layer which
465
can be easily reused. While our comparisons were based on a specific
466
task, feel free to use the layer for different tasks and report your
467
results.
468
469
According to me, the key take-away of involution is its
470
relationship with self-attention. The intuition behind location-specific
471
and channel-spefic processing makes sense in a lot of tasks.
472
473
Moving forward one can:
474
475
- Look at [Yannick's video](https://youtu.be/pH2jZun8MoY) on
476
involution for a better understanding.
477
- Experiment with the various hyperparameters of the involution layer.
478
- Build different models with the involution layer.
479
- Try building a different kernel generation method altogether.
480
481
You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/involution)
482
and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/involution).
483
"""
484
485