Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/keras_recipes/subclassing_conv_layers.py
3507 views
1
"""
2
Title: Customizing the convolution operation of a Conv2D layer
3
Author: [lukewood](https://lukewood.xyz)
4
Date created: 11/03/2021
5
Last modified: 11/03/2021
6
Description: This example shows how to implement custom convolution layers using the `Conv.convolution_op()` API.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
You may sometimes need to implement custom versions of convolution layers like `Conv1D` and `Conv2D`.
14
Keras enables you do this without implementing the entire layer from scratch: you can reuse
15
most of the base convolution layer and just customize the convolution op itself via the
16
`convolution_op()` method.
17
18
This method was introduced in Keras 2.7. So before using the
19
`convolution_op()` API, ensure that you are running Keras version 2.7.0 or greater.
20
"""
21
22
"""
23
## A Simple `StandardizedConv2D` implementation
24
25
There are two ways to use the `Conv.convolution_op()` API. The first way
26
is to override the `convolution_op()` method on a convolution layer subclass.
27
Using this approach, we can quickly implement a
28
[StandardizedConv2D](https://arxiv.org/abs/1903.10520) as shown below.
29
"""
30
import os
31
32
os.environ["KERAS_BACKEND"] = "tensorflow"
33
34
import tensorflow as tf
35
import keras
36
from keras import layers
37
import numpy as np
38
39
40
class StandardizedConv2DWithOverride(layers.Conv2D):
41
def convolution_op(self, inputs, kernel):
42
mean, var = tf.nn.moments(kernel, axes=[0, 1, 2], keepdims=True)
43
return tf.nn.conv2d(
44
inputs,
45
(kernel - mean) / tf.sqrt(var + 1e-10),
46
padding="VALID",
47
strides=list(self.strides),
48
name=self.__class__.__name__,
49
)
50
51
52
"""
53
The other way to use the `Conv.convolution_op()` API is to directly call the
54
`convolution_op()` method from the `call()` method of a convolution layer subclass.
55
A comparable class implemented using this approach is shown below.
56
"""
57
58
59
class StandardizedConv2DWithCall(layers.Conv2D):
60
def call(self, inputs):
61
mean, var = tf.nn.moments(self.kernel, axes=[0, 1, 2], keepdims=True)
62
result = self.convolution_op(
63
inputs, (self.kernel - mean) / tf.sqrt(var + 1e-10)
64
)
65
if self.use_bias:
66
result = result + self.bias
67
return result
68
69
70
"""
71
## Example Usage
72
73
Both of these layers work as drop-in replacements for `Conv2D`. The following
74
demonstration performs classification on the MNIST dataset.
75
"""
76
77
# Model / data parameters
78
num_classes = 10
79
input_shape = (28, 28, 1)
80
81
# the data, split between train and test sets
82
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
83
84
# Scale images to the [0, 1] range
85
x_train = x_train.astype("float32") / 255
86
x_test = x_test.astype("float32") / 255
87
# Make sure images have shape (28, 28, 1)
88
x_train = np.expand_dims(x_train, -1)
89
x_test = np.expand_dims(x_test, -1)
90
print("x_train shape:", x_train.shape)
91
print(x_train.shape[0], "train samples")
92
print(x_test.shape[0], "test samples")
93
94
# convert class vectors to binary class matrices
95
y_train = keras.utils.to_categorical(y_train, num_classes)
96
y_test = keras.utils.to_categorical(y_test, num_classes)
97
98
model = keras.Sequential(
99
[
100
keras.layers.Input(shape=input_shape),
101
StandardizedConv2DWithCall(32, kernel_size=(3, 3), activation="relu"),
102
layers.MaxPooling2D(pool_size=(2, 2)),
103
StandardizedConv2DWithOverride(64, kernel_size=(3, 3), activation="relu"),
104
layers.MaxPooling2D(pool_size=(2, 2)),
105
layers.Flatten(),
106
layers.Dropout(0.5),
107
layers.Dense(num_classes, activation="softmax"),
108
]
109
)
110
111
model.summary()
112
"""
113
114
"""
115
batch_size = 128
116
epochs = 5
117
118
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
119
120
model.fit(x_train, y_train, batch_size=batch_size, epochs=5, validation_split=0.1)
121
122
"""
123
## Conclusion
124
125
The `Conv.convolution_op()` API provides an easy and readable way to implement custom
126
convolution layers. A `StandardizedConvolution` implementation using the API is quite
127
terse, consisting of only four lines of code.
128
"""
129
130