Path: blob/master/examples/vision/forwardforward.py
3507 views
"""1Title: Using the Forward-Forward Algorithm for Image Classification2Author: [Suvaditya Mukherjee](https://twitter.com/halcyonrayes)3Date created: 2023/01/084Last modified: 2024/09/175Description: Training a Dense-layer model using the Forward-Forward algorithm.6Accelerator: GPU7"""89"""10## Introduction1112The following example explores how to use the Forward-Forward algorithm to perform13training instead of the traditionally-used method of backpropagation, as proposed by14Hinton in15[The Forward-Forward Algorithm: Some Preliminary Investigations](https://www.cs.toronto.edu/~hinton/FFA13.pdf)16(2022).1718The concept was inspired by the understanding behind19[Boltzmann Machines](http://www.cs.toronto.edu/~fritz/absps/dbm.pdf). Backpropagation20involves calculating the difference between actual and predicted output via a cost21function to adjust network weights. On the other hand, the FF Algorithm suggests the22analogy of neurons which get "excited" based on looking at a certain recognized23combination of an image and its correct corresponding label.2425This method takes certain inspiration from the biological learning process that occurs in26the cortex. A significant advantage that this method brings is the fact that27backpropagation through the network does not need to be performed anymore, and that28weight updates are local to the layer itself.2930As this is yet still an experimental method, it does not yield state-of-the-art results.31But with proper tuning, it is supposed to come close to the same.32Through this example, we will examine a process that allows us to implement the33Forward-Forward algorithm within the layers themselves, instead of the traditional method34of relying on the global loss functions and optimizers.3536The tutorial is structured as follows:3738- Perform necessary imports39- Load the [MNIST dataset](http://yann.lecun.com/exdb/mnist/)40- Visualize Random samples from the MNIST dataset41- Define a `FFDense` Layer to override `call` and implement a custom `forwardforward`42method which performs weight updates.43- Define a `FFNetwork` Layer to override `train_step`, `predict` and implement 2 custom44functions for per-sample prediction and overlaying labels45- Convert MNIST from `NumPy` arrays to `tf.data.Dataset`46- Fit the network47- Visualize results48- Perform inference on test samples4950As this example requires the customization of certain core functions with51`keras.layers.Layer` and `keras.models.Model`, refer to the following resources for52a primer on how to do so:5354- [Customizing what happens in `model.fit()`](https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit)55- [Making new Layers and Models via subclassing](https://www.tensorflow.org/guide/keras/custom_layers_and_models)56"""5758"""59## Setup imports60"""61import os6263os.environ["KERAS_BACKEND"] = "tensorflow"6465import tensorflow as tf66import keras67from keras import ops68import numpy as np69import matplotlib.pyplot as plt70from sklearn.metrics import accuracy_score71import random72from tensorflow.compiler.tf2xla.python import xla7374"""75## Load the dataset and visualize the data7677We use the `keras.datasets.mnist.load_data()` utility to directly pull the MNIST dataset78in the form of `NumPy` arrays. We then arrange it in the form of the train and test79splits.8081Following loading the dataset, we select 4 random samples from within the training set82and visualize them using `matplotlib.pyplot`.83"""8485(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()8687print("4 Random Training samples and labels")88idx1, idx2, idx3, idx4 = random.sample(range(0, x_train.shape[0]), 4)8990img1 = (x_train[idx1], y_train[idx1])91img2 = (x_train[idx2], y_train[idx2])92img3 = (x_train[idx3], y_train[idx3])93img4 = (x_train[idx4], y_train[idx4])9495imgs = [img1, img2, img3, img4]9697plt.figure(figsize=(10, 10))9899for idx, item in enumerate(imgs):100image, label = item[0], item[1]101plt.subplot(2, 2, idx + 1)102plt.imshow(image, cmap="gray")103plt.title(f"Label : {label}")104plt.show()105106"""107## Define `FFDense` custom layer108109In this custom layer, we have a base `keras.layers.Dense` object which acts as the110base `Dense` layer within. Since weight updates will happen within the layer itself, we111add an `keras.optimizers.Optimizer` object that is accepted from the user. Here, we112use `Adam` as our optimizer with a rather higher learning rate of `0.03`.113114Following the algorithm's specifics, we must set a `threshold` parameter that will be115used to make the positive-negative decision in each prediction. This is set to a default116of 2.0.117As the epochs are localized to the layer itself, we also set a `num_epochs` parameter118(defaults to 50).119120We override the `call` method in order to perform a normalization over the complete121input space followed by running it through the base `Dense` layer as would happen in a122normal `Dense` layer call.123124We implement the Forward-Forward algorithm which accepts 2 kinds of input tensors, each125representing the positive and negative samples respectively. We write a custom training126loop here with the use of `tf.GradientTape()`, within which we calculate a loss per127sample by taking the distance of the prediction from the threshold to understand the128error and taking its mean to get a `mean_loss` metric.129130With the help of `tf.GradientTape()` we calculate the gradient updates for the trainable131base `Dense` layer and apply them using the layer's local optimizer.132133Finally, we return the `call` result as the `Dense` results of the positive and negative134samples while also returning the last `mean_loss` metric and all the loss values over a135certain all-epoch run.136"""137138139class FFDense(keras.layers.Layer):140"""141A custom ForwardForward-enabled Dense layer. It has an implementation of the142Forward-Forward network internally for use.143This layer must be used in conjunction with the `FFNetwork` model.144"""145146def __init__(147self,148units,149init_optimizer,150loss_metric,151num_epochs=50,152use_bias=True,153kernel_initializer="glorot_uniform",154bias_initializer="zeros",155kernel_regularizer=None,156bias_regularizer=None,157**kwargs,158):159super().__init__(**kwargs)160self.dense = keras.layers.Dense(161units=units,162use_bias=use_bias,163kernel_initializer=kernel_initializer,164bias_initializer=bias_initializer,165kernel_regularizer=kernel_regularizer,166bias_regularizer=bias_regularizer,167)168self.relu = keras.layers.ReLU()169self.optimizer = init_optimizer()170self.loss_metric = loss_metric171self.threshold = 1.5172self.num_epochs = num_epochs173174# We perform a normalization step before we run the input through the Dense175# layer.176177def call(self, x):178x_norm = ops.norm(x, ord=2, axis=1, keepdims=True)179x_norm = x_norm + 1e-4180x_dir = x / x_norm181res = self.dense(x_dir)182return self.relu(res)183184# The Forward-Forward algorithm is below. We first perform the Dense-layer185# operation and then get a Mean Square value for all positive and negative186# samples respectively.187# The custom loss function finds the distance between the Mean-squared188# result and the threshold value we set (a hyperparameter) that will define189# whether the prediction is positive or negative in nature. Once the loss is190# calculated, we get a mean across the entire batch combined and perform a191# gradient calculation and optimization step. This does not technically192# qualify as backpropagation since there is no gradient being193# sent to any previous layer and is completely local in nature.194195def forward_forward(self, x_pos, x_neg):196for i in range(self.num_epochs):197with tf.GradientTape() as tape:198g_pos = ops.mean(ops.power(self.call(x_pos), 2), 1)199g_neg = ops.mean(ops.power(self.call(x_neg), 2), 1)200201loss = ops.log(2021203+ ops.exp(204ops.concatenate(205[-g_pos + self.threshold, g_neg - self.threshold], 0206)207)208)209mean_loss = ops.cast(ops.mean(loss), dtype="float32")210self.loss_metric.update_state([mean_loss])211gradients = tape.gradient(mean_loss, self.dense.trainable_weights)212self.optimizer.apply_gradients(zip(gradients, self.dense.trainable_weights))213return (214ops.stop_gradient(self.call(x_pos)),215ops.stop_gradient(self.call(x_neg)),216self.loss_metric.result(),217)218219220"""221## Define the `FFNetwork` Custom Model222223With our custom layer defined, we also need to override the `train_step` method and224define a custom `keras.models.Model` that works with our `FFDense` layer.225226For this algorithm, we must 'embed' the labels onto the original image. To do so, we227exploit the structure of MNIST images where the top-left 10 pixels are always zeros. We228use that as a label space in order to visually one-hot-encode the labels within the image229itself. This action is performed by the `overlay_y_on_x` function.230231We break down the prediction function with a per-sample prediction function which is then232called over the entire test set by the overriden `predict()` function. The prediction is233performed here with the help of measuring the `excitation` of the neurons per layer for234each image. This is then summed over all layers to calculate a network-wide 'goodness235score'. The label with the highest 'goodness score' is then chosen as the sample236prediction.237238The `train_step` function is overriden to act as the main controlling loop for running239training on each layer as per the number of epochs per layer.240"""241242243class FFNetwork(keras.Model):244"""245A `keras.Model` that supports a `FFDense` network creation. This model246can work for any kind of classification task. It has an internal247implementation with some details specific to the MNIST dataset which can be248changed as per the use-case.249"""250251# Since each layer runs gradient-calculation and optimization locally, each252# layer has its own optimizer that we pass. As a standard choice, we pass253# the `Adam` optimizer with a default learning rate of 0.03 as that was254# found to be the best rate after experimentation.255# Loss is tracked using `loss_var` and `loss_count` variables.256257def __init__(258self,259dims,260init_layer_optimizer=lambda: keras.optimizers.Adam(learning_rate=0.03),261**kwargs,262):263super().__init__(**kwargs)264self.init_layer_optimizer = init_layer_optimizer265self.loss_var = keras.Variable(0.0, trainable=False, dtype="float32")266self.loss_count = keras.Variable(0.0, trainable=False, dtype="float32")267self.layer_list = [keras.Input(shape=(dims[0],))]268self.metrics_built = False269for d in range(len(dims) - 1):270self.layer_list += [271FFDense(272dims[d + 1],273init_optimizer=self.init_layer_optimizer,274loss_metric=keras.metrics.Mean(),275)276]277278# This function makes a dynamic change to the image wherein the labels are279# put on top of the original image (for this example, as MNIST has 10280# unique labels, we take the top-left corner's first 10 pixels). This281# function returns the original data tensor with the first 10 pixels being282# a pixel-based one-hot representation of the labels.283284@tf.function(reduce_retracing=True)285def overlay_y_on_x(self, data):286X_sample, y_sample = data287max_sample = ops.amax(X_sample, axis=0, keepdims=True)288max_sample = ops.cast(max_sample, dtype="float64")289X_zeros = ops.zeros([10], dtype="float64")290X_update = xla.dynamic_update_slice(X_zeros, max_sample, [y_sample])291X_sample = xla.dynamic_update_slice(X_sample, X_update, [0])292return X_sample, y_sample293294# A custom `predict_one_sample` performs predictions by passing the images295# through the network, measures the results produced by each layer (i.e.296# how high/low the output values are with respect to the set threshold for297# each label) and then simply finding the label with the highest values.298# In such a case, the images are tested for their 'goodness' with all299# labels.300301@tf.function(reduce_retracing=True)302def predict_one_sample(self, x):303goodness_per_label = []304x = ops.reshape(x, [ops.shape(x)[0] * ops.shape(x)[1]])305for label in range(10):306h, label = self.overlay_y_on_x(data=(x, label))307h = ops.reshape(h, [-1, ops.shape(h)[0]])308goodness = []309for layer_idx in range(1, len(self.layer_list)):310layer = self.layer_list[layer_idx]311h = layer(h)312goodness += [ops.mean(ops.power(h, 2), 1)]313goodness_per_label += [ops.expand_dims(ops.sum(goodness, keepdims=True), 1)]314goodness_per_label = tf.concat(goodness_per_label, 1)315return ops.cast(ops.argmax(goodness_per_label, 1), dtype="float64")316317def predict(self, data):318x = data319preds = list()320preds = ops.vectorized_map(self.predict_one_sample, x)321return np.asarray(preds, dtype=int)322323# This custom `train_step` function overrides the internal `train_step`324# implementation. We take all the input image tensors, flatten them and325# subsequently produce positive and negative samples on the images.326# A positive sample is an image that has the right label encoded on it with327# the `overlay_y_on_x` function. A negative sample is an image that has an328# erroneous label present on it.329# With the samples ready, we pass them through each `FFLayer` and perform330# the Forward-Forward computation on it. The returned loss is the final331# loss value over all the layers.332333@tf.function(jit_compile=False)334def train_step(self, data):335x, y = data336337if not self.metrics_built:338# build metrics to ensure they can be queried without erroring out.339# We can't update the metrics' state, as we would usually do, since340# we do not perform predictions within the train step341for metric in self.metrics:342if hasattr(metric, "build"):343metric.build(y, y)344self.metrics_built = True345346# Flatten op347x = ops.reshape(x, [-1, ops.shape(x)[1] * ops.shape(x)[2]])348349x_pos, y = ops.vectorized_map(self.overlay_y_on_x, (x, y))350351random_y = tf.random.shuffle(y)352x_neg, y = tf.map_fn(self.overlay_y_on_x, (x, random_y))353354h_pos, h_neg = x_pos, x_neg355356for idx, layer in enumerate(self.layers):357if isinstance(layer, FFDense):358print(f"Training layer {idx+1} now : ")359h_pos, h_neg, loss = layer.forward_forward(h_pos, h_neg)360self.loss_var.assign_add(loss)361self.loss_count.assign_add(1.0)362else:363print(f"Passing layer {idx+1} now : ")364x = layer(x)365mean_res = ops.divide(self.loss_var, self.loss_count)366return {"FinalLoss": mean_res}367368369"""370## Convert MNIST `NumPy` arrays to `tf.data.Dataset`371372We now perform some preliminary processing on the `NumPy` arrays and then convert them373into the `tf.data.Dataset` format which allows for optimized loading.374"""375376x_train = x_train.astype(float) / 255377x_test = x_test.astype(float) / 255378y_train = y_train.astype(int)379y_test = y_test.astype(int)380381train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))382test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))383384train_dataset = train_dataset.batch(60000)385test_dataset = test_dataset.batch(10000)386387"""388## Fit the network and visualize results389390Having performed all previous set-up, we are now going to run `model.fit()` and run 250391model epochs, which will perform 50*250 epochs on each layer. We get to see the plotted loss392curve as each layer is trained.393"""394395model = FFNetwork(dims=[784, 500, 500])396397model.compile(398optimizer=keras.optimizers.Adam(learning_rate=0.03),399loss="mse",400jit_compile=False,401metrics=[],402)403404epochs = 250405history = model.fit(train_dataset, epochs=epochs)406407"""408## Perform inference and testing409410Having trained the model to a large extent, we now see how it performs on the411test set. We calculate the Accuracy Score to understand the results closely.412"""413414preds = model.predict(ops.convert_to_tensor(x_test))415416preds = preds.reshape((preds.shape[0], preds.shape[1]))417418results = accuracy_score(preds, y_test)419420print(f"Test Accuracy score : {results*100}%")421422plt.plot(range(len(history.history["FinalLoss"])), history.history["FinalLoss"])423plt.title("Loss over training")424plt.show()425426"""427## Conclusion428429This example has hereby demonstrated how the Forward-Forward algorithm works using430the TensorFlow and Keras packages. While the investigation results presented by Prof. Hinton431in their paper are currently still limited to smaller models and datasets like MNIST and432Fashion-MNIST, subsequent results on larger models like LLMs are expected in future433papers.434435Through the paper, Prof. Hinton has reported results of 1.36% test accuracy error with a4362000-units, 4 hidden-layer, fully-connected network run over 60 epochs (while mentioning437that backpropagation takes only 20 epochs to achieve similar performance). Another run of438doubling the learning rate and training for 40 epochs yields a slightly worse error rate439of 1.46%440441The current example does not yield state-of-the-art results. But with proper tuning of442the Learning Rate, model architecture (number of units in `Dense` layers, kernel443activations, initializations, regularization etc.), the results can be improved444to match the claims of the paper.445"""446447448