Path: blob/master/examples/keras_recipes/endpoint_layer_pattern.py
3507 views
"""1Title: Endpoint layer pattern2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2019/05/104Last modified: 2023/11/225Description: Demonstration of the "endpoint layer" pattern (layer that handles loss management).6Accelerator: GPU7"""89"""10## Setup11"""1213import os1415os.environ["KERAS_BACKEND"] = "tensorflow"1617import tensorflow as tf18import keras19import numpy as np2021"""22## Usage of endpoint layers in the Functional API2324An "endpoint layer" has access to the model's targets, and creates arbitrary losses25in `call()` using `self.add_loss()` and `Metric.update_state()`.26This enables you to define losses and27metrics that don't match the usual signature `fn(y_true, y_pred, sample_weight=None)`.2829Note that you could have separate metrics for training and eval with this pattern.30"""313233class LogisticEndpoint(keras.layers.Layer):34def __init__(self, name=None):35super().__init__(name=name)36self.loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)37self.accuracy_metric = keras.metrics.BinaryAccuracy(name="accuracy")3839def call(self, logits, targets=None, sample_weight=None):40if targets is not None:41# Compute the training-time loss value and add it42# to the layer using `self.add_loss()`.43loss = self.loss_fn(targets, logits, sample_weight)44self.add_loss(loss)4546# Log the accuracy as a metric (we could log arbitrary metrics,47# including different metrics for training and inference.)48self.accuracy_metric.update_state(targets, logits, sample_weight)4950# Return the inference-time prediction tensor (for `.predict()`).51return tf.nn.softmax(logits)525354inputs = keras.Input((764,), name="inputs")55logits = keras.layers.Dense(1)(inputs)56targets = keras.Input((1,), name="targets")57sample_weight = keras.Input((1,), name="sample_weight")58preds = LogisticEndpoint()(logits, targets, sample_weight)59model = keras.Model([inputs, targets, sample_weight], preds)6061data = {62"inputs": np.random.random((1000, 764)),63"targets": np.random.random((1000, 1)),64"sample_weight": np.random.random((1000, 1)),65}6667model.compile(keras.optimizers.Adam(1e-3))68model.fit(data, epochs=2)6970"""71## Exporting an inference-only model7273Simply don't include `targets` in the model. The weights stay the same.74"""7576inputs = keras.Input((764,), name="inputs")77logits = keras.layers.Dense(1)(inputs)78preds = LogisticEndpoint()(logits, targets=None, sample_weight=None)79inference_model = keras.Model(inputs, preds)8081inference_model.set_weights(model.get_weights())8283preds = inference_model.predict(np.random.random((1000, 764)))8485"""86## Usage of loss endpoint layers in subclassed models87"""888990class LogReg(keras.Model):91def __init__(self):92super().__init__()93self.dense = keras.layers.Dense(1)94self.logistic_endpoint = LogisticEndpoint()9596def call(self, inputs):97# Note that all inputs should be in the first argument98# since we want to be able to call `model.fit(inputs)`.99logits = self.dense(inputs["inputs"])100preds = self.logistic_endpoint(101logits=logits,102targets=inputs["targets"],103sample_weight=inputs["sample_weight"],104)105return preds106107108model = LogReg()109data = {110"inputs": np.random.random((1000, 764)),111"targets": np.random.random((1000, 1)),112"sample_weight": np.random.random((1000, 1)),113}114115model.compile(keras.optimizers.Adam(1e-3))116model.fit(data, epochs=2)117118119