Path: blob/master/examples/keras_recipes/debugging_tips.py
3507 views
"""1Title: Keras debugging tips2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2020/05/164Last modified: 2023/11/165Description: Four simple tips to help you debug your Keras code.6Accelerator: GPU7"""89"""10## Introduction1112It's generally possible to do almost anything in Keras *without writing code* per se:13whether you're implementing a new type of GAN or the latest convnet architecture for14image segmentation, you can usually stick to calling built-in methods. Because all15built-in methods do extensive input validation checks, you will have little to no16debugging to do. A Functional API model made entirely of built-in layers will work on17first try -- if you can compile it, it will run.1819However, sometimes, you will need to dive deeper and write your own code. Here are some20common examples:2122- Creating a new `Layer` subclass.23- Creating a custom `Metric` subclass.24- Implementing a custom `train_step` on a `Model`.2526This document provides a few simple tips to help you navigate debugging in these27situations.2829"""3031"""32## Tip 1: test each part before you test the whole3334If you've created any object that has a chance of not working as expected, don't just35drop it in your end-to-end process and watch sparks fly. Rather, test your custom object36in isolation first. This may seem obvious -- but you'd be surprised how often people37don't start with this.3839- If you write a custom layer, don't call `fit()` on your entire model just yet. Call40your layer on some test data first.41- If you write a custom metric, start by printing its output for some reference inputs.4243Here's a simple example. Let's write a custom layer a bug in it:4445"""4647import os4849# The last example uses tf.GradientTape and thus requires TensorFlow.50# However, all tips here are applicable with all backends.51os.environ["KERAS_BACKEND"] = "tensorflow"5253import keras54from keras import layers55from keras import ops56import numpy as np57import tensorflow as tf585960class MyAntirectifier(layers.Layer):61def build(self, input_shape):62output_dim = input_shape[-1]63self.kernel = self.add_weight(64shape=(output_dim * 2, output_dim),65initializer="he_normal",66name="kernel",67trainable=True,68)6970def call(self, inputs):71# Take the positive part of the input72pos = ops.relu(inputs)73# Take the negative part of the input74neg = ops.relu(-inputs)75# Concatenate the positive and negative parts76concatenated = ops.concatenate([pos, neg], axis=0)77# Project the concatenation down to the same dimensionality as the input78return ops.matmul(concatenated, self.kernel)798081"""82Now, rather than using it in a end-to-end model directly, let's try to call the layer on83some test data:8485```python86x = tf.random.normal(shape=(2, 5))87y = MyAntirectifier()(x)88```8990We get the following error:9192```93...941 x = tf.random.normal(shape=(2, 5))95----> 2 y = MyAntirectifier()(x)96...9717 neg = tf.nn.relu(-inputs)9818 concatenated = tf.concat([pos, neg], axis=0)99---> 19 return tf.matmul(concatenated, self.kernel)100...101InvalidArgumentError: Matrix size-incompatible: In[0]: [4,5], In[1]: [10,5] [Op:MatMul]102```103104Looks like our input tensor in the `matmul` op may have an incorrect shape.105Let's add a print statement to check the actual shapes:106107"""108109110class MyAntirectifier(layers.Layer):111def build(self, input_shape):112output_dim = input_shape[-1]113self.kernel = self.add_weight(114shape=(output_dim * 2, output_dim),115initializer="he_normal",116name="kernel",117trainable=True,118)119120def call(self, inputs):121pos = ops.relu(inputs)122neg = ops.relu(-inputs)123print("pos.shape:", pos.shape)124print("neg.shape:", neg.shape)125concatenated = ops.concatenate([pos, neg], axis=0)126print("concatenated.shape:", concatenated.shape)127print("kernel.shape:", self.kernel.shape)128return ops.matmul(concatenated, self.kernel)129130131"""132We get the following:133134```135pos.shape: (2, 5)136neg.shape: (2, 5)137concatenated.shape: (4, 5)138kernel.shape: (10, 5)139```140141Turns out we had the wrong axis for the `concat` op! We should be concatenating `neg` and142`pos` alongside the feature axis 1, not the batch axis 0. Here's the correct version:143"""144145146class MyAntirectifier(layers.Layer):147def build(self, input_shape):148output_dim = input_shape[-1]149self.kernel = self.add_weight(150shape=(output_dim * 2, output_dim),151initializer="he_normal",152name="kernel",153trainable=True,154)155156def call(self, inputs):157pos = ops.relu(inputs)158neg = ops.relu(-inputs)159print("pos.shape:", pos.shape)160print("neg.shape:", neg.shape)161concatenated = ops.concatenate([pos, neg], axis=1)162print("concatenated.shape:", concatenated.shape)163print("kernel.shape:", self.kernel.shape)164return ops.matmul(concatenated, self.kernel)165166167"""168Now our code works fine:169"""170171x = keras.random.normal(shape=(2, 5))172y = MyAntirectifier()(x)173174"""175## Tip 2: use `model.summary()` and `plot_model()` to check layer output shapes176177If you're working with complex network topologies, you're going to need a way178to visualize how your layers are connected and how they transform the data that passes179through them.180181Here's an example. Consider this model with three inputs and two outputs (lifted from the182[Functional API guide](https://keras.io/guides/functional_api/#manipulate-complex-graph-topologies)):183184"""185186num_tags = 12 # Number of unique issue tags187num_words = 10000 # Size of vocabulary obtained when preprocessing text data188num_departments = 4 # Number of departments for predictions189190title_input = keras.Input(191shape=(None,), name="title"192) # Variable-length sequence of ints193body_input = keras.Input(shape=(None,), name="body") # Variable-length sequence of ints194tags_input = keras.Input(195shape=(num_tags,), name="tags"196) # Binary vectors of size `num_tags`197198# Embed each word in the title into a 64-dimensional vector199title_features = layers.Embedding(num_words, 64)(title_input)200# Embed each word in the text into a 64-dimensional vector201body_features = layers.Embedding(num_words, 64)(body_input)202203# Reduce sequence of embedded words in the title into a single 128-dimensional vector204title_features = layers.LSTM(128)(title_features)205# Reduce sequence of embedded words in the body into a single 32-dimensional vector206body_features = layers.LSTM(32)(body_features)207208# Merge all available features into a single large vector via concatenation209x = layers.concatenate([title_features, body_features, tags_input])210211# Stick a logistic regression for priority prediction on top of the features212priority_pred = layers.Dense(1, name="priority")(x)213# Stick a department classifier on top of the features214department_pred = layers.Dense(num_departments, name="department")(x)215216# Instantiate an end-to-end model predicting both priority and department217model = keras.Model(218inputs=[title_input, body_input, tags_input],219outputs=[priority_pred, department_pred],220)221222"""223Calling `summary()` can help you check the output shape of each layer:224"""225226model.summary()227228"""229You can also visualize the entire network topology alongside output shapes using230`plot_model`:231"""232233keras.utils.plot_model(model, show_shapes=True)234235"""236With this plot, any connectivity-level error becomes immediately obvious.237"""238239"""240## Tip 3: to debug what happens during `fit()`, use `run_eagerly=True`241242The `fit()` method is fast: it runs a well-optimized, fully-compiled computation graph.243That's great for performance, but it also means that the code you're executing isn't the244Python code you've written. This can be problematic when debugging. As you may recall,245Python is slow -- so we use it as a staging language, not as an execution language.246247Thankfully, there's an easy way to run your code in "debug mode", fully eagerly:248pass `run_eagerly=True` to `compile()`. Your call to `fit()` will now get executed line249by line, without any optimization. It's slower, but it makes it possible to print the250value of intermediate tensors, or to use a Python debugger. Great for debugging.251252Here's a basic example: let's write a really simple model with a custom `train_step()` method.253Our model just implements gradient descent, but instead of first-order gradients,254it uses a combination of first-order and second-order gradients. Pretty simple so far.255256Can you spot what we're doing wrong?257"""258259260class MyModel(keras.Model):261def train_step(self, data):262inputs, targets = data263trainable_vars = self.trainable_variables264with tf.GradientTape() as tape2:265with tf.GradientTape() as tape1:266y_pred = self(inputs, training=True) # Forward pass267# Compute the loss value268# (the loss function is configured in `compile()`)269loss = self.compute_loss(y=targets, y_pred=y_pred)270# Compute first-order gradients271dl_dw = tape1.gradient(loss, trainable_vars)272# Compute second-order gradients273d2l_dw2 = tape2.gradient(dl_dw, trainable_vars)274275# Combine first-order and second-order gradients276grads = [0.5 * w1 + 0.5 * w2 for (w1, w2) in zip(d2l_dw2, dl_dw)]277278# Update weights279self.optimizer.apply_gradients(zip(grads, trainable_vars))280281# Update metrics (includes the metric that tracks the loss)282for metric in self.metrics:283if metric.name == "loss":284metric.update_state(loss)285else:286metric.update_state(targets, y_pred)287288# Return a dict mapping metric names to current value289return {m.name: m.result() for m in self.metrics}290291292"""293Let's train a one-layer model on MNIST with this custom loss function.294295We pick, somewhat at random, a batch size of 1024 and a learning rate of 0.1. The general296idea being to use larger batches and a larger learning rate than usual, since our297"improved" gradients should lead us to quicker convergence.298"""299300301# Construct an instance of MyModel302def get_model():303inputs = keras.Input(shape=(784,))304intermediate = layers.Dense(256, activation="relu")(inputs)305outputs = layers.Dense(10, activation="softmax")(intermediate)306model = MyModel(inputs, outputs)307return model308309310# Prepare data311(x_train, y_train), _ = keras.datasets.mnist.load_data()312x_train = np.reshape(x_train, (-1, 784)) / 255313314model = get_model()315model.compile(316optimizer=keras.optimizers.SGD(learning_rate=1e-2),317loss="sparse_categorical_crossentropy",318)319model.fit(x_train, y_train, epochs=3, batch_size=1024, validation_split=0.1)320321"""322Oh no, it doesn't converge! Something is not working as planned.323324Time for some step-by-step printing of what's going on with our gradients.325326We add various `print` statements in the `train_step` method, and we make sure to pass327`run_eagerly=True` to `compile()` to run our code step-by-step, eagerly.328"""329330331class MyModel(keras.Model):332def train_step(self, data):333print()334print("----Start of step: %d" % (self.step_counter,))335self.step_counter += 1336337inputs, targets = data338trainable_vars = self.trainable_variables339with tf.GradientTape() as tape2:340with tf.GradientTape() as tape1:341y_pred = self(inputs, training=True) # Forward pass342# Compute the loss value343# (the loss function is configured in `compile()`)344loss = self.compute_loss(y=targets, y_pred=y_pred)345# Compute first-order gradients346dl_dw = tape1.gradient(loss, trainable_vars)347# Compute second-order gradients348d2l_dw2 = tape2.gradient(dl_dw, trainable_vars)349350print("Max of dl_dw[0]: %.4f" % tf.reduce_max(dl_dw[0]))351print("Min of dl_dw[0]: %.4f" % tf.reduce_min(dl_dw[0]))352print("Mean of dl_dw[0]: %.4f" % tf.reduce_mean(dl_dw[0]))353print("-")354print("Max of d2l_dw2[0]: %.4f" % tf.reduce_max(d2l_dw2[0]))355print("Min of d2l_dw2[0]: %.4f" % tf.reduce_min(d2l_dw2[0]))356print("Mean of d2l_dw2[0]: %.4f" % tf.reduce_mean(d2l_dw2[0]))357358# Combine first-order and second-order gradients359grads = [0.5 * w1 + 0.5 * w2 for (w1, w2) in zip(d2l_dw2, dl_dw)]360361# Update weights362self.optimizer.apply_gradients(zip(grads, trainable_vars))363364# Update metrics (includes the metric that tracks the loss)365for metric in self.metrics:366if metric.name == "loss":367metric.update_state(loss)368else:369metric.update_state(targets, y_pred)370371# Return a dict mapping metric names to current value372return {m.name: m.result() for m in self.metrics}373374375model = get_model()376model.compile(377optimizer=keras.optimizers.SGD(learning_rate=1e-2),378loss="sparse_categorical_crossentropy",379metrics=["sparse_categorical_accuracy"],380run_eagerly=True,381)382model.step_counter = 0383# We pass epochs=1 and steps_per_epoch=10 to only run 10 steps of training.384model.fit(x_train, y_train, epochs=1, batch_size=1024, verbose=0, steps_per_epoch=10)385386"""387What did we learn?388389- The first order and second order gradients can have values that differ by orders of390magnitudes.391- Sometimes, they may not even have the same sign.392- Their values can vary greatly at each step.393394This leads us to an obvious idea: let's normalize the gradients before combining them.395"""396397398class MyModel(keras.Model):399def train_step(self, data):400inputs, targets = data401trainable_vars = self.trainable_variables402with tf.GradientTape() as tape2:403with tf.GradientTape() as tape1:404y_pred = self(inputs, training=True) # Forward pass405# Compute the loss value406# (the loss function is configured in `compile()`)407loss = self.compute_loss(y=targets, y_pred=y_pred)408# Compute first-order gradients409dl_dw = tape1.gradient(loss, trainable_vars)410# Compute second-order gradients411d2l_dw2 = tape2.gradient(dl_dw, trainable_vars)412413dl_dw = [tf.math.l2_normalize(w) for w in dl_dw]414d2l_dw2 = [tf.math.l2_normalize(w) for w in d2l_dw2]415416# Combine first-order and second-order gradients417grads = [0.5 * w1 + 0.5 * w2 for (w1, w2) in zip(d2l_dw2, dl_dw)]418419# Update weights420self.optimizer.apply_gradients(zip(grads, trainable_vars))421422# Update metrics (includes the metric that tracks the loss)423for metric in self.metrics:424if metric.name == "loss":425metric.update_state(loss)426else:427metric.update_state(targets, y_pred)428429# Return a dict mapping metric names to current value430return {m.name: m.result() for m in self.metrics}431432433model = get_model()434model.compile(435optimizer=keras.optimizers.SGD(learning_rate=1e-2),436loss="sparse_categorical_crossentropy",437metrics=["sparse_categorical_accuracy"],438)439model.fit(x_train, y_train, epochs=5, batch_size=1024, validation_split=0.1)440441"""442Now, training converges! It doesn't work well at all, but at least the model learns443something.444445After spending a few minutes tuning parameters, we get to the following configuration446that works somewhat well (achieves 97% validation accuracy and seems reasonably robust to447overfitting):448449- Use `0.2 * w1 + 0.8 * w2` for combining gradients.450- Use a learning rate that decays linearly over time.451452I'm not going to say that the idea works -- this isn't at all how you're supposed to do453second-order optimization (pointers: see the Newton & Gauss-Newton methods, quasi-Newton454methods, and BFGS). But hopefully this demonstration gave you an idea of how you can455debug your way out of uncomfortable training situations.456457Remember: use `run_eagerly=True` for debugging what happens in `fit()`. And when your code458is finally working as expected, make sure to remove this flag in order to get the best459runtime performance!460461Here's our final training run:462"""463464465class MyModel(keras.Model):466def train_step(self, data):467inputs, targets = data468trainable_vars = self.trainable_variables469with tf.GradientTape() as tape2:470with tf.GradientTape() as tape1:471y_pred = self(inputs, training=True) # Forward pass472# Compute the loss value473# (the loss function is configured in `compile()`)474loss = self.compute_loss(y=targets, y_pred=y_pred)475# Compute first-order gradients476dl_dw = tape1.gradient(loss, trainable_vars)477# Compute second-order gradients478d2l_dw2 = tape2.gradient(dl_dw, trainable_vars)479480dl_dw = [tf.math.l2_normalize(w) for w in dl_dw]481d2l_dw2 = [tf.math.l2_normalize(w) for w in d2l_dw2]482483# Combine first-order and second-order gradients484grads = [0.2 * w1 + 0.8 * w2 for (w1, w2) in zip(d2l_dw2, dl_dw)]485486# Update weights487self.optimizer.apply_gradients(zip(grads, trainable_vars))488489# Update metrics (includes the metric that tracks the loss)490for metric in self.metrics:491if metric.name == "loss":492metric.update_state(loss)493else:494metric.update_state(targets, y_pred)495496# Return a dict mapping metric names to current value497return {m.name: m.result() for m in self.metrics}498499500model = get_model()501lr = learning_rate = keras.optimizers.schedules.InverseTimeDecay(502initial_learning_rate=0.1, decay_steps=25, decay_rate=0.1503)504model.compile(505optimizer=keras.optimizers.SGD(lr),506loss="sparse_categorical_crossentropy",507metrics=["sparse_categorical_accuracy"],508)509model.fit(x_train, y_train, epochs=50, batch_size=2048, validation_split=0.1)510511512