Path: blob/master/examples/keras_recipes/reproducibility_recipes.py
3507 views
"""1Title: Reproducibility in Keras Models2Author: [Frightera](https://github.com/Frightera)3Date created: 2023/05/054Last modified: 2023/05/055Description: Demonstration of random weight initialization and reproducibility in Keras models.6Accelerator: GPU7"""89"""10## Introduction1112This example demonstrates how to control randomness in Keras models. Sometimes13you may want to reproduce the exact same results across runs, for experimentation14purposes or to debug a problem.15"""1617"""18## Setup19"""20import json21import numpy as np22import tensorflow as tf23import keras24from keras import layers25from keras import initializers2627# Set the seed using keras.utils.set_random_seed. This will set:28# 1) `numpy` seed29# 2) backend random seed30# 3) `python` random seed31keras.utils.set_random_seed(812)3233# If using TensorFlow, this will make GPU ops as deterministic as possible,34# but it will affect the overall performance, so be mindful of that.35tf.config.experimental.enable_op_determinism()363738"""39## Weight initialization in Keras4041Most of the layers in Keras have `kernel_initializer` and `bias_initializer`42parameters. These parameters allow you to specify the strategy used for43initializing the weights of layer variables. The following built-in initializers44are available as part of `keras.initializers`:45"""4647initializers_list = [48initializers.RandomNormal,49initializers.RandomUniform,50initializers.TruncatedNormal,51initializers.VarianceScaling,52initializers.GlorotNormal,53initializers.GlorotUniform,54initializers.HeNormal,55initializers.HeUniform,56initializers.LecunNormal,57initializers.LecunUniform,58initializers.Orthogonal,59]6061"""62In a reproducible model, the weights of the model should be initialized with63same values in subsequent runs. First, we'll check how initializers behave when64they are called multiple times with same `seed` value.65"""6667for initializer in initializers_list:68print(f"Running {initializer}")6970for iteration in range(2):71# In order to get same results across multiple runs from an initializer,72# you can specify a seed value.73result = float(initializer(seed=42)(shape=(1, 1)))74print(f"\tIteration --> {iteration} // Result --> {result}")75print("\n")767778"""79Now, let's inspect how two different initializer objects behave when they are80have the same seed value.81"""8283# Setting the seed value for an initializer will cause two different objects84# to produce same results.85glorot_normal_1 = keras.initializers.GlorotNormal(seed=42)86glorot_normal_2 = keras.initializers.GlorotNormal(seed=42)8788input_dim, neurons = 3, 58990# Call two different objects with same shape91result_1 = glorot_normal_1(shape=(input_dim, neurons))92result_2 = glorot_normal_2(shape=(input_dim, neurons))9394# Check if the results are equal.95equal = np.allclose(result_1, result_2)96print(f"Are the results equal? {equal}")9798"""99If the seed value is not set (or different seed values are used), two different100objects will produce different results. Since the random seed is set at the beginning101of the notebook, the results will be same in the sequential runs. This is related102to the `keras.utils.set_random_seed`.103"""104105glorot_normal_3 = keras.initializers.GlorotNormal()106glorot_normal_4 = keras.initializers.GlorotNormal()107108# Let's call the initializer.109result_3 = glorot_normal_3(shape=(input_dim, neurons))110111# Call the second initializer.112result_4 = glorot_normal_4(shape=(input_dim, neurons))113114equal = np.allclose(result_3, result_4)115print(f"Are the results equal? {equal}")116117"""118`result_3` and `result_4` will be different, but when you run the notebook119again, `result_3` will have identical values to the ones in the previous run.120Same goes for `result_4`.121"""122123"""124## Reproducibility in model training process125If you want to reproduce the results of a model training process, you need to126control the randomness sources during the training process. In order to show a127realistic example, this section utilizes `tf.data` using parallel map and shuffle128operations.129130In order to start, let's create a simple function which returns the history131object of the Keras model.132"""133134135def train_model(train_data: tf.data.Dataset, test_data: tf.data.Dataset) -> dict:136model = keras.Sequential(137[138layers.Conv2D(32, (3, 3), activation="relu"),139layers.MaxPooling2D((2, 2)),140layers.Dropout(0.2),141layers.Conv2D(32, (3, 3), activation="relu"),142layers.MaxPooling2D((2, 2)),143layers.Dropout(0.2),144layers.Conv2D(32, (3, 3), activation="relu"),145layers.GlobalAveragePooling2D(),146layers.Dense(64, activation="relu"),147layers.Dropout(0.2),148layers.Dense(10, activation="softmax"),149]150)151152model.compile(153optimizer="adam",154loss="sparse_categorical_crossentropy",155metrics=["accuracy"],156jit_compile=False,157)158# jit_compile's default value is "auto" which will cause some problems in some159# ops, therefore it's set to False.160161# model.fit has a `shuffle` parameter which has a default value of `True`.162# If you are using array-like objects, this will shuffle the data before163# training. This argument is ignored when `x` is a generator or164# `tf.data.Dataset`.165history = model.fit(train_data, epochs=2, validation_data=test_data)166167print(f"Model accuracy on test data: {model.evaluate(test_data)[1] * 100:.2f}%")168169return history.history170171172# Load the MNIST dataset173(train_images, train_labels), (174test_images,175test_labels,176) = keras.datasets.mnist.load_data()177178# Construct tf.data.Dataset objects179train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels))180test_ds = tf.data.Dataset.from_tensor_slices((test_images, test_labels))181182"""183Remember we called `tf.config.experimental.enable_op_determinism()` at the184beginning of the function. This makes the `tf.data` operations deterministic.185However, making `tf.data` operations deterministic comes with a performance186cost. If you want to learn more about it, please check this187[official guide](https://www.tensorflow.org/api_docs/python/tf/config/experimental/enable_op_determinism#determinism_and_tfdata).188189Small summary what's going on here. Models have `kernel_initializer` and190`bias_initializer` parameters. Since we set random seeds using191`keras.utils.set_random_seed` in the beginning of the notebook, the initializers192will produce same results in the sequential runs. Additionally, TensorFlow193operations have now become deterministic. Frequently, you will be utilizing GPUs194that have thousands of hardware threads which causes non-deterministic behavior195to occur.196"""197198199def prepare_dataset(image, label):200# Cast and normalize the image201image = tf.cast(image, tf.float32) / 255.0202203# Expand the channel dimension204image = tf.expand_dims(image, axis=-1)205206# Resize the image207image = tf.image.resize(image, (32, 32))208209return image, label210211212"""213`tf.data.Dataset` objects have a `shuffle` method which shuffles the data.214This method has a `buffer_size` parameter which controls the size of the215buffer. If you set this value to `len(train_images)`, the whole dataset will216be shuffled. If the buffer size is equal to the length of the dataset,217then the elements will be shuffled in a completely random order.218219Main drawback of setting the buffer size to the length of the dataset is that220filling the buffer can take a while depending on the size of the dataset.221222Here is a small summary of what's going on here:2231) The `shuffle()` method creates a buffer of the specified size.2242) The elements of the dataset are randomly shuffled and placed into the buffer.2253) The elements of the buffer are then returned in a random order.226227Since `tf.config.experimental.enable_op_determinism()` is enabled and we set228random seeds using `keras.utils.set_random_seed` in the beginning of the229notebook, the `shuffle()` method will produce same results in the sequential230runs.231"""232# Prepare the datasets, batch-map --> vectorized operations233train_data = (234train_ds.shuffle(buffer_size=len(train_images))235.batch(batch_size=64)236.map(prepare_dataset, num_parallel_calls=tf.data.AUTOTUNE)237.prefetch(buffer_size=tf.data.AUTOTUNE)238)239240test_data = (241test_ds.batch(batch_size=64)242.map(prepare_dataset, num_parallel_calls=tf.data.AUTOTUNE)243.prefetch(buffer_size=tf.data.AUTOTUNE)244)245246"""247Train the model for the first time.248"""249250history = train_model(train_data, test_data)251252"""253Let's save our results into a JSON file, and restart the kernel. After254restarting the kernel, we should see the same results as the previous run,255this includes metrics and loss values both on the training and test data.256"""257258# Save the history object into a json file259with open("history.json", "w") as fp:260json.dump(history, fp)261262"""263Do not run the cell above in order not to overwrite the results. Execute the264model training cell again and compare the results.265"""266267with open("history.json", "r") as fp:268history_loaded = json.load(fp)269270271"""272Compare the results one by one. You will see that they are equal.273"""274for key in history.keys():275for i in range(len(history[key])):276if not np.allclose(history[key][i], history_loaded[key][i]):277print(f"{key} not equal")278279"""280## Conclusion281282In this tutorial, you learned how to control the randomness sources in Keras and283TensorFlow. You also learned how to reproduce the results of a model training284process.285286If you want to initialize the model with the same weights everytime, you need to287set `kernel_initializer` and `bias_initializer` parameters of the layers and provide288a `seed` value to the initializer.289290There still may be some inconsistencies due to numerical error accumulation such291as using `recurrent_dropout` in RNN layers.292293Reproducibility is subject to the environment. You'll get the same results if you294run the notebook or the code on the same machine with the same environment.295"""296297298