Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/keras_recipes/sklearn_metric_callbacks.py
3507 views
1
"""
2
Title: Evaluating and exporting scikit-learn metrics in a Keras callback
3
Author: [lukewood](https://lukewood.xyz)
4
Date created: 10/07/2021
5
Last modified: 11/17/2023
6
Description: This example shows how to use Keras callbacks to evaluate and export non-TensorFlow based metrics.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
[Keras callbacks](https://keras.io/api/callbacks/) allow for the execution of arbitrary
14
code at various stages of the Keras training process. While Keras offers first-class
15
support for metric evaluation, [Keras metrics](https://keras.io/api/metrics/) may only
16
rely on TensorFlow code internally.
17
18
While there are TensorFlow implementations of many metrics online, some metrics are
19
implemented using [NumPy](https://numpy.org/) or another Python-based numerical computation library.
20
By performing metric evaluation inside of a Keras callback, we can leverage any existing
21
metric, and ultimately export the result to TensorBoard.
22
"""
23
24
"""
25
## Jaccard score metric
26
27
This example makes use of a sklearn metric, `sklearn.metrics.jaccard_score()`, and
28
writes the result to TensorBoard using the `tf.summary` API.
29
30
This template can be modified slightly to make it work with any existing sklearn metric.
31
"""
32
33
import os
34
35
os.environ["KERAS_BACKEND"] = "tensorflow"
36
37
import tensorflow as tf
38
import keras as keras
39
from keras import layers
40
from sklearn.metrics import jaccard_score
41
import numpy as np
42
import os
43
44
45
class JaccardScoreCallback(keras.callbacks.Callback):
46
"""Computes the Jaccard score and logs the results to TensorBoard."""
47
48
def __init__(self, name, x_test, y_test, log_dir):
49
self.x_test = x_test
50
self.y_test = y_test
51
self.keras_metric = keras.metrics.Mean("jaccard_score")
52
self.epoch = 0
53
self.summary_writer = tf.summary.create_file_writer(os.path.join(log_dir, name))
54
55
def on_epoch_end(self, batch, logs=None):
56
self.epoch += 1
57
self.keras_metric.reset_state()
58
predictions = self.model.predict(self.x_test)
59
jaccard_value = jaccard_score(
60
np.argmax(predictions, axis=-1), self.y_test, average=None
61
)
62
self.keras_metric.update_state(jaccard_value)
63
self._write_metric(
64
self.keras_metric.name, self.keras_metric.result().numpy().astype(float)
65
)
66
67
def _write_metric(self, name, value):
68
with self.summary_writer.as_default():
69
tf.summary.scalar(
70
name,
71
value,
72
step=self.epoch,
73
)
74
self.summary_writer.flush()
75
76
77
"""
78
## Sample usage
79
80
Let's test our `JaccardScoreCallback` class with a Keras model.
81
"""
82
# Model / data parameters
83
num_classes = 10
84
input_shape = (28, 28, 1)
85
86
# The data, split between train and test sets
87
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
88
89
# Scale images to the [0, 1] range
90
x_train = x_train.astype("float32") / 255
91
x_test = x_test.astype("float32") / 255
92
# Make sure images have shape (28, 28, 1)
93
x_train = np.expand_dims(x_train, -1)
94
x_test = np.expand_dims(x_test, -1)
95
print("x_train shape:", x_train.shape)
96
print(x_train.shape[0], "train samples")
97
print(x_test.shape[0], "test samples")
98
99
100
# Convert class vectors to binary class matrices.
101
y_train = keras.utils.to_categorical(y_train, num_classes)
102
y_test = keras.utils.to_categorical(y_test, num_classes)
103
104
model = keras.Sequential(
105
[
106
keras.Input(shape=input_shape),
107
layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
108
layers.MaxPooling2D(pool_size=(2, 2)),
109
layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
110
layers.MaxPooling2D(pool_size=(2, 2)),
111
layers.Flatten(),
112
layers.Dropout(0.5),
113
layers.Dense(num_classes, activation="softmax"),
114
]
115
)
116
117
model.summary()
118
119
batch_size = 128
120
epochs = 15
121
122
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
123
callbacks = [
124
JaccardScoreCallback(model.name, x_test, np.argmax(y_test, axis=-1), "logs")
125
]
126
model.fit(
127
x_train,
128
y_train,
129
batch_size=batch_size,
130
epochs=epochs,
131
validation_split=0.1,
132
callbacks=callbacks,
133
)
134
135
"""
136
If you now launch a TensorBoard instance using `tensorboard --logdir=logs`, you will
137
see the `jaccard_score` metric alongside any other exported metrics!
138
139
![TensorBoard Jaccard Score](https://i.imgur.com/T4qzrdn.png)
140
"""
141
142
"""
143
## Conclusion
144
145
Many ML practitioners and researchers rely on metrics that may not yet have a TensorFlow
146
implementation. Keras users can still leverage the wide variety of existing metric
147
implementations in other frameworks by using a Keras callback. These metrics can be
148
exported, viewed and analyzed in the TensorBoard like any other metric.
149
"""
150
151