Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/consistency_training.py
3507 views
1
"""
2
Title: Consistency training with supervision
3
Author: [Sayak Paul](https://twitter.com/RisingSayak)
4
Date created: 2021/04/13
5
Last modified: 2021/04/19
6
Description: Training with consistency regularization for robustness against data distribution shifts.
7
Accelerator: GPU
8
"""
9
10
"""
11
Deep learning models excel in many image recognition tasks when the data is independent
12
and identically distributed (i.i.d.). However, they can suffer from performance
13
degradation caused by subtle distribution shifts in the input data (such as random
14
noise, contrast change, and blurring). So, naturally, there arises a question of
15
why. As discussed in [A Fourier Perspective on Model Robustness in Computer Vision](https://arxiv.org/pdf/1906.08988.pdf)),
16
there's no reason for deep learning models to be robust against such shifts. Standard
17
model training procedures (such as standard image classification training workflows)
18
*don't* enable a model to learn beyond what's fed to it in the form of training data.
19
20
In this example, we will be training an image classification model enforcing a sense of
21
*consistency* inside it by doing the following:
22
23
* Train a standard image classification model.
24
* Train an _equal or larger_ model on a noisy version of the dataset (augmented using
25
[RandAugment](https://arxiv.org/abs/1909.13719)).
26
* To do this, we will first obtain predictions of the previous model on the clean images
27
of the dataset.
28
* We will then use these predictions and train the second model to match these
29
predictions on the noisy variant of the same images. This is identical to the workflow of
30
[*Knowledge Distillation*](https://keras.io/examples/vision/knowledge_distillation/) but
31
since the student model is equal or larger in size this process is also referred to as
32
***Self-Training***.
33
34
This overall training workflow finds its roots in works like
35
[FixMatch](https://arxiv.org/abs/2001.07685), [Unsupervised Data Augmentation for Consistency Training](https://arxiv.org/abs/1904.12848),
36
and [Noisy Student Training](https://arxiv.org/abs/1911.04252). Since this training
37
process encourages a model yield consistent predictions for clean as well as noisy
38
images, it's often referred to as *consistency training* or *training with consistency
39
regularization*. Although the example focuses on using consistency training to enhance
40
the robustness of models to common corruptions this example can also serve a template
41
for performing _weakly supervised learning_.
42
43
This example requires TensorFlow 2.4 or higher, as well as TensorFlow Hub and TensorFlow
44
Models, which can be installed using the following command:
45
46
"""
47
48
"""shell
49
pip install -q tf-models-official tensorflow-addons
50
"""
51
52
"""
53
## Imports and setup
54
"""
55
56
from official.vision.image_classification.augment import RandAugment
57
from tensorflow.keras import layers
58
59
import tensorflow as tf
60
import tensorflow_addons as tfa
61
import matplotlib.pyplot as plt
62
63
tf.random.set_seed(42)
64
65
"""
66
## Define hyperparameters
67
"""
68
69
AUTO = tf.data.AUTOTUNE
70
BATCH_SIZE = 128
71
EPOCHS = 5
72
73
CROP_TO = 72
74
RESIZE_TO = 96
75
76
"""
77
## Load the CIFAR-10 dataset
78
"""
79
80
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
81
82
val_samples = 49500
83
new_train_x, new_y_train = x_train[: val_samples + 1], y_train[: val_samples + 1]
84
val_x, val_y = x_train[val_samples:], y_train[val_samples:]
85
86
"""
87
## Create TensorFlow `Dataset` objects
88
"""
89
90
# Initialize `RandAugment` object with 2 layers of
91
# augmentation transforms and strength of 9.
92
augmenter = RandAugment(num_layers=2, magnitude=9)
93
94
"""
95
For training the teacher model, we will only be using two geometric augmentation
96
transforms: random horizontal flip and random crop.
97
"""
98
99
100
def preprocess_train(image, label, noisy=True):
101
image = tf.image.random_flip_left_right(image)
102
# We first resize the original image to a larger dimension
103
# and then we take random crops from it.
104
image = tf.image.resize(image, [RESIZE_TO, RESIZE_TO])
105
image = tf.image.random_crop(image, [CROP_TO, CROP_TO, 3])
106
if noisy:
107
image = augmenter.distort(image)
108
return image, label
109
110
111
def preprocess_test(image, label):
112
image = tf.image.resize(image, [CROP_TO, CROP_TO])
113
return image, label
114
115
116
train_ds = tf.data.Dataset.from_tensor_slices((new_train_x, new_y_train))
117
validation_ds = tf.data.Dataset.from_tensor_slices((val_x, val_y))
118
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
119
120
"""
121
We make sure `train_clean_ds` and `train_noisy_ds` are shuffled using the *same* seed to
122
ensure their orders are exactly the same. This will be helpful during training the
123
student model.
124
"""
125
126
# This dataset will be used to train the first model.
127
train_clean_ds = (
128
train_ds.shuffle(BATCH_SIZE * 10, seed=42)
129
.map(lambda x, y: (preprocess_train(x, y, noisy=False)), num_parallel_calls=AUTO)
130
.batch(BATCH_SIZE)
131
.prefetch(AUTO)
132
)
133
134
# This prepares the `Dataset` object to use RandAugment.
135
train_noisy_ds = (
136
train_ds.shuffle(BATCH_SIZE * 10, seed=42)
137
.map(preprocess_train, num_parallel_calls=AUTO)
138
.batch(BATCH_SIZE)
139
.prefetch(AUTO)
140
)
141
142
validation_ds = (
143
validation_ds.map(preprocess_test, num_parallel_calls=AUTO)
144
.batch(BATCH_SIZE)
145
.prefetch(AUTO)
146
)
147
148
test_ds = (
149
test_ds.map(preprocess_test, num_parallel_calls=AUTO)
150
.batch(BATCH_SIZE)
151
.prefetch(AUTO)
152
)
153
154
# This dataset will be used to train the second model.
155
consistency_training_ds = tf.data.Dataset.zip((train_clean_ds, train_noisy_ds))
156
157
"""
158
## Visualize the datasets
159
"""
160
161
sample_images, sample_labels = next(iter(train_clean_ds))
162
plt.figure(figsize=(10, 10))
163
for i, image in enumerate(sample_images[:9]):
164
ax = plt.subplot(3, 3, i + 1)
165
plt.imshow(image.numpy().astype("int"))
166
plt.axis("off")
167
168
sample_images, sample_labels = next(iter(train_noisy_ds))
169
plt.figure(figsize=(10, 10))
170
for i, image in enumerate(sample_images[:9]):
171
ax = plt.subplot(3, 3, i + 1)
172
plt.imshow(image.numpy().astype("int"))
173
plt.axis("off")
174
175
"""
176
## Define a model building utility function
177
178
We now define our model building utility. Our model is based on the [ResNet50V2 architecture](https://arxiv.org/abs/1603.05027).
179
"""
180
181
182
def get_training_model(num_classes=10):
183
resnet50_v2 = tf.keras.applications.ResNet50V2(
184
weights=None,
185
include_top=False,
186
input_shape=(CROP_TO, CROP_TO, 3),
187
)
188
model = tf.keras.Sequential(
189
[
190
layers.Input((CROP_TO, CROP_TO, 3)),
191
layers.Rescaling(scale=1.0 / 127.5, offset=-1),
192
resnet50_v2,
193
layers.GlobalAveragePooling2D(),
194
layers.Dense(num_classes),
195
]
196
)
197
return model
198
199
200
"""
201
In the interest of reproducibility, we serialize the initial random weights of the
202
teacher network.
203
"""
204
205
initial_teacher_model = get_training_model()
206
initial_teacher_model.save_weights("initial_teacher_model.h5")
207
208
"""
209
## Train the teacher model
210
211
As noted in Noisy Student Training, if the teacher model is trained with *geometric
212
ensembling* and when the student model is forced to mimic that, it leads to better
213
performance. The original work uses [Stochastic Depth](https://arxiv.org/abs/1603.09382)
214
and [Dropout](https://jmlr.org/papers/v15/srivastava14a.html) to bring in the ensembling
215
part but for this example, we will use [Stochastic Weight Averaging](https://arxiv.org/abs/1803.05407)
216
(SWA) which also resembles geometric ensembling.
217
"""
218
219
# Define the callbacks.
220
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(patience=3)
221
early_stopping = tf.keras.callbacks.EarlyStopping(
222
patience=10, restore_best_weights=True
223
)
224
225
# Initialize SWA from tf-hub.
226
SWA = tfa.optimizers.SWA
227
228
# Compile and train the teacher model.
229
teacher_model = get_training_model()
230
teacher_model.load_weights("initial_teacher_model.h5")
231
teacher_model.compile(
232
# Notice that we are wrapping our optimizer within SWA
233
optimizer=SWA(tf.keras.optimizers.Adam()),
234
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
235
metrics=["accuracy"],
236
)
237
history = teacher_model.fit(
238
train_clean_ds,
239
epochs=EPOCHS,
240
validation_data=validation_ds,
241
callbacks=[reduce_lr, early_stopping],
242
)
243
244
# Evaluate the teacher model on the test set.
245
_, acc = teacher_model.evaluate(test_ds, verbose=0)
246
print(f"Test accuracy: {acc*100}%")
247
248
"""
249
## Define a self-training utility
250
251
For this part, we will borrow the `Distiller` class from [this Keras Example](https://keras.io/examples/vision/knowledge_distillation/).
252
"""
253
254
255
# Majority of the code is taken from:
256
# https://keras.io/examples/vision/knowledge_distillation/
257
class SelfTrainer(tf.keras.Model):
258
def __init__(self, student, teacher):
259
super().__init__()
260
self.student = student
261
self.teacher = teacher
262
263
def compile(
264
self,
265
optimizer,
266
metrics,
267
student_loss_fn,
268
distillation_loss_fn,
269
temperature=3,
270
):
271
super().compile(optimizer=optimizer, metrics=metrics)
272
self.student_loss_fn = student_loss_fn
273
self.distillation_loss_fn = distillation_loss_fn
274
self.temperature = temperature
275
276
def train_step(self, data):
277
# Since our dataset is a zip of two independent datasets,
278
# after initially parsing them, we segregate the
279
# respective images and labels next.
280
clean_ds, noisy_ds = data
281
clean_images, _ = clean_ds
282
noisy_images, y = noisy_ds
283
284
# Forward pass of teacher
285
teacher_predictions = self.teacher(clean_images, training=False)
286
287
with tf.GradientTape() as tape:
288
# Forward pass of student
289
student_predictions = self.student(noisy_images, training=True)
290
291
# Compute losses
292
student_loss = self.student_loss_fn(y, student_predictions)
293
distillation_loss = self.distillation_loss_fn(
294
tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
295
tf.nn.softmax(student_predictions / self.temperature, axis=1),
296
)
297
total_loss = (student_loss + distillation_loss) / 2
298
299
# Compute gradients
300
trainable_vars = self.student.trainable_variables
301
gradients = tape.gradient(total_loss, trainable_vars)
302
303
# Update weights
304
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
305
306
# Update the metrics configured in `compile()`
307
self.compiled_metrics.update_state(
308
y, tf.nn.softmax(student_predictions, axis=1)
309
)
310
311
# Return a dict of performance
312
results = {m.name: m.result() for m in self.metrics}
313
results.update({"total_loss": total_loss})
314
return results
315
316
def test_step(self, data):
317
# During inference, we only pass a dataset consisting images and labels.
318
x, y = data
319
320
# Compute predictions
321
y_prediction = self.student(x, training=False)
322
323
# Update the metrics
324
self.compiled_metrics.update_state(y, tf.nn.softmax(y_prediction, axis=1))
325
326
# Return a dict of performance
327
results = {m.name: m.result() for m in self.metrics}
328
return results
329
330
331
"""
332
The only difference in this implementation is the way loss is being calculated. **Instead
333
of weighted the distillation loss and student loss differently we are taking their
334
average following Noisy Student Training**.
335
"""
336
337
"""
338
## Train the student model
339
"""
340
341
# Define the callbacks.
342
# We are using a larger decay factor to stabilize the training.
343
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
344
patience=3, factor=0.5, monitor="val_accuracy"
345
)
346
early_stopping = tf.keras.callbacks.EarlyStopping(
347
patience=10, restore_best_weights=True, monitor="val_accuracy"
348
)
349
350
# Compile and train the student model.
351
self_trainer = SelfTrainer(student=get_training_model(), teacher=teacher_model)
352
self_trainer.compile(
353
# Notice we are *not* using SWA here.
354
optimizer="adam",
355
metrics=["accuracy"],
356
student_loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
357
distillation_loss_fn=tf.keras.losses.KLDivergence(),
358
temperature=10,
359
)
360
history = self_trainer.fit(
361
consistency_training_ds,
362
epochs=EPOCHS,
363
validation_data=validation_ds,
364
callbacks=[reduce_lr, early_stopping],
365
)
366
367
# Evaluate the student model.
368
acc = self_trainer.evaluate(test_ds, verbose=0)
369
print(f"Test accuracy from student model: {acc*100}%")
370
371
"""
372
## Assess the robustness of the models
373
374
A standard benchmark of assessing the robustness of vision models is to record their
375
performance on corrupted datasets like ImageNet-C and CIFAR-10-C both of which were
376
proposed in [Benchmarking Neural Network Robustness to Common Corruptions and
377
Perturbations](https://arxiv.org/abs/1903.12261). For this example, we will be using the
378
CIFAR-10-C dataset which has 19 different corruptions on 5 different severity levels. To
379
assess the robustness of the models on this dataset, we will do the following:
380
381
* Run the pre-trained models on the highest level of severities and obtain the top-1
382
accuracies.
383
* Compute the mean top-1 accuracy.
384
385
For the purpose of this example, we won't be going through these steps. This is why we
386
trained the models for only 5 epochs. You can check out [this
387
repository](https://github.com/sayakpaul/Consistency-Training-with-Supervision) that
388
demonstrates the full-scale training experiments and also the aforementioned assessment.
389
The figure below presents an executive summary of that assessment:
390
391
![](https://i.ibb.co/HBJkM9R/image.png)
392
393
**Mean Top-1** results stand for the CIFAR-10-C dataset and **Test Top-1** results stand
394
for the CIFAR-10 test set. It's clear that consistency training has an advantage on not
395
only enhancing the model robustness but also on improving the standard test performance.
396
"""
397
398