Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/forwardforward.py
3507 views
1
"""
2
Title: Using the Forward-Forward Algorithm for Image Classification
3
Author: [Suvaditya Mukherjee](https://twitter.com/halcyonrayes)
4
Date created: 2023/01/08
5
Last modified: 2024/09/17
6
Description: Training a Dense-layer model using the Forward-Forward algorithm.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
The following example explores how to use the Forward-Forward algorithm to perform
14
training instead of the traditionally-used method of backpropagation, as proposed by
15
Hinton in
16
[The Forward-Forward Algorithm: Some Preliminary Investigations](https://www.cs.toronto.edu/~hinton/FFA13.pdf)
17
(2022).
18
19
The concept was inspired by the understanding behind
20
[Boltzmann Machines](http://www.cs.toronto.edu/~fritz/absps/dbm.pdf). Backpropagation
21
involves calculating the difference between actual and predicted output via a cost
22
function to adjust network weights. On the other hand, the FF Algorithm suggests the
23
analogy of neurons which get "excited" based on looking at a certain recognized
24
combination of an image and its correct corresponding label.
25
26
This method takes certain inspiration from the biological learning process that occurs in
27
the cortex. A significant advantage that this method brings is the fact that
28
backpropagation through the network does not need to be performed anymore, and that
29
weight updates are local to the layer itself.
30
31
As this is yet still an experimental method, it does not yield state-of-the-art results.
32
But with proper tuning, it is supposed to come close to the same.
33
Through this example, we will examine a process that allows us to implement the
34
Forward-Forward algorithm within the layers themselves, instead of the traditional method
35
of relying on the global loss functions and optimizers.
36
37
The tutorial is structured as follows:
38
39
- Perform necessary imports
40
- Load the [MNIST dataset](http://yann.lecun.com/exdb/mnist/)
41
- Visualize Random samples from the MNIST dataset
42
- Define a `FFDense` Layer to override `call` and implement a custom `forwardforward`
43
method which performs weight updates.
44
- Define a `FFNetwork` Layer to override `train_step`, `predict` and implement 2 custom
45
functions for per-sample prediction and overlaying labels
46
- Convert MNIST from `NumPy` arrays to `tf.data.Dataset`
47
- Fit the network
48
- Visualize results
49
- Perform inference on test samples
50
51
As this example requires the customization of certain core functions with
52
`keras.layers.Layer` and `keras.models.Model`, refer to the following resources for
53
a primer on how to do so:
54
55
- [Customizing what happens in `model.fit()`](https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit)
56
- [Making new Layers and Models via subclassing](https://www.tensorflow.org/guide/keras/custom_layers_and_models)
57
"""
58
59
"""
60
## Setup imports
61
"""
62
import os
63
64
os.environ["KERAS_BACKEND"] = "tensorflow"
65
66
import tensorflow as tf
67
import keras
68
from keras import ops
69
import numpy as np
70
import matplotlib.pyplot as plt
71
from sklearn.metrics import accuracy_score
72
import random
73
from tensorflow.compiler.tf2xla.python import xla
74
75
"""
76
## Load the dataset and visualize the data
77
78
We use the `keras.datasets.mnist.load_data()` utility to directly pull the MNIST dataset
79
in the form of `NumPy` arrays. We then arrange it in the form of the train and test
80
splits.
81
82
Following loading the dataset, we select 4 random samples from within the training set
83
and visualize them using `matplotlib.pyplot`.
84
"""
85
86
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
87
88
print("4 Random Training samples and labels")
89
idx1, idx2, idx3, idx4 = random.sample(range(0, x_train.shape[0]), 4)
90
91
img1 = (x_train[idx1], y_train[idx1])
92
img2 = (x_train[idx2], y_train[idx2])
93
img3 = (x_train[idx3], y_train[idx3])
94
img4 = (x_train[idx4], y_train[idx4])
95
96
imgs = [img1, img2, img3, img4]
97
98
plt.figure(figsize=(10, 10))
99
100
for idx, item in enumerate(imgs):
101
image, label = item[0], item[1]
102
plt.subplot(2, 2, idx + 1)
103
plt.imshow(image, cmap="gray")
104
plt.title(f"Label : {label}")
105
plt.show()
106
107
"""
108
## Define `FFDense` custom layer
109
110
In this custom layer, we have a base `keras.layers.Dense` object which acts as the
111
base `Dense` layer within. Since weight updates will happen within the layer itself, we
112
add an `keras.optimizers.Optimizer` object that is accepted from the user. Here, we
113
use `Adam` as our optimizer with a rather higher learning rate of `0.03`.
114
115
Following the algorithm's specifics, we must set a `threshold` parameter that will be
116
used to make the positive-negative decision in each prediction. This is set to a default
117
of 2.0.
118
As the epochs are localized to the layer itself, we also set a `num_epochs` parameter
119
(defaults to 50).
120
121
We override the `call` method in order to perform a normalization over the complete
122
input space followed by running it through the base `Dense` layer as would happen in a
123
normal `Dense` layer call.
124
125
We implement the Forward-Forward algorithm which accepts 2 kinds of input tensors, each
126
representing the positive and negative samples respectively. We write a custom training
127
loop here with the use of `tf.GradientTape()`, within which we calculate a loss per
128
sample by taking the distance of the prediction from the threshold to understand the
129
error and taking its mean to get a `mean_loss` metric.
130
131
With the help of `tf.GradientTape()` we calculate the gradient updates for the trainable
132
base `Dense` layer and apply them using the layer's local optimizer.
133
134
Finally, we return the `call` result as the `Dense` results of the positive and negative
135
samples while also returning the last `mean_loss` metric and all the loss values over a
136
certain all-epoch run.
137
"""
138
139
140
class FFDense(keras.layers.Layer):
141
"""
142
A custom ForwardForward-enabled Dense layer. It has an implementation of the
143
Forward-Forward network internally for use.
144
This layer must be used in conjunction with the `FFNetwork` model.
145
"""
146
147
def __init__(
148
self,
149
units,
150
init_optimizer,
151
loss_metric,
152
num_epochs=50,
153
use_bias=True,
154
kernel_initializer="glorot_uniform",
155
bias_initializer="zeros",
156
kernel_regularizer=None,
157
bias_regularizer=None,
158
**kwargs,
159
):
160
super().__init__(**kwargs)
161
self.dense = keras.layers.Dense(
162
units=units,
163
use_bias=use_bias,
164
kernel_initializer=kernel_initializer,
165
bias_initializer=bias_initializer,
166
kernel_regularizer=kernel_regularizer,
167
bias_regularizer=bias_regularizer,
168
)
169
self.relu = keras.layers.ReLU()
170
self.optimizer = init_optimizer()
171
self.loss_metric = loss_metric
172
self.threshold = 1.5
173
self.num_epochs = num_epochs
174
175
# We perform a normalization step before we run the input through the Dense
176
# layer.
177
178
def call(self, x):
179
x_norm = ops.norm(x, ord=2, axis=1, keepdims=True)
180
x_norm = x_norm + 1e-4
181
x_dir = x / x_norm
182
res = self.dense(x_dir)
183
return self.relu(res)
184
185
# The Forward-Forward algorithm is below. We first perform the Dense-layer
186
# operation and then get a Mean Square value for all positive and negative
187
# samples respectively.
188
# The custom loss function finds the distance between the Mean-squared
189
# result and the threshold value we set (a hyperparameter) that will define
190
# whether the prediction is positive or negative in nature. Once the loss is
191
# calculated, we get a mean across the entire batch combined and perform a
192
# gradient calculation and optimization step. This does not technically
193
# qualify as backpropagation since there is no gradient being
194
# sent to any previous layer and is completely local in nature.
195
196
def forward_forward(self, x_pos, x_neg):
197
for i in range(self.num_epochs):
198
with tf.GradientTape() as tape:
199
g_pos = ops.mean(ops.power(self.call(x_pos), 2), 1)
200
g_neg = ops.mean(ops.power(self.call(x_neg), 2), 1)
201
202
loss = ops.log(
203
1
204
+ ops.exp(
205
ops.concatenate(
206
[-g_pos + self.threshold, g_neg - self.threshold], 0
207
)
208
)
209
)
210
mean_loss = ops.cast(ops.mean(loss), dtype="float32")
211
self.loss_metric.update_state([mean_loss])
212
gradients = tape.gradient(mean_loss, self.dense.trainable_weights)
213
self.optimizer.apply_gradients(zip(gradients, self.dense.trainable_weights))
214
return (
215
ops.stop_gradient(self.call(x_pos)),
216
ops.stop_gradient(self.call(x_neg)),
217
self.loss_metric.result(),
218
)
219
220
221
"""
222
## Define the `FFNetwork` Custom Model
223
224
With our custom layer defined, we also need to override the `train_step` method and
225
define a custom `keras.models.Model` that works with our `FFDense` layer.
226
227
For this algorithm, we must 'embed' the labels onto the original image. To do so, we
228
exploit the structure of MNIST images where the top-left 10 pixels are always zeros. We
229
use that as a label space in order to visually one-hot-encode the labels within the image
230
itself. This action is performed by the `overlay_y_on_x` function.
231
232
We break down the prediction function with a per-sample prediction function which is then
233
called over the entire test set by the overriden `predict()` function. The prediction is
234
performed here with the help of measuring the `excitation` of the neurons per layer for
235
each image. This is then summed over all layers to calculate a network-wide 'goodness
236
score'. The label with the highest 'goodness score' is then chosen as the sample
237
prediction.
238
239
The `train_step` function is overriden to act as the main controlling loop for running
240
training on each layer as per the number of epochs per layer.
241
"""
242
243
244
class FFNetwork(keras.Model):
245
"""
246
A `keras.Model` that supports a `FFDense` network creation. This model
247
can work for any kind of classification task. It has an internal
248
implementation with some details specific to the MNIST dataset which can be
249
changed as per the use-case.
250
"""
251
252
# Since each layer runs gradient-calculation and optimization locally, each
253
# layer has its own optimizer that we pass. As a standard choice, we pass
254
# the `Adam` optimizer with a default learning rate of 0.03 as that was
255
# found to be the best rate after experimentation.
256
# Loss is tracked using `loss_var` and `loss_count` variables.
257
258
def __init__(
259
self,
260
dims,
261
init_layer_optimizer=lambda: keras.optimizers.Adam(learning_rate=0.03),
262
**kwargs,
263
):
264
super().__init__(**kwargs)
265
self.init_layer_optimizer = init_layer_optimizer
266
self.loss_var = keras.Variable(0.0, trainable=False, dtype="float32")
267
self.loss_count = keras.Variable(0.0, trainable=False, dtype="float32")
268
self.layer_list = [keras.Input(shape=(dims[0],))]
269
self.metrics_built = False
270
for d in range(len(dims) - 1):
271
self.layer_list += [
272
FFDense(
273
dims[d + 1],
274
init_optimizer=self.init_layer_optimizer,
275
loss_metric=keras.metrics.Mean(),
276
)
277
]
278
279
# This function makes a dynamic change to the image wherein the labels are
280
# put on top of the original image (for this example, as MNIST has 10
281
# unique labels, we take the top-left corner's first 10 pixels). This
282
# function returns the original data tensor with the first 10 pixels being
283
# a pixel-based one-hot representation of the labels.
284
285
@tf.function(reduce_retracing=True)
286
def overlay_y_on_x(self, data):
287
X_sample, y_sample = data
288
max_sample = ops.amax(X_sample, axis=0, keepdims=True)
289
max_sample = ops.cast(max_sample, dtype="float64")
290
X_zeros = ops.zeros([10], dtype="float64")
291
X_update = xla.dynamic_update_slice(X_zeros, max_sample, [y_sample])
292
X_sample = xla.dynamic_update_slice(X_sample, X_update, [0])
293
return X_sample, y_sample
294
295
# A custom `predict_one_sample` performs predictions by passing the images
296
# through the network, measures the results produced by each layer (i.e.
297
# how high/low the output values are with respect to the set threshold for
298
# each label) and then simply finding the label with the highest values.
299
# In such a case, the images are tested for their 'goodness' with all
300
# labels.
301
302
@tf.function(reduce_retracing=True)
303
def predict_one_sample(self, x):
304
goodness_per_label = []
305
x = ops.reshape(x, [ops.shape(x)[0] * ops.shape(x)[1]])
306
for label in range(10):
307
h, label = self.overlay_y_on_x(data=(x, label))
308
h = ops.reshape(h, [-1, ops.shape(h)[0]])
309
goodness = []
310
for layer_idx in range(1, len(self.layer_list)):
311
layer = self.layer_list[layer_idx]
312
h = layer(h)
313
goodness += [ops.mean(ops.power(h, 2), 1)]
314
goodness_per_label += [ops.expand_dims(ops.sum(goodness, keepdims=True), 1)]
315
goodness_per_label = tf.concat(goodness_per_label, 1)
316
return ops.cast(ops.argmax(goodness_per_label, 1), dtype="float64")
317
318
def predict(self, data):
319
x = data
320
preds = list()
321
preds = ops.vectorized_map(self.predict_one_sample, x)
322
return np.asarray(preds, dtype=int)
323
324
# This custom `train_step` function overrides the internal `train_step`
325
# implementation. We take all the input image tensors, flatten them and
326
# subsequently produce positive and negative samples on the images.
327
# A positive sample is an image that has the right label encoded on it with
328
# the `overlay_y_on_x` function. A negative sample is an image that has an
329
# erroneous label present on it.
330
# With the samples ready, we pass them through each `FFLayer` and perform
331
# the Forward-Forward computation on it. The returned loss is the final
332
# loss value over all the layers.
333
334
@tf.function(jit_compile=False)
335
def train_step(self, data):
336
x, y = data
337
338
if not self.metrics_built:
339
# build metrics to ensure they can be queried without erroring out.
340
# We can't update the metrics' state, as we would usually do, since
341
# we do not perform predictions within the train step
342
for metric in self.metrics:
343
if hasattr(metric, "build"):
344
metric.build(y, y)
345
self.metrics_built = True
346
347
# Flatten op
348
x = ops.reshape(x, [-1, ops.shape(x)[1] * ops.shape(x)[2]])
349
350
x_pos, y = ops.vectorized_map(self.overlay_y_on_x, (x, y))
351
352
random_y = tf.random.shuffle(y)
353
x_neg, y = tf.map_fn(self.overlay_y_on_x, (x, random_y))
354
355
h_pos, h_neg = x_pos, x_neg
356
357
for idx, layer in enumerate(self.layers):
358
if isinstance(layer, FFDense):
359
print(f"Training layer {idx+1} now : ")
360
h_pos, h_neg, loss = layer.forward_forward(h_pos, h_neg)
361
self.loss_var.assign_add(loss)
362
self.loss_count.assign_add(1.0)
363
else:
364
print(f"Passing layer {idx+1} now : ")
365
x = layer(x)
366
mean_res = ops.divide(self.loss_var, self.loss_count)
367
return {"FinalLoss": mean_res}
368
369
370
"""
371
## Convert MNIST `NumPy` arrays to `tf.data.Dataset`
372
373
We now perform some preliminary processing on the `NumPy` arrays and then convert them
374
into the `tf.data.Dataset` format which allows for optimized loading.
375
"""
376
377
x_train = x_train.astype(float) / 255
378
x_test = x_test.astype(float) / 255
379
y_train = y_train.astype(int)
380
y_test = y_test.astype(int)
381
382
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
383
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
384
385
train_dataset = train_dataset.batch(60000)
386
test_dataset = test_dataset.batch(10000)
387
388
"""
389
## Fit the network and visualize results
390
391
Having performed all previous set-up, we are now going to run `model.fit()` and run 250
392
model epochs, which will perform 50*250 epochs on each layer. We get to see the plotted loss
393
curve as each layer is trained.
394
"""
395
396
model = FFNetwork(dims=[784, 500, 500])
397
398
model.compile(
399
optimizer=keras.optimizers.Adam(learning_rate=0.03),
400
loss="mse",
401
jit_compile=False,
402
metrics=[],
403
)
404
405
epochs = 250
406
history = model.fit(train_dataset, epochs=epochs)
407
408
"""
409
## Perform inference and testing
410
411
Having trained the model to a large extent, we now see how it performs on the
412
test set. We calculate the Accuracy Score to understand the results closely.
413
"""
414
415
preds = model.predict(ops.convert_to_tensor(x_test))
416
417
preds = preds.reshape((preds.shape[0], preds.shape[1]))
418
419
results = accuracy_score(preds, y_test)
420
421
print(f"Test Accuracy score : {results*100}%")
422
423
plt.plot(range(len(history.history["FinalLoss"])), history.history["FinalLoss"])
424
plt.title("Loss over training")
425
plt.show()
426
427
"""
428
## Conclusion
429
430
This example has hereby demonstrated how the Forward-Forward algorithm works using
431
the TensorFlow and Keras packages. While the investigation results presented by Prof. Hinton
432
in their paper are currently still limited to smaller models and datasets like MNIST and
433
Fashion-MNIST, subsequent results on larger models like LLMs are expected in future
434
papers.
435
436
Through the paper, Prof. Hinton has reported results of 1.36% test accuracy error with a
437
2000-units, 4 hidden-layer, fully-connected network run over 60 epochs (while mentioning
438
that backpropagation takes only 20 epochs to achieve similar performance). Another run of
439
doubling the learning rate and training for 40 epochs yields a slightly worse error rate
440
of 1.46%
441
442
The current example does not yield state-of-the-art results. But with proper tuning of
443
the Learning Rate, model architecture (number of units in `Dense` layers, kernel
444
activations, initializations, regularization etc.), the results can be improved
445
to match the claims of the paper.
446
"""
447
448