Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/gradient_centralization.py
3507 views
1
"""
2
Title: Gradient Centralization for Better Training Performance
3
Author: [Rishit Dagli](https://github.com/Rishit-dagli)
4
Date created: 06/18/21
5
Last modified: 07/25/23
6
Description: Implement Gradient Centralization to improve training performance of DNNs.
7
Accelerator: GPU
8
Converted to Keras 3 by: [Muhammad Anas Raza](https://anasrz.com)
9
"""
10
11
"""
12
## Introduction
13
14
This example implements [Gradient Centralization](https://arxiv.org/abs/2004.01461), a
15
new optimization technique for Deep Neural Networks by Yong et al., and demonstrates it
16
on Laurence Moroney's [Horses or Humans
17
Dataset](https://www.tensorflow.org/datasets/catalog/horses_or_humans). Gradient
18
Centralization can both speedup training process and improve the final generalization
19
performance of DNNs. It operates directly on gradients by centralizing the gradient
20
vectors to have zero mean. Gradient Centralization morever improves the Lipschitzness of
21
the loss function and its gradient so that the training process becomes more efficient
22
and stable.
23
24
This example requires `tensorflow_datasets` which can be installed with this command:
25
26
```
27
pip install tensorflow-datasets
28
```
29
"""
30
31
"""
32
## Setup
33
"""
34
35
from time import time
36
37
import keras
38
from keras import layers
39
from keras.optimizers import RMSprop
40
from keras import ops
41
42
from tensorflow import data as tf_data
43
import tensorflow_datasets as tfds
44
45
46
"""
47
## Prepare the data
48
49
For this example, we will be using the [Horses or Humans
50
dataset](https://www.tensorflow.org/datasets/catalog/horses_or_humans).
51
"""
52
53
num_classes = 2
54
input_shape = (300, 300, 3)
55
dataset_name = "horses_or_humans"
56
batch_size = 128
57
AUTOTUNE = tf_data.AUTOTUNE
58
59
(train_ds, test_ds), metadata = tfds.load(
60
name=dataset_name,
61
split=[tfds.Split.TRAIN, tfds.Split.TEST],
62
with_info=True,
63
as_supervised=True,
64
)
65
66
print(f"Image shape: {metadata.features['image'].shape}")
67
print(f"Training images: {metadata.splits['train'].num_examples}")
68
print(f"Test images: {metadata.splits['test'].num_examples}")
69
70
"""
71
## Use Data Augmentation
72
73
We will rescale the data to `[0, 1]` and perform simple augmentations to our data.
74
"""
75
76
rescale = layers.Rescaling(1.0 / 255)
77
78
data_augmentation = [
79
layers.RandomFlip("horizontal_and_vertical"),
80
layers.RandomRotation(0.3),
81
layers.RandomZoom(0.2),
82
]
83
84
85
# Helper to apply augmentation
86
def apply_aug(x):
87
for aug in data_augmentation:
88
x = aug(x)
89
return x
90
91
92
def prepare(ds, shuffle=False, augment=False):
93
# Rescale dataset
94
ds = ds.map(lambda x, y: (rescale(x), y), num_parallel_calls=AUTOTUNE)
95
96
if shuffle:
97
ds = ds.shuffle(1024)
98
99
# Batch dataset
100
ds = ds.batch(batch_size)
101
102
# Use data augmentation only on the training set
103
if augment:
104
ds = ds.map(
105
lambda x, y: (apply_aug(x), y),
106
num_parallel_calls=AUTOTUNE,
107
)
108
109
# Use buffered prefecting
110
return ds.prefetch(buffer_size=AUTOTUNE)
111
112
113
"""
114
Rescale and augment the data
115
"""
116
117
train_ds = prepare(train_ds, shuffle=True, augment=True)
118
test_ds = prepare(test_ds)
119
"""
120
## Define a model
121
122
In this section we will define a Convolutional neural network.
123
"""
124
125
model = keras.Sequential(
126
[
127
layers.Input(shape=input_shape),
128
layers.Conv2D(16, (3, 3), activation="relu"),
129
layers.MaxPooling2D(2, 2),
130
layers.Conv2D(32, (3, 3), activation="relu"),
131
layers.Dropout(0.5),
132
layers.MaxPooling2D(2, 2),
133
layers.Conv2D(64, (3, 3), activation="relu"),
134
layers.Dropout(0.5),
135
layers.MaxPooling2D(2, 2),
136
layers.Conv2D(64, (3, 3), activation="relu"),
137
layers.MaxPooling2D(2, 2),
138
layers.Conv2D(64, (3, 3), activation="relu"),
139
layers.MaxPooling2D(2, 2),
140
layers.Flatten(),
141
layers.Dropout(0.5),
142
layers.Dense(512, activation="relu"),
143
layers.Dense(1, activation="sigmoid"),
144
]
145
)
146
147
"""
148
## Implement Gradient Centralization
149
150
We will now
151
subclass the `RMSProp` optimizer class modifying the
152
`keras.optimizers.Optimizer.get_gradients()` method where we now implement Gradient
153
Centralization. On a high level the idea is that let us say we obtain our gradients
154
through back propagation for a Dense or Convolution layer we then compute the mean of the
155
column vectors of the weight matrix, and then remove the mean from each column vector.
156
157
The experiments in [this paper](https://arxiv.org/abs/2004.01461) on various
158
applications, including general image classification, fine-grained image classification,
159
detection and segmentation and Person ReID demonstrate that GC can consistently improve
160
the performance of DNN learning.
161
162
Also, for simplicity at the moment we are not implementing gradient cliiping functionality,
163
however this quite easy to implement.
164
165
At the moment we are just creating a subclass for the `RMSProp` optimizer
166
however you could easily reproduce this for any other optimizer or on a custom
167
optimizer in the same way. We will be using this class in the later section when
168
we train a model with Gradient Centralization.
169
"""
170
171
172
class GCRMSprop(RMSprop):
173
def get_gradients(self, loss, params):
174
# We here just provide a modified get_gradients() function since we are
175
# trying to just compute the centralized gradients.
176
177
grads = []
178
gradients = super().get_gradients()
179
for grad in gradients:
180
grad_len = len(grad.shape)
181
if grad_len > 1:
182
axis = list(range(grad_len - 1))
183
grad -= ops.mean(grad, axis=axis, keep_dims=True)
184
grads.append(grad)
185
186
return grads
187
188
189
optimizer = GCRMSprop(learning_rate=1e-4)
190
191
"""
192
## Training utilities
193
194
We will also create a callback which allows us to easily measure the total training time
195
and the time taken for each epoch since we are interested in comparing the effect of
196
Gradient Centralization on the model we built above.
197
"""
198
199
200
class TimeHistory(keras.callbacks.Callback):
201
def on_train_begin(self, logs={}):
202
self.times = []
203
204
def on_epoch_begin(self, batch, logs={}):
205
self.epoch_time_start = time()
206
207
def on_epoch_end(self, batch, logs={}):
208
self.times.append(time() - self.epoch_time_start)
209
210
211
"""
212
## Train the model without GC
213
214
We now train the model we built earlier without Gradient Centralization which we can
215
compare to the training performance of the model trained with Gradient Centralization.
216
"""
217
218
time_callback_no_gc = TimeHistory()
219
model.compile(
220
loss="binary_crossentropy",
221
optimizer=RMSprop(learning_rate=1e-4),
222
metrics=["accuracy"],
223
)
224
225
model.summary()
226
227
"""
228
We also save the history since we later want to compare our model trained with and not
229
trained with Gradient Centralization
230
"""
231
232
history_no_gc = model.fit(
233
train_ds, epochs=10, verbose=1, callbacks=[time_callback_no_gc]
234
)
235
236
"""
237
## Train the model with GC
238
239
We will now train the same model, this time using Gradient Centralization,
240
notice our optimizer is the one using Gradient Centralization this time.
241
"""
242
243
time_callback_gc = TimeHistory()
244
model.compile(loss="binary_crossentropy", optimizer=optimizer, metrics=["accuracy"])
245
246
model.summary()
247
248
history_gc = model.fit(train_ds, epochs=10, verbose=1, callbacks=[time_callback_gc])
249
250
"""
251
## Comparing performance
252
"""
253
254
print("Not using Gradient Centralization")
255
print(f"Loss: {history_no_gc.history['loss'][-1]}")
256
print(f"Accuracy: {history_no_gc.history['accuracy'][-1]}")
257
print(f"Training Time: {sum(time_callback_no_gc.times)}")
258
259
print("Using Gradient Centralization")
260
print(f"Loss: {history_gc.history['loss'][-1]}")
261
print(f"Accuracy: {history_gc.history['accuracy'][-1]}")
262
print(f"Training Time: {sum(time_callback_gc.times)}")
263
264
"""
265
Readers are encouraged to try out Gradient Centralization on different datasets from
266
different domains and experiment with it's effect. You are strongly advised to check out
267
the [original paper](https://arxiv.org/abs/2004.01461) as well - the authors present
268
several studies on Gradient Centralization showing how it can improve general
269
performance, generalization, training time as well as more efficient.
270
271
Many thanks to [Ali Mustufa Shaikh](https://github.com/ialimustufa) for reviewing this
272
implementation.
273
"""
274
275