Path: blob/master/examples/keras_recipes/subclassing_conv_layers.py
3507 views
"""1Title: Customizing the convolution operation of a Conv2D layer2Author: [lukewood](https://lukewood.xyz)3Date created: 11/03/20214Last modified: 11/03/20215Description: This example shows how to implement custom convolution layers using the `Conv.convolution_op()` API.6Accelerator: GPU7"""89"""10## Introduction1112You may sometimes need to implement custom versions of convolution layers like `Conv1D` and `Conv2D`.13Keras enables you do this without implementing the entire layer from scratch: you can reuse14most of the base convolution layer and just customize the convolution op itself via the15`convolution_op()` method.1617This method was introduced in Keras 2.7. So before using the18`convolution_op()` API, ensure that you are running Keras version 2.7.0 or greater.19"""2021"""22## A Simple `StandardizedConv2D` implementation2324There are two ways to use the `Conv.convolution_op()` API. The first way25is to override the `convolution_op()` method on a convolution layer subclass.26Using this approach, we can quickly implement a27[StandardizedConv2D](https://arxiv.org/abs/1903.10520) as shown below.28"""29import os3031os.environ["KERAS_BACKEND"] = "tensorflow"3233import tensorflow as tf34import keras35from keras import layers36import numpy as np373839class StandardizedConv2DWithOverride(layers.Conv2D):40def convolution_op(self, inputs, kernel):41mean, var = tf.nn.moments(kernel, axes=[0, 1, 2], keepdims=True)42return tf.nn.conv2d(43inputs,44(kernel - mean) / tf.sqrt(var + 1e-10),45padding="VALID",46strides=list(self.strides),47name=self.__class__.__name__,48)495051"""52The other way to use the `Conv.convolution_op()` API is to directly call the53`convolution_op()` method from the `call()` method of a convolution layer subclass.54A comparable class implemented using this approach is shown below.55"""565758class StandardizedConv2DWithCall(layers.Conv2D):59def call(self, inputs):60mean, var = tf.nn.moments(self.kernel, axes=[0, 1, 2], keepdims=True)61result = self.convolution_op(62inputs, (self.kernel - mean) / tf.sqrt(var + 1e-10)63)64if self.use_bias:65result = result + self.bias66return result676869"""70## Example Usage7172Both of these layers work as drop-in replacements for `Conv2D`. The following73demonstration performs classification on the MNIST dataset.74"""7576# Model / data parameters77num_classes = 1078input_shape = (28, 28, 1)7980# the data, split between train and test sets81(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()8283# Scale images to the [0, 1] range84x_train = x_train.astype("float32") / 25585x_test = x_test.astype("float32") / 25586# Make sure images have shape (28, 28, 1)87x_train = np.expand_dims(x_train, -1)88x_test = np.expand_dims(x_test, -1)89print("x_train shape:", x_train.shape)90print(x_train.shape[0], "train samples")91print(x_test.shape[0], "test samples")9293# convert class vectors to binary class matrices94y_train = keras.utils.to_categorical(y_train, num_classes)95y_test = keras.utils.to_categorical(y_test, num_classes)9697model = keras.Sequential(98[99keras.layers.Input(shape=input_shape),100StandardizedConv2DWithCall(32, kernel_size=(3, 3), activation="relu"),101layers.MaxPooling2D(pool_size=(2, 2)),102StandardizedConv2DWithOverride(64, kernel_size=(3, 3), activation="relu"),103layers.MaxPooling2D(pool_size=(2, 2)),104layers.Flatten(),105layers.Dropout(0.5),106layers.Dense(num_classes, activation="softmax"),107]108)109110model.summary()111"""112113"""114batch_size = 128115epochs = 5116117model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])118119model.fit(x_train, y_train, batch_size=batch_size, epochs=5, validation_split=0.1)120121"""122## Conclusion123124The `Conv.convolution_op()` API provides an easy and readable way to implement custom125convolution layers. A `StandardizedConvolution` implementation using the API is quite126terse, consisting of only four lines of code.127"""128129130