Path: blob/master/examples/keras_recipes/tensorflow_numpy_models.py
3507 views
"""1Title: Writing Keras Models With TensorFlow NumPy2Author: [lukewood](https://lukewood.xyz)3Date created: 2021/08/284Last modified: 2021/08/285Description: Overview of how to use the TensorFlow NumPy API to write Keras models.6Accelerator: GPU7"""89"""10## Introduction1112[NumPy](https://numpy.org/) is a hugely successful Python linear algebra library.1314TensorFlow recently launched [tf_numpy](https://www.tensorflow.org/guide/tf_numpy), a15TensorFlow implementation of a large subset of the NumPy API.16Thanks to `tf_numpy`, you can write Keras layers or models in the NumPy style!1718The TensorFlow NumPy API has full integration with the TensorFlow ecosystem.19Features such as automatic differentiation, TensorBoard, Keras model callbacks,20TPU distribution and model exporting are all supported.2122Let's run through a few examples.23"""2425"""26## Setup27"""2829import os3031os.environ["KERAS_BACKEND"] = "tensorflow"3233import tensorflow as tf34import tensorflow.experimental.numpy as tnp35import keras36from keras import layers3738"""39To test our models we will use the Boston housing prices regression dataset.40"""4142(x_train, y_train), (x_test, y_test) = keras.datasets.boston_housing.load_data(43path="boston_housing.npz", test_split=0.2, seed=11344)45input_dim = x_train.shape[1]464748def evaluate_model(model: keras.Model):49loss, percent_error = model.evaluate(x_test, y_test, verbose=0)50print("Mean absolute percent error before training: ", percent_error)51model.fit(x_train, y_train, epochs=200, verbose=0)52loss, percent_error = model.evaluate(x_test, y_test, verbose=0)53print("Mean absolute percent error after training:", percent_error)545556"""57## Subclassing keras.Model with TNP5859The most flexible way to make use of the Keras API is to subclass the60[`keras.Model`](https://keras.io/api/models/model/) class. Subclassing the Model class61gives you the ability to fully customize what occurs in the training loop. This makes62subclassing Model a popular option for researchers.6364In this example, we will implement a `Model` subclass that performs regression over the65boston housing dataset using the TNP API. Note that differentiation and gradient66descent is handled automatically when using the TNP API alongside keras.6768First let's define a simple `TNPForwardFeedRegressionNetwork` class.69"""707172class TNPForwardFeedRegressionNetwork(keras.Model):73def __init__(self, blocks=None, **kwargs):74super().__init__(**kwargs)75if not isinstance(blocks, list):76raise ValueError(f"blocks must be a list, got blocks={blocks}")77self.blocks = blocks78self.block_weights = None79self.biases = None8081def build(self, input_shape):82current_shape = input_shape[1]83self.block_weights = []84self.biases = []85for i, block in enumerate(self.blocks):86self.block_weights.append(87self.add_weight(88shape=(current_shape, block),89trainable=True,90name=f"block-{i}",91initializer="glorot_normal",92)93)94self.biases.append(95self.add_weight(96shape=(block,),97trainable=True,98name=f"bias-{i}",99initializer="zeros",100)101)102current_shape = block103104self.linear_layer = self.add_weight(105shape=(current_shape, 1),106name="linear_projector",107trainable=True,108initializer="glorot_normal",109)110111def call(self, inputs):112activations = inputs113for w, b in zip(self.block_weights, self.biases):114activations = tnp.matmul(activations, w) + b115# ReLu activation function116activations = tnp.maximum(activations, 0.0)117118return tnp.matmul(activations, self.linear_layer)119120121"""122Just like with any other Keras model we can utilize any supported optimizer, loss,123metrics or callbacks that we want.124125Let's see how the model performs!126"""127128model = TNPForwardFeedRegressionNetwork(blocks=[3, 3])129model.compile(130optimizer="adam",131loss="mean_squared_error",132metrics=[keras.metrics.MeanAbsolutePercentageError()],133)134evaluate_model(model)135136"""137Great! Our model seems to be effectively learning to solve the problem at hand.138139We can also write our own custom loss function using TNP.140"""141142143def tnp_mse(y_true, y_pred):144return tnp.mean(tnp.square(y_true - y_pred), axis=0)145146147keras.backend.clear_session()148model = TNPForwardFeedRegressionNetwork(blocks=[3, 3])149model.compile(150optimizer="adam",151loss=tnp_mse,152metrics=[keras.metrics.MeanAbsolutePercentageError()],153)154evaluate_model(model)155156"""157## Implementing a Keras Layer Based Model with TNP158159If desired, TNP can also be used in layer oriented Keras code structure. Let's160implement the same model, but using a layered approach!161"""162163164def tnp_relu(x):165return tnp.maximum(x, 0)166167168class TNPDense(keras.layers.Layer):169def __init__(self, units, activation=None):170super().__init__()171self.units = units172self.activation = activation173174def build(self, input_shape):175self.w = self.add_weight(176name="weights",177shape=(input_shape[1], self.units),178initializer="random_normal",179trainable=True,180)181self.bias = self.add_weight(182name="bias",183shape=(self.units,),184initializer="zeros",185trainable=True,186)187188def call(self, inputs):189outputs = tnp.matmul(inputs, self.w) + self.bias190if self.activation:191return self.activation(outputs)192return outputs193194195def create_layered_tnp_model():196return keras.Sequential(197[198TNPDense(3, activation=tnp_relu),199TNPDense(3, activation=tnp_relu),200TNPDense(1),201]202)203204205model = create_layered_tnp_model()206model.compile(207optimizer="adam",208loss="mean_squared_error",209metrics=[keras.metrics.MeanAbsolutePercentageError()],210)211model.build((None, input_dim))212model.summary()213214evaluate_model(model)215216"""217You can also seamlessly switch between TNP layers and native Keras layers!218"""219220221def create_mixed_model():222return keras.Sequential(223[224TNPDense(3, activation=tnp_relu),225# The model will have no issue using a normal Dense layer226layers.Dense(3, activation="relu"),227# ... or switching back to tnp layers!228TNPDense(1),229]230)231232233model = create_mixed_model()234model.compile(235optimizer="adam",236loss="mean_squared_error",237metrics=[keras.metrics.MeanAbsolutePercentageError()],238)239model.build((None, input_dim))240model.summary()241242evaluate_model(model)243244"""245The Keras API offers a wide variety of layers. The ability to use them alongside NumPy246code can be a huge time saver in projects.247"""248249"""250## Distribution Strategy251252TensorFlow NumPy and Keras integrate with253[TensorFlow Distribution Strategies](https://www.tensorflow.org/guide/distributed_training).254This makes it simple to perform distributed training across multiple GPUs,255or even an entire TPU Pod.256"""257258gpus = tf.config.list_logical_devices("GPU")259if gpus:260strategy = tf.distribute.MirroredStrategy(gpus)261else:262# We can fallback to a no-op CPU strategy.263strategy = tf.distribute.get_strategy()264print("Running with strategy:", str(strategy.__class__.__name__))265266with strategy.scope():267model = create_layered_tnp_model()268model.compile(269optimizer="adam",270loss="mean_squared_error",271metrics=[keras.metrics.MeanAbsolutePercentageError()],272)273model.build((None, input_dim))274model.summary()275evaluate_model(model)276277"""278## TensorBoard Integration279280One of the many benefits of using the Keras API is the ability to monitor training281through TensorBoard. Using the TensorFlow NumPy API alongside Keras allows you to easily282leverage TensorBoard.283"""284285keras.backend.clear_session()286287"""288To load the TensorBoard from a Jupyter notebook, you can run the following magic:289```290%load_ext tensorboard291```292293"""294295models = [296(297TNPForwardFeedRegressionNetwork(blocks=[3, 3]),298"TNPForwardFeedRegressionNetwork",299),300(create_layered_tnp_model(), "layered_tnp_model"),301(create_mixed_model(), "mixed_model"),302]303for model, model_name in models:304model.compile(305optimizer="adam",306loss="mean_squared_error",307metrics=[keras.metrics.MeanAbsolutePercentageError()],308)309model.fit(310x_train,311y_train,312epochs=200,313verbose=0,314callbacks=[keras.callbacks.TensorBoard(log_dir=f"logs/{model_name}")],315)316317"""318To load the TensorBoard from a Jupyter notebook you can use the `%tensorboard` magic:319320```321%tensorboard --logdir logs322```323324The TensorBoard monitor metrics and examine the training curve.325326327328The TensorBoard also allows you to explore the computation graph used in your models.329330331332The ability to introspect into your models can be valuable during debugging.333"""334335"""336## Conclusion337338Porting existing NumPy code to Keras models using the `tensorflow_numpy` API is easy!339By integrating with Keras you gain the ability to use existing Keras callbacks, metrics340and optimizers, easily distribute your training and use Tensorboard.341342Migrating a more complex model, such as a ResNet, to the TensorFlow NumPy API would be a343great follow up learning exercise.344345Several open source NumPy ResNet implementations are available online.346"""347348349