Path: blob/master/examples/keras_recipes/sklearn_metric_callbacks.py
3507 views
"""1Title: Evaluating and exporting scikit-learn metrics in a Keras callback2Author: [lukewood](https://lukewood.xyz)3Date created: 10/07/20214Last modified: 11/17/20235Description: This example shows how to use Keras callbacks to evaluate and export non-TensorFlow based metrics.6Accelerator: GPU7"""89"""10## Introduction1112[Keras callbacks](https://keras.io/api/callbacks/) allow for the execution of arbitrary13code at various stages of the Keras training process. While Keras offers first-class14support for metric evaluation, [Keras metrics](https://keras.io/api/metrics/) may only15rely on TensorFlow code internally.1617While there are TensorFlow implementations of many metrics online, some metrics are18implemented using [NumPy](https://numpy.org/) or another Python-based numerical computation library.19By performing metric evaluation inside of a Keras callback, we can leverage any existing20metric, and ultimately export the result to TensorBoard.21"""2223"""24## Jaccard score metric2526This example makes use of a sklearn metric, `sklearn.metrics.jaccard_score()`, and27writes the result to TensorBoard using the `tf.summary` API.2829This template can be modified slightly to make it work with any existing sklearn metric.30"""3132import os3334os.environ["KERAS_BACKEND"] = "tensorflow"3536import tensorflow as tf37import keras as keras38from keras import layers39from sklearn.metrics import jaccard_score40import numpy as np41import os424344class JaccardScoreCallback(keras.callbacks.Callback):45"""Computes the Jaccard score and logs the results to TensorBoard."""4647def __init__(self, name, x_test, y_test, log_dir):48self.x_test = x_test49self.y_test = y_test50self.keras_metric = keras.metrics.Mean("jaccard_score")51self.epoch = 052self.summary_writer = tf.summary.create_file_writer(os.path.join(log_dir, name))5354def on_epoch_end(self, batch, logs=None):55self.epoch += 156self.keras_metric.reset_state()57predictions = self.model.predict(self.x_test)58jaccard_value = jaccard_score(59np.argmax(predictions, axis=-1), self.y_test, average=None60)61self.keras_metric.update_state(jaccard_value)62self._write_metric(63self.keras_metric.name, self.keras_metric.result().numpy().astype(float)64)6566def _write_metric(self, name, value):67with self.summary_writer.as_default():68tf.summary.scalar(69name,70value,71step=self.epoch,72)73self.summary_writer.flush()747576"""77## Sample usage7879Let's test our `JaccardScoreCallback` class with a Keras model.80"""81# Model / data parameters82num_classes = 1083input_shape = (28, 28, 1)8485# The data, split between train and test sets86(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()8788# Scale images to the [0, 1] range89x_train = x_train.astype("float32") / 25590x_test = x_test.astype("float32") / 25591# Make sure images have shape (28, 28, 1)92x_train = np.expand_dims(x_train, -1)93x_test = np.expand_dims(x_test, -1)94print("x_train shape:", x_train.shape)95print(x_train.shape[0], "train samples")96print(x_test.shape[0], "test samples")979899# Convert class vectors to binary class matrices.100y_train = keras.utils.to_categorical(y_train, num_classes)101y_test = keras.utils.to_categorical(y_test, num_classes)102103model = keras.Sequential(104[105keras.Input(shape=input_shape),106layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),107layers.MaxPooling2D(pool_size=(2, 2)),108layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),109layers.MaxPooling2D(pool_size=(2, 2)),110layers.Flatten(),111layers.Dropout(0.5),112layers.Dense(num_classes, activation="softmax"),113]114)115116model.summary()117118batch_size = 128119epochs = 15120121model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])122callbacks = [123JaccardScoreCallback(model.name, x_test, np.argmax(y_test, axis=-1), "logs")124]125model.fit(126x_train,127y_train,128batch_size=batch_size,129epochs=epochs,130validation_split=0.1,131callbacks=callbacks,132)133134"""135If you now launch a TensorBoard instance using `tensorboard --logdir=logs`, you will136see the `jaccard_score` metric alongside any other exported metrics!137138139"""140141"""142## Conclusion143144Many ML practitioners and researchers rely on metrics that may not yet have a TensorFlow145implementation. Keras users can still leverage the wide variety of existing metric146implementations in other frameworks by using a Keras callback. These metrics can be147exported, viewed and analyzed in the TensorBoard like any other metric.148"""149150151