Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/keras_recipes/trainer_pattern.py
3507 views
1
"""
2
Title: Trainer pattern
3
Author: [nkovela1](https://nkovela1.github.io/)
4
Date created: 2022/09/19
5
Last modified: 2022/09/26
6
Description: Guide on how to share a custom training step across multiple Keras models.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
This example shows how to create a custom training step using the "Trainer pattern",
14
which can then be shared across multiple Keras models. This pattern overrides the
15
`train_step()` method of the `keras.Model` class, allowing for training loops
16
beyond plain supervised learning.
17
18
The Trainer pattern can also easily be adapted to more complex models with larger
19
custom training steps, such as
20
[this end-to-end GAN model](https://keras.io/guides/custom_train_step_in_tensorflow/#wrapping-up-an-endtoend-gan-example),
21
by putting the custom training step in the Trainer class definition.
22
"""
23
24
"""
25
## Setup
26
"""
27
28
import os
29
30
os.environ["KERAS_BACKEND"] = "tensorflow"
31
32
import tensorflow as tf
33
import keras
34
35
# Load MNIST dataset and standardize the data
36
mnist = keras.datasets.mnist
37
(x_train, y_train), (x_test, y_test) = mnist.load_data()
38
x_train, x_test = x_train / 255.0, x_test / 255.0
39
40
41
"""
42
## Define the Trainer class
43
44
A custom training and evaluation step can be created by overriding
45
the `train_step()` and `test_step()` method of a `Model` subclass:
46
"""
47
48
49
class MyTrainer(keras.Model):
50
def __init__(self, model):
51
super().__init__()
52
self.model = model
53
# Create loss and metrics here.
54
self.loss_fn = keras.losses.SparseCategoricalCrossentropy()
55
self.accuracy_metric = keras.metrics.SparseCategoricalAccuracy()
56
57
@property
58
def metrics(self):
59
# List metrics here.
60
return [self.accuracy_metric]
61
62
def train_step(self, data):
63
x, y = data
64
with tf.GradientTape() as tape:
65
y_pred = self.model(x, training=True) # Forward pass
66
# Compute loss value
67
loss = self.loss_fn(y, y_pred)
68
69
# Compute gradients
70
trainable_vars = self.trainable_variables
71
gradients = tape.gradient(loss, trainable_vars)
72
73
# Update weights
74
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
75
76
# Update metrics
77
for metric in self.metrics:
78
metric.update_state(y, y_pred)
79
80
# Return a dict mapping metric names to current value.
81
return {m.name: m.result() for m in self.metrics}
82
83
def test_step(self, data):
84
x, y = data
85
86
# Inference step
87
y_pred = self.model(x, training=False)
88
89
# Update metrics
90
for metric in self.metrics:
91
metric.update_state(y, y_pred)
92
return {m.name: m.result() for m in self.metrics}
93
94
def call(self, x):
95
# Equivalent to `call()` of the wrapped keras.Model
96
x = self.model(x)
97
return x
98
99
100
"""
101
## Define multiple models to share the custom training step
102
103
Let's define two different models that can share our Trainer class and its custom `train_step()`:
104
"""
105
106
# A model defined using Sequential API
107
model_a = keras.models.Sequential(
108
[
109
keras.layers.Flatten(),
110
keras.layers.Dense(256, activation="relu"),
111
keras.layers.Dropout(0.2),
112
keras.layers.Dense(10, activation="softmax"),
113
]
114
)
115
116
# A model defined using Functional API
117
func_input = keras.Input(shape=(28, 28, 1))
118
x = keras.layers.Flatten()(func_input)
119
x = keras.layers.Dense(512, activation="relu")(x)
120
x = keras.layers.Dropout(0.4)(x)
121
func_output = keras.layers.Dense(10, activation="softmax")(x)
122
123
model_b = keras.Model(func_input, func_output)
124
125
"""
126
## Create Trainer class objects from the models
127
"""
128
129
trainer_1 = MyTrainer(model_a)
130
trainer_2 = MyTrainer(model_b)
131
132
"""
133
## Compile and fit the models to the MNIST dataset
134
"""
135
136
trainer_1.compile(optimizer=keras.optimizers.SGD())
137
trainer_1.fit(
138
x_train, y_train, epochs=5, batch_size=64, validation_data=(x_test, y_test)
139
)
140
141
trainer_2.compile(optimizer=keras.optimizers.Adam())
142
trainer_2.fit(
143
x_train, y_train, epochs=5, batch_size=64, validation_data=(x_test, y_test)
144
)
145
146