Path: blob/master/examples/vision/gradient_centralization.py
3507 views
"""1Title: Gradient Centralization for Better Training Performance2Author: [Rishit Dagli](https://github.com/Rishit-dagli)3Date created: 06/18/214Last modified: 07/25/235Description: Implement Gradient Centralization to improve training performance of DNNs.6Accelerator: GPU7Converted to Keras 3 by: [Muhammad Anas Raza](https://anasrz.com)8"""910"""11## Introduction1213This example implements [Gradient Centralization](https://arxiv.org/abs/2004.01461), a14new optimization technique for Deep Neural Networks by Yong et al., and demonstrates it15on Laurence Moroney's [Horses or Humans16Dataset](https://www.tensorflow.org/datasets/catalog/horses_or_humans). Gradient17Centralization can both speedup training process and improve the final generalization18performance of DNNs. It operates directly on gradients by centralizing the gradient19vectors to have zero mean. Gradient Centralization morever improves the Lipschitzness of20the loss function and its gradient so that the training process becomes more efficient21and stable.2223This example requires `tensorflow_datasets` which can be installed with this command:2425```26pip install tensorflow-datasets27```28"""2930"""31## Setup32"""3334from time import time3536import keras37from keras import layers38from keras.optimizers import RMSprop39from keras import ops4041from tensorflow import data as tf_data42import tensorflow_datasets as tfds434445"""46## Prepare the data4748For this example, we will be using the [Horses or Humans49dataset](https://www.tensorflow.org/datasets/catalog/horses_or_humans).50"""5152num_classes = 253input_shape = (300, 300, 3)54dataset_name = "horses_or_humans"55batch_size = 12856AUTOTUNE = tf_data.AUTOTUNE5758(train_ds, test_ds), metadata = tfds.load(59name=dataset_name,60split=[tfds.Split.TRAIN, tfds.Split.TEST],61with_info=True,62as_supervised=True,63)6465print(f"Image shape: {metadata.features['image'].shape}")66print(f"Training images: {metadata.splits['train'].num_examples}")67print(f"Test images: {metadata.splits['test'].num_examples}")6869"""70## Use Data Augmentation7172We will rescale the data to `[0, 1]` and perform simple augmentations to our data.73"""7475rescale = layers.Rescaling(1.0 / 255)7677data_augmentation = [78layers.RandomFlip("horizontal_and_vertical"),79layers.RandomRotation(0.3),80layers.RandomZoom(0.2),81]828384# Helper to apply augmentation85def apply_aug(x):86for aug in data_augmentation:87x = aug(x)88return x899091def prepare(ds, shuffle=False, augment=False):92# Rescale dataset93ds = ds.map(lambda x, y: (rescale(x), y), num_parallel_calls=AUTOTUNE)9495if shuffle:96ds = ds.shuffle(1024)9798# Batch dataset99ds = ds.batch(batch_size)100101# Use data augmentation only on the training set102if augment:103ds = ds.map(104lambda x, y: (apply_aug(x), y),105num_parallel_calls=AUTOTUNE,106)107108# Use buffered prefecting109return ds.prefetch(buffer_size=AUTOTUNE)110111112"""113Rescale and augment the data114"""115116train_ds = prepare(train_ds, shuffle=True, augment=True)117test_ds = prepare(test_ds)118"""119## Define a model120121In this section we will define a Convolutional neural network.122"""123124model = keras.Sequential(125[126layers.Input(shape=input_shape),127layers.Conv2D(16, (3, 3), activation="relu"),128layers.MaxPooling2D(2, 2),129layers.Conv2D(32, (3, 3), activation="relu"),130layers.Dropout(0.5),131layers.MaxPooling2D(2, 2),132layers.Conv2D(64, (3, 3), activation="relu"),133layers.Dropout(0.5),134layers.MaxPooling2D(2, 2),135layers.Conv2D(64, (3, 3), activation="relu"),136layers.MaxPooling2D(2, 2),137layers.Conv2D(64, (3, 3), activation="relu"),138layers.MaxPooling2D(2, 2),139layers.Flatten(),140layers.Dropout(0.5),141layers.Dense(512, activation="relu"),142layers.Dense(1, activation="sigmoid"),143]144)145146"""147## Implement Gradient Centralization148149We will now150subclass the `RMSProp` optimizer class modifying the151`keras.optimizers.Optimizer.get_gradients()` method where we now implement Gradient152Centralization. On a high level the idea is that let us say we obtain our gradients153through back propagation for a Dense or Convolution layer we then compute the mean of the154column vectors of the weight matrix, and then remove the mean from each column vector.155156The experiments in [this paper](https://arxiv.org/abs/2004.01461) on various157applications, including general image classification, fine-grained image classification,158detection and segmentation and Person ReID demonstrate that GC can consistently improve159the performance of DNN learning.160161Also, for simplicity at the moment we are not implementing gradient cliiping functionality,162however this quite easy to implement.163164At the moment we are just creating a subclass for the `RMSProp` optimizer165however you could easily reproduce this for any other optimizer or on a custom166optimizer in the same way. We will be using this class in the later section when167we train a model with Gradient Centralization.168"""169170171class GCRMSprop(RMSprop):172def get_gradients(self, loss, params):173# We here just provide a modified get_gradients() function since we are174# trying to just compute the centralized gradients.175176grads = []177gradients = super().get_gradients()178for grad in gradients:179grad_len = len(grad.shape)180if grad_len > 1:181axis = list(range(grad_len - 1))182grad -= ops.mean(grad, axis=axis, keep_dims=True)183grads.append(grad)184185return grads186187188optimizer = GCRMSprop(learning_rate=1e-4)189190"""191## Training utilities192193We will also create a callback which allows us to easily measure the total training time194and the time taken for each epoch since we are interested in comparing the effect of195Gradient Centralization on the model we built above.196"""197198199class TimeHistory(keras.callbacks.Callback):200def on_train_begin(self, logs={}):201self.times = []202203def on_epoch_begin(self, batch, logs={}):204self.epoch_time_start = time()205206def on_epoch_end(self, batch, logs={}):207self.times.append(time() - self.epoch_time_start)208209210"""211## Train the model without GC212213We now train the model we built earlier without Gradient Centralization which we can214compare to the training performance of the model trained with Gradient Centralization.215"""216217time_callback_no_gc = TimeHistory()218model.compile(219loss="binary_crossentropy",220optimizer=RMSprop(learning_rate=1e-4),221metrics=["accuracy"],222)223224model.summary()225226"""227We also save the history since we later want to compare our model trained with and not228trained with Gradient Centralization229"""230231history_no_gc = model.fit(232train_ds, epochs=10, verbose=1, callbacks=[time_callback_no_gc]233)234235"""236## Train the model with GC237238We will now train the same model, this time using Gradient Centralization,239notice our optimizer is the one using Gradient Centralization this time.240"""241242time_callback_gc = TimeHistory()243model.compile(loss="binary_crossentropy", optimizer=optimizer, metrics=["accuracy"])244245model.summary()246247history_gc = model.fit(train_ds, epochs=10, verbose=1, callbacks=[time_callback_gc])248249"""250## Comparing performance251"""252253print("Not using Gradient Centralization")254print(f"Loss: {history_no_gc.history['loss'][-1]}")255print(f"Accuracy: {history_no_gc.history['accuracy'][-1]}")256print(f"Training Time: {sum(time_callback_no_gc.times)}")257258print("Using Gradient Centralization")259print(f"Loss: {history_gc.history['loss'][-1]}")260print(f"Accuracy: {history_gc.history['accuracy'][-1]}")261print(f"Training Time: {sum(time_callback_gc.times)}")262263"""264Readers are encouraged to try out Gradient Centralization on different datasets from265different domains and experiment with it's effect. You are strongly advised to check out266the [original paper](https://arxiv.org/abs/2004.01461) as well - the authors present267several studies on Gradient Centralization showing how it can improve general268performance, generalization, training time as well as more efficient.269270Many thanks to [Ali Mustufa Shaikh](https://github.com/ialimustufa) for reviewing this271implementation.272"""273274275