Path: blob/master/examples/keras_recipes/antirectifier.py
3507 views
"""1Title: Simple custom layer example: Antirectifier2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2016/01/064Last modified: 2023/11/205Description: Demonstration of custom layer creation.6Accelerator: GPU7"""89"""10## Introduction1112This example shows how to create custom layers, using the Antirectifier layer13(originally proposed as a Keras example script in January 2016), an alternative14to ReLU. Instead of zeroing-out the negative part of the input, it splits the negative15and positive parts and returns the concatenation of the absolute value16of both. This avoids loss of information, at the cost of an increase in dimensionality.17To fix the dimensionality increase, we linearly combine the18features back to a space of the original size.19"""2021"""22## Setup23"""2425import keras26from keras import layers27from keras import ops2829"""30## The Antirectifier layer3132To implement a custom layer:3334- Create the state variables via `add_weight()` in `__init__` or `build()`.35Similarly, you can also create sublayers.36- Implement the `call()` method, taking the layer's input tensor(s) and37return the output tensor(s).38- Optionally, you can also enable serialization by implementing `get_config()`,39which returns a configuration dictionary.4041See also the guide42[Making new layers and models via subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/).43"""444546class Antirectifier(layers.Layer):47def __init__(self, initializer="he_normal", **kwargs):48super().__init__(**kwargs)49self.initializer = keras.initializers.get(initializer)5051def build(self, input_shape):52output_dim = input_shape[-1]53self.kernel = self.add_weight(54shape=(output_dim * 2, output_dim),55initializer=self.initializer,56name="kernel",57trainable=True,58)5960def call(self, inputs):61inputs -= ops.mean(inputs, axis=-1, keepdims=True)62pos = ops.relu(inputs)63neg = ops.relu(-inputs)64concatenated = ops.concatenate([pos, neg], axis=-1)65mixed = ops.matmul(concatenated, self.kernel)66return mixed6768def get_config(self):69# Implement get_config to enable serialization. This is optional.70base_config = super().get_config()71config = {"initializer": keras.initializers.serialize(self.initializer)}72return dict(list(base_config.items()) + list(config.items()))737475"""76## Let's test-drive it on MNIST77"""7879# Training parameters80batch_size = 12881num_classes = 1082epochs = 208384# The data, split between train and test sets85(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()8687x_train = x_train.reshape(-1, 784)88x_test = x_test.reshape(-1, 784)89x_train = x_train.astype("float32")90x_test = x_test.astype("float32")91x_train /= 25592x_test /= 25593print(x_train.shape[0], "train samples")94print(x_test.shape[0], "test samples")9596# Build the model97model = keras.Sequential(98[99keras.Input(shape=(784,)),100layers.Dense(256),101Antirectifier(),102layers.Dense(256),103Antirectifier(),104layers.Dropout(0.5),105layers.Dense(10),106]107)108109# Compile the model110model.compile(111loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),112optimizer=keras.optimizers.RMSprop(),113metrics=[keras.metrics.SparseCategoricalAccuracy()],114)115116# Train the model117model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.15)118119# Test the model120model.evaluate(x_test, y_test)121122123