Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/keras_recipes/tensorflow_numpy_models.py
3507 views
1
"""
2
Title: Writing Keras Models With TensorFlow NumPy
3
Author: [lukewood](https://lukewood.xyz)
4
Date created: 2021/08/28
5
Last modified: 2021/08/28
6
Description: Overview of how to use the TensorFlow NumPy API to write Keras models.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
[NumPy](https://numpy.org/) is a hugely successful Python linear algebra library.
14
15
TensorFlow recently launched [tf_numpy](https://www.tensorflow.org/guide/tf_numpy), a
16
TensorFlow implementation of a large subset of the NumPy API.
17
Thanks to `tf_numpy`, you can write Keras layers or models in the NumPy style!
18
19
The TensorFlow NumPy API has full integration with the TensorFlow ecosystem.
20
Features such as automatic differentiation, TensorBoard, Keras model callbacks,
21
TPU distribution and model exporting are all supported.
22
23
Let's run through a few examples.
24
"""
25
26
"""
27
## Setup
28
"""
29
30
import os
31
32
os.environ["KERAS_BACKEND"] = "tensorflow"
33
34
import tensorflow as tf
35
import tensorflow.experimental.numpy as tnp
36
import keras
37
from keras import layers
38
39
"""
40
To test our models we will use the Boston housing prices regression dataset.
41
"""
42
43
(x_train, y_train), (x_test, y_test) = keras.datasets.boston_housing.load_data(
44
path="boston_housing.npz", test_split=0.2, seed=113
45
)
46
input_dim = x_train.shape[1]
47
48
49
def evaluate_model(model: keras.Model):
50
loss, percent_error = model.evaluate(x_test, y_test, verbose=0)
51
print("Mean absolute percent error before training: ", percent_error)
52
model.fit(x_train, y_train, epochs=200, verbose=0)
53
loss, percent_error = model.evaluate(x_test, y_test, verbose=0)
54
print("Mean absolute percent error after training:", percent_error)
55
56
57
"""
58
## Subclassing keras.Model with TNP
59
60
The most flexible way to make use of the Keras API is to subclass the
61
[`keras.Model`](https://keras.io/api/models/model/) class. Subclassing the Model class
62
gives you the ability to fully customize what occurs in the training loop. This makes
63
subclassing Model a popular option for researchers.
64
65
In this example, we will implement a `Model` subclass that performs regression over the
66
boston housing dataset using the TNP API. Note that differentiation and gradient
67
descent is handled automatically when using the TNP API alongside keras.
68
69
First let's define a simple `TNPForwardFeedRegressionNetwork` class.
70
"""
71
72
73
class TNPForwardFeedRegressionNetwork(keras.Model):
74
def __init__(self, blocks=None, **kwargs):
75
super().__init__(**kwargs)
76
if not isinstance(blocks, list):
77
raise ValueError(f"blocks must be a list, got blocks={blocks}")
78
self.blocks = blocks
79
self.block_weights = None
80
self.biases = None
81
82
def build(self, input_shape):
83
current_shape = input_shape[1]
84
self.block_weights = []
85
self.biases = []
86
for i, block in enumerate(self.blocks):
87
self.block_weights.append(
88
self.add_weight(
89
shape=(current_shape, block),
90
trainable=True,
91
name=f"block-{i}",
92
initializer="glorot_normal",
93
)
94
)
95
self.biases.append(
96
self.add_weight(
97
shape=(block,),
98
trainable=True,
99
name=f"bias-{i}",
100
initializer="zeros",
101
)
102
)
103
current_shape = block
104
105
self.linear_layer = self.add_weight(
106
shape=(current_shape, 1),
107
name="linear_projector",
108
trainable=True,
109
initializer="glorot_normal",
110
)
111
112
def call(self, inputs):
113
activations = inputs
114
for w, b in zip(self.block_weights, self.biases):
115
activations = tnp.matmul(activations, w) + b
116
# ReLu activation function
117
activations = tnp.maximum(activations, 0.0)
118
119
return tnp.matmul(activations, self.linear_layer)
120
121
122
"""
123
Just like with any other Keras model we can utilize any supported optimizer, loss,
124
metrics or callbacks that we want.
125
126
Let's see how the model performs!
127
"""
128
129
model = TNPForwardFeedRegressionNetwork(blocks=[3, 3])
130
model.compile(
131
optimizer="adam",
132
loss="mean_squared_error",
133
metrics=[keras.metrics.MeanAbsolutePercentageError()],
134
)
135
evaluate_model(model)
136
137
"""
138
Great! Our model seems to be effectively learning to solve the problem at hand.
139
140
We can also write our own custom loss function using TNP.
141
"""
142
143
144
def tnp_mse(y_true, y_pred):
145
return tnp.mean(tnp.square(y_true - y_pred), axis=0)
146
147
148
keras.backend.clear_session()
149
model = TNPForwardFeedRegressionNetwork(blocks=[3, 3])
150
model.compile(
151
optimizer="adam",
152
loss=tnp_mse,
153
metrics=[keras.metrics.MeanAbsolutePercentageError()],
154
)
155
evaluate_model(model)
156
157
"""
158
## Implementing a Keras Layer Based Model with TNP
159
160
If desired, TNP can also be used in layer oriented Keras code structure. Let's
161
implement the same model, but using a layered approach!
162
"""
163
164
165
def tnp_relu(x):
166
return tnp.maximum(x, 0)
167
168
169
class TNPDense(keras.layers.Layer):
170
def __init__(self, units, activation=None):
171
super().__init__()
172
self.units = units
173
self.activation = activation
174
175
def build(self, input_shape):
176
self.w = self.add_weight(
177
name="weights",
178
shape=(input_shape[1], self.units),
179
initializer="random_normal",
180
trainable=True,
181
)
182
self.bias = self.add_weight(
183
name="bias",
184
shape=(self.units,),
185
initializer="zeros",
186
trainable=True,
187
)
188
189
def call(self, inputs):
190
outputs = tnp.matmul(inputs, self.w) + self.bias
191
if self.activation:
192
return self.activation(outputs)
193
return outputs
194
195
196
def create_layered_tnp_model():
197
return keras.Sequential(
198
[
199
TNPDense(3, activation=tnp_relu),
200
TNPDense(3, activation=tnp_relu),
201
TNPDense(1),
202
]
203
)
204
205
206
model = create_layered_tnp_model()
207
model.compile(
208
optimizer="adam",
209
loss="mean_squared_error",
210
metrics=[keras.metrics.MeanAbsolutePercentageError()],
211
)
212
model.build((None, input_dim))
213
model.summary()
214
215
evaluate_model(model)
216
217
"""
218
You can also seamlessly switch between TNP layers and native Keras layers!
219
"""
220
221
222
def create_mixed_model():
223
return keras.Sequential(
224
[
225
TNPDense(3, activation=tnp_relu),
226
# The model will have no issue using a normal Dense layer
227
layers.Dense(3, activation="relu"),
228
# ... or switching back to tnp layers!
229
TNPDense(1),
230
]
231
)
232
233
234
model = create_mixed_model()
235
model.compile(
236
optimizer="adam",
237
loss="mean_squared_error",
238
metrics=[keras.metrics.MeanAbsolutePercentageError()],
239
)
240
model.build((None, input_dim))
241
model.summary()
242
243
evaluate_model(model)
244
245
"""
246
The Keras API offers a wide variety of layers. The ability to use them alongside NumPy
247
code can be a huge time saver in projects.
248
"""
249
250
"""
251
## Distribution Strategy
252
253
TensorFlow NumPy and Keras integrate with
254
[TensorFlow Distribution Strategies](https://www.tensorflow.org/guide/distributed_training).
255
This makes it simple to perform distributed training across multiple GPUs,
256
or even an entire TPU Pod.
257
"""
258
259
gpus = tf.config.list_logical_devices("GPU")
260
if gpus:
261
strategy = tf.distribute.MirroredStrategy(gpus)
262
else:
263
# We can fallback to a no-op CPU strategy.
264
strategy = tf.distribute.get_strategy()
265
print("Running with strategy:", str(strategy.__class__.__name__))
266
267
with strategy.scope():
268
model = create_layered_tnp_model()
269
model.compile(
270
optimizer="adam",
271
loss="mean_squared_error",
272
metrics=[keras.metrics.MeanAbsolutePercentageError()],
273
)
274
model.build((None, input_dim))
275
model.summary()
276
evaluate_model(model)
277
278
"""
279
## TensorBoard Integration
280
281
One of the many benefits of using the Keras API is the ability to monitor training
282
through TensorBoard. Using the TensorFlow NumPy API alongside Keras allows you to easily
283
leverage TensorBoard.
284
"""
285
286
keras.backend.clear_session()
287
288
"""
289
To load the TensorBoard from a Jupyter notebook, you can run the following magic:
290
```
291
%load_ext tensorboard
292
```
293
294
"""
295
296
models = [
297
(
298
TNPForwardFeedRegressionNetwork(blocks=[3, 3]),
299
"TNPForwardFeedRegressionNetwork",
300
),
301
(create_layered_tnp_model(), "layered_tnp_model"),
302
(create_mixed_model(), "mixed_model"),
303
]
304
for model, model_name in models:
305
model.compile(
306
optimizer="adam",
307
loss="mean_squared_error",
308
metrics=[keras.metrics.MeanAbsolutePercentageError()],
309
)
310
model.fit(
311
x_train,
312
y_train,
313
epochs=200,
314
verbose=0,
315
callbacks=[keras.callbacks.TensorBoard(log_dir=f"logs/{model_name}")],
316
)
317
318
"""
319
To load the TensorBoard from a Jupyter notebook you can use the `%tensorboard` magic:
320
321
```
322
%tensorboard --logdir logs
323
```
324
325
The TensorBoard monitor metrics and examine the training curve.
326
327
![Tensorboard training graph](https://i.imgur.com/wsOuFnz.png)
328
329
The TensorBoard also allows you to explore the computation graph used in your models.
330
331
![Tensorboard graph exploration](https://i.imgur.com/tOrezDL.png)
332
333
The ability to introspect into your models can be valuable during debugging.
334
"""
335
336
"""
337
## Conclusion
338
339
Porting existing NumPy code to Keras models using the `tensorflow_numpy` API is easy!
340
By integrating with Keras you gain the ability to use existing Keras callbacks, metrics
341
and optimizers, easily distribute your training and use Tensorboard.
342
343
Migrating a more complex model, such as a ResNet, to the TensorFlow NumPy API would be a
344
great follow up learning exercise.
345
346
Several open source NumPy ResNet implementations are available online.
347
"""
348
349