Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/keras_recipes/debugging_tips.py
3507 views
1
"""
2
Title: Keras debugging tips
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2020/05/16
5
Last modified: 2023/11/16
6
Description: Four simple tips to help you debug your Keras code.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
It's generally possible to do almost anything in Keras *without writing code* per se:
14
whether you're implementing a new type of GAN or the latest convnet architecture for
15
image segmentation, you can usually stick to calling built-in methods. Because all
16
built-in methods do extensive input validation checks, you will have little to no
17
debugging to do. A Functional API model made entirely of built-in layers will work on
18
first try -- if you can compile it, it will run.
19
20
However, sometimes, you will need to dive deeper and write your own code. Here are some
21
common examples:
22
23
- Creating a new `Layer` subclass.
24
- Creating a custom `Metric` subclass.
25
- Implementing a custom `train_step` on a `Model`.
26
27
This document provides a few simple tips to help you navigate debugging in these
28
situations.
29
30
"""
31
32
"""
33
## Tip 1: test each part before you test the whole
34
35
If you've created any object that has a chance of not working as expected, don't just
36
drop it in your end-to-end process and watch sparks fly. Rather, test your custom object
37
in isolation first. This may seem obvious -- but you'd be surprised how often people
38
don't start with this.
39
40
- If you write a custom layer, don't call `fit()` on your entire model just yet. Call
41
your layer on some test data first.
42
- If you write a custom metric, start by printing its output for some reference inputs.
43
44
Here's a simple example. Let's write a custom layer a bug in it:
45
46
"""
47
48
import os
49
50
# The last example uses tf.GradientTape and thus requires TensorFlow.
51
# However, all tips here are applicable with all backends.
52
os.environ["KERAS_BACKEND"] = "tensorflow"
53
54
import keras
55
from keras import layers
56
from keras import ops
57
import numpy as np
58
import tensorflow as tf
59
60
61
class MyAntirectifier(layers.Layer):
62
def build(self, input_shape):
63
output_dim = input_shape[-1]
64
self.kernel = self.add_weight(
65
shape=(output_dim * 2, output_dim),
66
initializer="he_normal",
67
name="kernel",
68
trainable=True,
69
)
70
71
def call(self, inputs):
72
# Take the positive part of the input
73
pos = ops.relu(inputs)
74
# Take the negative part of the input
75
neg = ops.relu(-inputs)
76
# Concatenate the positive and negative parts
77
concatenated = ops.concatenate([pos, neg], axis=0)
78
# Project the concatenation down to the same dimensionality as the input
79
return ops.matmul(concatenated, self.kernel)
80
81
82
"""
83
Now, rather than using it in a end-to-end model directly, let's try to call the layer on
84
some test data:
85
86
```python
87
x = tf.random.normal(shape=(2, 5))
88
y = MyAntirectifier()(x)
89
```
90
91
We get the following error:
92
93
```
94
...
95
1 x = tf.random.normal(shape=(2, 5))
96
----> 2 y = MyAntirectifier()(x)
97
...
98
17 neg = tf.nn.relu(-inputs)
99
18 concatenated = tf.concat([pos, neg], axis=0)
100
---> 19 return tf.matmul(concatenated, self.kernel)
101
...
102
InvalidArgumentError: Matrix size-incompatible: In[0]: [4,5], In[1]: [10,5] [Op:MatMul]
103
```
104
105
Looks like our input tensor in the `matmul` op may have an incorrect shape.
106
Let's add a print statement to check the actual shapes:
107
108
"""
109
110
111
class MyAntirectifier(layers.Layer):
112
def build(self, input_shape):
113
output_dim = input_shape[-1]
114
self.kernel = self.add_weight(
115
shape=(output_dim * 2, output_dim),
116
initializer="he_normal",
117
name="kernel",
118
trainable=True,
119
)
120
121
def call(self, inputs):
122
pos = ops.relu(inputs)
123
neg = ops.relu(-inputs)
124
print("pos.shape:", pos.shape)
125
print("neg.shape:", neg.shape)
126
concatenated = ops.concatenate([pos, neg], axis=0)
127
print("concatenated.shape:", concatenated.shape)
128
print("kernel.shape:", self.kernel.shape)
129
return ops.matmul(concatenated, self.kernel)
130
131
132
"""
133
We get the following:
134
135
```
136
pos.shape: (2, 5)
137
neg.shape: (2, 5)
138
concatenated.shape: (4, 5)
139
kernel.shape: (10, 5)
140
```
141
142
Turns out we had the wrong axis for the `concat` op! We should be concatenating `neg` and
143
`pos` alongside the feature axis 1, not the batch axis 0. Here's the correct version:
144
"""
145
146
147
class MyAntirectifier(layers.Layer):
148
def build(self, input_shape):
149
output_dim = input_shape[-1]
150
self.kernel = self.add_weight(
151
shape=(output_dim * 2, output_dim),
152
initializer="he_normal",
153
name="kernel",
154
trainable=True,
155
)
156
157
def call(self, inputs):
158
pos = ops.relu(inputs)
159
neg = ops.relu(-inputs)
160
print("pos.shape:", pos.shape)
161
print("neg.shape:", neg.shape)
162
concatenated = ops.concatenate([pos, neg], axis=1)
163
print("concatenated.shape:", concatenated.shape)
164
print("kernel.shape:", self.kernel.shape)
165
return ops.matmul(concatenated, self.kernel)
166
167
168
"""
169
Now our code works fine:
170
"""
171
172
x = keras.random.normal(shape=(2, 5))
173
y = MyAntirectifier()(x)
174
175
"""
176
## Tip 2: use `model.summary()` and `plot_model()` to check layer output shapes
177
178
If you're working with complex network topologies, you're going to need a way
179
to visualize how your layers are connected and how they transform the data that passes
180
through them.
181
182
Here's an example. Consider this model with three inputs and two outputs (lifted from the
183
[Functional API guide](https://keras.io/guides/functional_api/#manipulate-complex-graph-topologies)):
184
185
"""
186
187
num_tags = 12 # Number of unique issue tags
188
num_words = 10000 # Size of vocabulary obtained when preprocessing text data
189
num_departments = 4 # Number of departments for predictions
190
191
title_input = keras.Input(
192
shape=(None,), name="title"
193
) # Variable-length sequence of ints
194
body_input = keras.Input(shape=(None,), name="body") # Variable-length sequence of ints
195
tags_input = keras.Input(
196
shape=(num_tags,), name="tags"
197
) # Binary vectors of size `num_tags`
198
199
# Embed each word in the title into a 64-dimensional vector
200
title_features = layers.Embedding(num_words, 64)(title_input)
201
# Embed each word in the text into a 64-dimensional vector
202
body_features = layers.Embedding(num_words, 64)(body_input)
203
204
# Reduce sequence of embedded words in the title into a single 128-dimensional vector
205
title_features = layers.LSTM(128)(title_features)
206
# Reduce sequence of embedded words in the body into a single 32-dimensional vector
207
body_features = layers.LSTM(32)(body_features)
208
209
# Merge all available features into a single large vector via concatenation
210
x = layers.concatenate([title_features, body_features, tags_input])
211
212
# Stick a logistic regression for priority prediction on top of the features
213
priority_pred = layers.Dense(1, name="priority")(x)
214
# Stick a department classifier on top of the features
215
department_pred = layers.Dense(num_departments, name="department")(x)
216
217
# Instantiate an end-to-end model predicting both priority and department
218
model = keras.Model(
219
inputs=[title_input, body_input, tags_input],
220
outputs=[priority_pred, department_pred],
221
)
222
223
"""
224
Calling `summary()` can help you check the output shape of each layer:
225
"""
226
227
model.summary()
228
229
"""
230
You can also visualize the entire network topology alongside output shapes using
231
`plot_model`:
232
"""
233
234
keras.utils.plot_model(model, show_shapes=True)
235
236
"""
237
With this plot, any connectivity-level error becomes immediately obvious.
238
"""
239
240
"""
241
## Tip 3: to debug what happens during `fit()`, use `run_eagerly=True`
242
243
The `fit()` method is fast: it runs a well-optimized, fully-compiled computation graph.
244
That's great for performance, but it also means that the code you're executing isn't the
245
Python code you've written. This can be problematic when debugging. As you may recall,
246
Python is slow -- so we use it as a staging language, not as an execution language.
247
248
Thankfully, there's an easy way to run your code in "debug mode", fully eagerly:
249
pass `run_eagerly=True` to `compile()`. Your call to `fit()` will now get executed line
250
by line, without any optimization. It's slower, but it makes it possible to print the
251
value of intermediate tensors, or to use a Python debugger. Great for debugging.
252
253
Here's a basic example: let's write a really simple model with a custom `train_step()` method.
254
Our model just implements gradient descent, but instead of first-order gradients,
255
it uses a combination of first-order and second-order gradients. Pretty simple so far.
256
257
Can you spot what we're doing wrong?
258
"""
259
260
261
class MyModel(keras.Model):
262
def train_step(self, data):
263
inputs, targets = data
264
trainable_vars = self.trainable_variables
265
with tf.GradientTape() as tape2:
266
with tf.GradientTape() as tape1:
267
y_pred = self(inputs, training=True) # Forward pass
268
# Compute the loss value
269
# (the loss function is configured in `compile()`)
270
loss = self.compute_loss(y=targets, y_pred=y_pred)
271
# Compute first-order gradients
272
dl_dw = tape1.gradient(loss, trainable_vars)
273
# Compute second-order gradients
274
d2l_dw2 = tape2.gradient(dl_dw, trainable_vars)
275
276
# Combine first-order and second-order gradients
277
grads = [0.5 * w1 + 0.5 * w2 for (w1, w2) in zip(d2l_dw2, dl_dw)]
278
279
# Update weights
280
self.optimizer.apply_gradients(zip(grads, trainable_vars))
281
282
# Update metrics (includes the metric that tracks the loss)
283
for metric in self.metrics:
284
if metric.name == "loss":
285
metric.update_state(loss)
286
else:
287
metric.update_state(targets, y_pred)
288
289
# Return a dict mapping metric names to current value
290
return {m.name: m.result() for m in self.metrics}
291
292
293
"""
294
Let's train a one-layer model on MNIST with this custom loss function.
295
296
We pick, somewhat at random, a batch size of 1024 and a learning rate of 0.1. The general
297
idea being to use larger batches and a larger learning rate than usual, since our
298
"improved" gradients should lead us to quicker convergence.
299
"""
300
301
302
# Construct an instance of MyModel
303
def get_model():
304
inputs = keras.Input(shape=(784,))
305
intermediate = layers.Dense(256, activation="relu")(inputs)
306
outputs = layers.Dense(10, activation="softmax")(intermediate)
307
model = MyModel(inputs, outputs)
308
return model
309
310
311
# Prepare data
312
(x_train, y_train), _ = keras.datasets.mnist.load_data()
313
x_train = np.reshape(x_train, (-1, 784)) / 255
314
315
model = get_model()
316
model.compile(
317
optimizer=keras.optimizers.SGD(learning_rate=1e-2),
318
loss="sparse_categorical_crossentropy",
319
)
320
model.fit(x_train, y_train, epochs=3, batch_size=1024, validation_split=0.1)
321
322
"""
323
Oh no, it doesn't converge! Something is not working as planned.
324
325
Time for some step-by-step printing of what's going on with our gradients.
326
327
We add various `print` statements in the `train_step` method, and we make sure to pass
328
`run_eagerly=True` to `compile()` to run our code step-by-step, eagerly.
329
"""
330
331
332
class MyModel(keras.Model):
333
def train_step(self, data):
334
print()
335
print("----Start of step: %d" % (self.step_counter,))
336
self.step_counter += 1
337
338
inputs, targets = data
339
trainable_vars = self.trainable_variables
340
with tf.GradientTape() as tape2:
341
with tf.GradientTape() as tape1:
342
y_pred = self(inputs, training=True) # Forward pass
343
# Compute the loss value
344
# (the loss function is configured in `compile()`)
345
loss = self.compute_loss(y=targets, y_pred=y_pred)
346
# Compute first-order gradients
347
dl_dw = tape1.gradient(loss, trainable_vars)
348
# Compute second-order gradients
349
d2l_dw2 = tape2.gradient(dl_dw, trainable_vars)
350
351
print("Max of dl_dw[0]: %.4f" % tf.reduce_max(dl_dw[0]))
352
print("Min of dl_dw[0]: %.4f" % tf.reduce_min(dl_dw[0]))
353
print("Mean of dl_dw[0]: %.4f" % tf.reduce_mean(dl_dw[0]))
354
print("-")
355
print("Max of d2l_dw2[0]: %.4f" % tf.reduce_max(d2l_dw2[0]))
356
print("Min of d2l_dw2[0]: %.4f" % tf.reduce_min(d2l_dw2[0]))
357
print("Mean of d2l_dw2[0]: %.4f" % tf.reduce_mean(d2l_dw2[0]))
358
359
# Combine first-order and second-order gradients
360
grads = [0.5 * w1 + 0.5 * w2 for (w1, w2) in zip(d2l_dw2, dl_dw)]
361
362
# Update weights
363
self.optimizer.apply_gradients(zip(grads, trainable_vars))
364
365
# Update metrics (includes the metric that tracks the loss)
366
for metric in self.metrics:
367
if metric.name == "loss":
368
metric.update_state(loss)
369
else:
370
metric.update_state(targets, y_pred)
371
372
# Return a dict mapping metric names to current value
373
return {m.name: m.result() for m in self.metrics}
374
375
376
model = get_model()
377
model.compile(
378
optimizer=keras.optimizers.SGD(learning_rate=1e-2),
379
loss="sparse_categorical_crossentropy",
380
metrics=["sparse_categorical_accuracy"],
381
run_eagerly=True,
382
)
383
model.step_counter = 0
384
# We pass epochs=1 and steps_per_epoch=10 to only run 10 steps of training.
385
model.fit(x_train, y_train, epochs=1, batch_size=1024, verbose=0, steps_per_epoch=10)
386
387
"""
388
What did we learn?
389
390
- The first order and second order gradients can have values that differ by orders of
391
magnitudes.
392
- Sometimes, they may not even have the same sign.
393
- Their values can vary greatly at each step.
394
395
This leads us to an obvious idea: let's normalize the gradients before combining them.
396
"""
397
398
399
class MyModel(keras.Model):
400
def train_step(self, data):
401
inputs, targets = data
402
trainable_vars = self.trainable_variables
403
with tf.GradientTape() as tape2:
404
with tf.GradientTape() as tape1:
405
y_pred = self(inputs, training=True) # Forward pass
406
# Compute the loss value
407
# (the loss function is configured in `compile()`)
408
loss = self.compute_loss(y=targets, y_pred=y_pred)
409
# Compute first-order gradients
410
dl_dw = tape1.gradient(loss, trainable_vars)
411
# Compute second-order gradients
412
d2l_dw2 = tape2.gradient(dl_dw, trainable_vars)
413
414
dl_dw = [tf.math.l2_normalize(w) for w in dl_dw]
415
d2l_dw2 = [tf.math.l2_normalize(w) for w in d2l_dw2]
416
417
# Combine first-order and second-order gradients
418
grads = [0.5 * w1 + 0.5 * w2 for (w1, w2) in zip(d2l_dw2, dl_dw)]
419
420
# Update weights
421
self.optimizer.apply_gradients(zip(grads, trainable_vars))
422
423
# Update metrics (includes the metric that tracks the loss)
424
for metric in self.metrics:
425
if metric.name == "loss":
426
metric.update_state(loss)
427
else:
428
metric.update_state(targets, y_pred)
429
430
# Return a dict mapping metric names to current value
431
return {m.name: m.result() for m in self.metrics}
432
433
434
model = get_model()
435
model.compile(
436
optimizer=keras.optimizers.SGD(learning_rate=1e-2),
437
loss="sparse_categorical_crossentropy",
438
metrics=["sparse_categorical_accuracy"],
439
)
440
model.fit(x_train, y_train, epochs=5, batch_size=1024, validation_split=0.1)
441
442
"""
443
Now, training converges! It doesn't work well at all, but at least the model learns
444
something.
445
446
After spending a few minutes tuning parameters, we get to the following configuration
447
that works somewhat well (achieves 97% validation accuracy and seems reasonably robust to
448
overfitting):
449
450
- Use `0.2 * w1 + 0.8 * w2` for combining gradients.
451
- Use a learning rate that decays linearly over time.
452
453
I'm not going to say that the idea works -- this isn't at all how you're supposed to do
454
second-order optimization (pointers: see the Newton & Gauss-Newton methods, quasi-Newton
455
methods, and BFGS). But hopefully this demonstration gave you an idea of how you can
456
debug your way out of uncomfortable training situations.
457
458
Remember: use `run_eagerly=True` for debugging what happens in `fit()`. And when your code
459
is finally working as expected, make sure to remove this flag in order to get the best
460
runtime performance!
461
462
Here's our final training run:
463
"""
464
465
466
class MyModel(keras.Model):
467
def train_step(self, data):
468
inputs, targets = data
469
trainable_vars = self.trainable_variables
470
with tf.GradientTape() as tape2:
471
with tf.GradientTape() as tape1:
472
y_pred = self(inputs, training=True) # Forward pass
473
# Compute the loss value
474
# (the loss function is configured in `compile()`)
475
loss = self.compute_loss(y=targets, y_pred=y_pred)
476
# Compute first-order gradients
477
dl_dw = tape1.gradient(loss, trainable_vars)
478
# Compute second-order gradients
479
d2l_dw2 = tape2.gradient(dl_dw, trainable_vars)
480
481
dl_dw = [tf.math.l2_normalize(w) for w in dl_dw]
482
d2l_dw2 = [tf.math.l2_normalize(w) for w in d2l_dw2]
483
484
# Combine first-order and second-order gradients
485
grads = [0.2 * w1 + 0.8 * w2 for (w1, w2) in zip(d2l_dw2, dl_dw)]
486
487
# Update weights
488
self.optimizer.apply_gradients(zip(grads, trainable_vars))
489
490
# Update metrics (includes the metric that tracks the loss)
491
for metric in self.metrics:
492
if metric.name == "loss":
493
metric.update_state(loss)
494
else:
495
metric.update_state(targets, y_pred)
496
497
# Return a dict mapping metric names to current value
498
return {m.name: m.result() for m in self.metrics}
499
500
501
model = get_model()
502
lr = learning_rate = keras.optimizers.schedules.InverseTimeDecay(
503
initial_learning_rate=0.1, decay_steps=25, decay_rate=0.1
504
)
505
model.compile(
506
optimizer=keras.optimizers.SGD(lr),
507
loss="sparse_categorical_crossentropy",
508
metrics=["sparse_categorical_accuracy"],
509
)
510
model.fit(x_train, y_train, epochs=50, batch_size=2048, validation_split=0.1)
511
512