Path: blob/master/examples/keras_recipes/trainer_pattern.py
3507 views
"""1Title: Trainer pattern2Author: [nkovela1](https://nkovela1.github.io/)3Date created: 2022/09/194Last modified: 2022/09/265Description: Guide on how to share a custom training step across multiple Keras models.6Accelerator: GPU7"""89"""10## Introduction1112This example shows how to create a custom training step using the "Trainer pattern",13which can then be shared across multiple Keras models. This pattern overrides the14`train_step()` method of the `keras.Model` class, allowing for training loops15beyond plain supervised learning.1617The Trainer pattern can also easily be adapted to more complex models with larger18custom training steps, such as19[this end-to-end GAN model](https://keras.io/guides/custom_train_step_in_tensorflow/#wrapping-up-an-endtoend-gan-example),20by putting the custom training step in the Trainer class definition.21"""2223"""24## Setup25"""2627import os2829os.environ["KERAS_BACKEND"] = "tensorflow"3031import tensorflow as tf32import keras3334# Load MNIST dataset and standardize the data35mnist = keras.datasets.mnist36(x_train, y_train), (x_test, y_test) = mnist.load_data()37x_train, x_test = x_train / 255.0, x_test / 255.0383940"""41## Define the Trainer class4243A custom training and evaluation step can be created by overriding44the `train_step()` and `test_step()` method of a `Model` subclass:45"""464748class MyTrainer(keras.Model):49def __init__(self, model):50super().__init__()51self.model = model52# Create loss and metrics here.53self.loss_fn = keras.losses.SparseCategoricalCrossentropy()54self.accuracy_metric = keras.metrics.SparseCategoricalAccuracy()5556@property57def metrics(self):58# List metrics here.59return [self.accuracy_metric]6061def train_step(self, data):62x, y = data63with tf.GradientTape() as tape:64y_pred = self.model(x, training=True) # Forward pass65# Compute loss value66loss = self.loss_fn(y, y_pred)6768# Compute gradients69trainable_vars = self.trainable_variables70gradients = tape.gradient(loss, trainable_vars)7172# Update weights73self.optimizer.apply_gradients(zip(gradients, trainable_vars))7475# Update metrics76for metric in self.metrics:77metric.update_state(y, y_pred)7879# Return a dict mapping metric names to current value.80return {m.name: m.result() for m in self.metrics}8182def test_step(self, data):83x, y = data8485# Inference step86y_pred = self.model(x, training=False)8788# Update metrics89for metric in self.metrics:90metric.update_state(y, y_pred)91return {m.name: m.result() for m in self.metrics}9293def call(self, x):94# Equivalent to `call()` of the wrapped keras.Model95x = self.model(x)96return x979899"""100## Define multiple models to share the custom training step101102Let's define two different models that can share our Trainer class and its custom `train_step()`:103"""104105# A model defined using Sequential API106model_a = keras.models.Sequential(107[108keras.layers.Flatten(),109keras.layers.Dense(256, activation="relu"),110keras.layers.Dropout(0.2),111keras.layers.Dense(10, activation="softmax"),112]113)114115# A model defined using Functional API116func_input = keras.Input(shape=(28, 28, 1))117x = keras.layers.Flatten()(func_input)118x = keras.layers.Dense(512, activation="relu")(x)119x = keras.layers.Dropout(0.4)(x)120func_output = keras.layers.Dense(10, activation="softmax")(x)121122model_b = keras.Model(func_input, func_output)123124"""125## Create Trainer class objects from the models126"""127128trainer_1 = MyTrainer(model_a)129trainer_2 = MyTrainer(model_b)130131"""132## Compile and fit the models to the MNIST dataset133"""134135trainer_1.compile(optimizer=keras.optimizers.SGD())136trainer_1.fit(137x_train, y_train, epochs=5, batch_size=64, validation_data=(x_test, y_test)138)139140trainer_2.compile(optimizer=keras.optimizers.Adam())141trainer_2.fit(142x_train, y_train, epochs=5, batch_size=64, validation_data=(x_test, y_test)143)144145146