Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/keras_recipes/endpoint_layer_pattern.py
3507 views
1
"""
2
Title: Endpoint layer pattern
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2019/05/10
5
Last modified: 2023/11/22
6
Description: Demonstration of the "endpoint layer" pattern (layer that handles loss management).
7
Accelerator: GPU
8
"""
9
10
"""
11
## Setup
12
"""
13
14
import os
15
16
os.environ["KERAS_BACKEND"] = "tensorflow"
17
18
import tensorflow as tf
19
import keras
20
import numpy as np
21
22
"""
23
## Usage of endpoint layers in the Functional API
24
25
An "endpoint layer" has access to the model's targets, and creates arbitrary losses
26
in `call()` using `self.add_loss()` and `Metric.update_state()`.
27
This enables you to define losses and
28
metrics that don't match the usual signature `fn(y_true, y_pred, sample_weight=None)`.
29
30
Note that you could have separate metrics for training and eval with this pattern.
31
"""
32
33
34
class LogisticEndpoint(keras.layers.Layer):
35
def __init__(self, name=None):
36
super().__init__(name=name)
37
self.loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
38
self.accuracy_metric = keras.metrics.BinaryAccuracy(name="accuracy")
39
40
def call(self, logits, targets=None, sample_weight=None):
41
if targets is not None:
42
# Compute the training-time loss value and add it
43
# to the layer using `self.add_loss()`.
44
loss = self.loss_fn(targets, logits, sample_weight)
45
self.add_loss(loss)
46
47
# Log the accuracy as a metric (we could log arbitrary metrics,
48
# including different metrics for training and inference.)
49
self.accuracy_metric.update_state(targets, logits, sample_weight)
50
51
# Return the inference-time prediction tensor (for `.predict()`).
52
return tf.nn.softmax(logits)
53
54
55
inputs = keras.Input((764,), name="inputs")
56
logits = keras.layers.Dense(1)(inputs)
57
targets = keras.Input((1,), name="targets")
58
sample_weight = keras.Input((1,), name="sample_weight")
59
preds = LogisticEndpoint()(logits, targets, sample_weight)
60
model = keras.Model([inputs, targets, sample_weight], preds)
61
62
data = {
63
"inputs": np.random.random((1000, 764)),
64
"targets": np.random.random((1000, 1)),
65
"sample_weight": np.random.random((1000, 1)),
66
}
67
68
model.compile(keras.optimizers.Adam(1e-3))
69
model.fit(data, epochs=2)
70
71
"""
72
## Exporting an inference-only model
73
74
Simply don't include `targets` in the model. The weights stay the same.
75
"""
76
77
inputs = keras.Input((764,), name="inputs")
78
logits = keras.layers.Dense(1)(inputs)
79
preds = LogisticEndpoint()(logits, targets=None, sample_weight=None)
80
inference_model = keras.Model(inputs, preds)
81
82
inference_model.set_weights(model.get_weights())
83
84
preds = inference_model.predict(np.random.random((1000, 764)))
85
86
"""
87
## Usage of loss endpoint layers in subclassed models
88
"""
89
90
91
class LogReg(keras.Model):
92
def __init__(self):
93
super().__init__()
94
self.dense = keras.layers.Dense(1)
95
self.logistic_endpoint = LogisticEndpoint()
96
97
def call(self, inputs):
98
# Note that all inputs should be in the first argument
99
# since we want to be able to call `model.fit(inputs)`.
100
logits = self.dense(inputs["inputs"])
101
preds = self.logistic_endpoint(
102
logits=logits,
103
targets=inputs["targets"],
104
sample_weight=inputs["sample_weight"],
105
)
106
return preds
107
108
109
model = LogReg()
110
data = {
111
"inputs": np.random.random((1000, 764)),
112
"targets": np.random.random((1000, 1)),
113
"sample_weight": np.random.random((1000, 1)),
114
}
115
116
model.compile(keras.optimizers.Adam(1e-3))
117
model.fit(data, epochs=2)
118
119