Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/fixres.py
3507 views
1
"""
2
Title: FixRes: Fixing train-test resolution discrepancy
3
Author: [Sayak Paul](https://twitter.com/RisingSayak)
4
Date created: 2021/10/08
5
Last modified: 2021/10/10
6
Description: Mitigating resolution discrepancy between training and test sets.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
It is a common practice to use the same input image resolution while training and testing
14
vision models. However, as investigated in
15
[Fixing the train-test resolution discrepancy](https://arxiv.org/abs/1906.06423)
16
(Touvron et al.), this practice leads to suboptimal performance. Data augmentation
17
is an indispensable part of the training process of deep neural networks. For vision models, we
18
typically use random resized crops during training and center crops during inference.
19
This introduces a discrepancy in the object sizes seen during training and inference.
20
As shown by Touvron et al., if we can fix this discrepancy, we can significantly
21
boost model performance.
22
23
In this example, we implement the **FixRes** techniques introduced by Touvron et al.
24
to fix this discrepancy.
25
"""
26
27
"""
28
## Imports
29
"""
30
31
import keras
32
from keras import layers
33
import tensorflow as tf # just for image processing and pipeline
34
35
import tensorflow_datasets as tfds
36
37
tfds.disable_progress_bar()
38
39
import matplotlib.pyplot as plt
40
41
"""
42
## Load the `tf_flowers` dataset
43
"""
44
45
train_dataset, val_dataset = tfds.load(
46
"tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True
47
)
48
49
num_train = train_dataset.cardinality()
50
num_val = val_dataset.cardinality()
51
print(f"Number of training examples: {num_train}")
52
print(f"Number of validation examples: {num_val}")
53
54
"""
55
## Data preprocessing utilities
56
"""
57
58
"""
59
We create three datasets:
60
61
1. A dataset with a smaller resolution - 128x128.
62
2. Two datasets with a larger resolution - 224x224.
63
64
We will apply different augmentation transforms to the larger-resolution datasets.
65
66
The idea of FixRes is to first train a model on a smaller resolution dataset and then fine-tune
67
it on a larger resolution dataset. This simple yet effective recipe leads to non-trivial performance
68
improvements. Please refer to the [original paper](https://arxiv.org/abs/1906.06423) for
69
results.
70
"""
71
72
# Reference: https://github.com/facebookresearch/FixRes/blob/main/transforms_v2.py.
73
74
batch_size = 32
75
auto = tf.data.AUTOTUNE
76
smaller_size = 128
77
bigger_size = 224
78
79
size_for_resizing = int((bigger_size / smaller_size) * bigger_size)
80
central_crop_layer = layers.CenterCrop(bigger_size, bigger_size)
81
82
83
def preprocess_initial(train, image_size):
84
"""Initial preprocessing function for training on smaller resolution.
85
86
For training, do random_horizontal_flip -> random_crop.
87
For validation, just resize.
88
No color-jittering has been used.
89
"""
90
91
def _pp(image, label, train):
92
if train:
93
channels = image.shape[-1]
94
begin, size, _ = tf.image.sample_distorted_bounding_box(
95
tf.shape(image),
96
tf.zeros([0, 0, 4], tf.float32),
97
area_range=(0.05, 1.0),
98
min_object_covered=0,
99
use_image_if_no_bounding_boxes=True,
100
)
101
image = tf.slice(image, begin, size)
102
103
image.set_shape([None, None, channels])
104
image = tf.image.resize(image, [image_size, image_size])
105
image = tf.image.random_flip_left_right(image)
106
else:
107
image = tf.image.resize(image, [image_size, image_size])
108
109
return image, label
110
111
return _pp
112
113
114
def preprocess_finetune(image, label, train):
115
"""Preprocessing function for fine-tuning on a higher resolution.
116
117
For training, resize to a bigger resolution to maintain the ratio ->
118
random_horizontal_flip -> center_crop.
119
For validation, do the same without any horizontal flipping.
120
No color-jittering has been used.
121
"""
122
image = tf.image.resize(image, [size_for_resizing, size_for_resizing])
123
if train:
124
image = tf.image.random_flip_left_right(image)
125
image = central_crop_layer(image[None, ...])[0]
126
127
return image, label
128
129
130
def make_dataset(
131
dataset: tf.data.Dataset,
132
train: bool,
133
image_size: int = smaller_size,
134
fixres: bool = True,
135
num_parallel_calls=auto,
136
):
137
if image_size not in [smaller_size, bigger_size]:
138
raise ValueError(f"{image_size} resolution is not supported.")
139
140
# Determine which preprocessing function we are using.
141
if image_size == smaller_size:
142
preprocess_func = preprocess_initial(train, image_size)
143
elif not fixres and image_size == bigger_size:
144
preprocess_func = preprocess_initial(train, image_size)
145
else:
146
preprocess_func = preprocess_finetune
147
148
dataset = dataset.map(
149
lambda x, y: preprocess_func(x, y, train),
150
num_parallel_calls=num_parallel_calls,
151
)
152
dataset = dataset.batch(batch_size)
153
154
if train:
155
dataset = dataset.shuffle(batch_size * 10)
156
157
return dataset.prefetch(num_parallel_calls)
158
159
160
"""
161
Notice how the augmentation transforms vary for the kind of dataset we are preparing.
162
"""
163
164
"""
165
## Prepare datasets
166
"""
167
168
initial_train_dataset = make_dataset(train_dataset, train=True, image_size=smaller_size)
169
initial_val_dataset = make_dataset(val_dataset, train=False, image_size=smaller_size)
170
171
finetune_train_dataset = make_dataset(train_dataset, train=True, image_size=bigger_size)
172
finetune_val_dataset = make_dataset(val_dataset, train=False, image_size=bigger_size)
173
174
vanilla_train_dataset = make_dataset(
175
train_dataset, train=True, image_size=bigger_size, fixres=False
176
)
177
vanilla_val_dataset = make_dataset(
178
val_dataset, train=False, image_size=bigger_size, fixres=False
179
)
180
181
"""
182
## Visualize the datasets
183
"""
184
185
186
def visualize_dataset(batch_images):
187
plt.figure(figsize=(10, 10))
188
for n in range(25):
189
ax = plt.subplot(5, 5, n + 1)
190
plt.imshow(batch_images[n].numpy().astype("int"))
191
plt.axis("off")
192
plt.show()
193
194
print(f"Batch shape: {batch_images.shape}.")
195
196
197
# Smaller resolution.
198
initial_sample_images, _ = next(iter(initial_train_dataset))
199
visualize_dataset(initial_sample_images)
200
201
# Bigger resolution, only for fine-tuning.
202
finetune_sample_images, _ = next(iter(finetune_train_dataset))
203
visualize_dataset(finetune_sample_images)
204
205
# Bigger resolution, with the same augmentation transforms as
206
# the smaller resolution dataset.
207
vanilla_sample_images, _ = next(iter(vanilla_train_dataset))
208
visualize_dataset(vanilla_sample_images)
209
210
"""
211
## Model training utilities
212
213
We train multiple variants of ResNet50V2
214
([He et al.](https://arxiv.org/abs/1603.05027)):
215
216
1. On the smaller resolution dataset (128x128). It will be trained from scratch.
217
2. Then fine-tune the model from 1 on the larger resolution (224x224) dataset.
218
3. Train another ResNet50V2 from scratch on the larger resolution dataset.
219
220
As a reminder, the larger resolution datasets differ in terms of their augmentation
221
transforms.
222
"""
223
224
225
def get_training_model(num_classes=5):
226
inputs = layers.Input((None, None, 3))
227
resnet_base = keras.applications.ResNet50V2(
228
include_top=False, weights=None, pooling="avg"
229
)
230
resnet_base.trainable = True
231
232
x = layers.Rescaling(scale=1.0 / 127.5, offset=-1)(inputs)
233
x = resnet_base(x)
234
outputs = layers.Dense(num_classes, activation="softmax")(x)
235
return keras.Model(inputs, outputs)
236
237
238
def train_and_evaluate(
239
model,
240
train_ds,
241
val_ds,
242
epochs,
243
learning_rate=1e-3,
244
use_early_stopping=False,
245
):
246
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
247
model.compile(
248
optimizer=optimizer,
249
loss="sparse_categorical_crossentropy",
250
metrics=["accuracy"],
251
)
252
253
if use_early_stopping:
254
es_callback = keras.callbacks.EarlyStopping(patience=5)
255
callbacks = [es_callback]
256
else:
257
callbacks = None
258
259
model.fit(
260
train_ds,
261
validation_data=val_ds,
262
epochs=epochs,
263
callbacks=callbacks,
264
)
265
266
_, accuracy = model.evaluate(val_ds)
267
print(f"Top-1 accuracy on the validation set: {accuracy*100:.2f}%.")
268
return model
269
270
271
"""
272
## Experiment 1: Train on 128x128 and then fine-tune on 224x224
273
"""
274
275
epochs = 30
276
277
smaller_res_model = get_training_model()
278
smaller_res_model = train_and_evaluate(
279
smaller_res_model, initial_train_dataset, initial_val_dataset, epochs
280
)
281
282
"""
283
### Freeze all the layers except for the final Batch Normalization layer
284
285
For fine-tuning, we train only two layers:
286
287
* The final Batch Normalization ([Ioffe et al.](https://arxiv.org/abs/1502.03167)) layer.
288
* The classification layer.
289
290
We are unfreezing the final Batch Normalization layer to compensate for the change in
291
activation statistics before the global average pooling layer. As shown in
292
[the paper](https://arxiv.org/abs/1906.06423), unfreezing the final Batch
293
Normalization layer is enough.
294
295
For a comprehensive guide on fine-tuning models in Keras, refer to
296
[this tutorial](https://keras.io/guides/transfer_learning/).
297
"""
298
299
for layer in smaller_res_model.layers[2].layers:
300
layer.trainable = False
301
302
smaller_res_model.layers[2].get_layer("post_bn").trainable = True
303
304
epochs = 10
305
306
# Use a lower learning rate during fine-tuning.
307
bigger_res_model = train_and_evaluate(
308
smaller_res_model,
309
finetune_train_dataset,
310
finetune_val_dataset,
311
epochs,
312
learning_rate=1e-4,
313
)
314
315
"""
316
## Experiment 2: Train a model on 224x224 resolution from scratch
317
318
Now, we train another model from scratch on the larger resolution dataset. Recall that
319
the augmentation transforms used in this dataset are different from before.
320
"""
321
322
epochs = 30
323
324
vanilla_bigger_res_model = get_training_model()
325
vanilla_bigger_res_model = train_and_evaluate(
326
vanilla_bigger_res_model, vanilla_train_dataset, vanilla_val_dataset, epochs
327
)
328
329
"""
330
As we can notice from the above cells, FixRes leads to a better performance. Another
331
advantage of FixRes is the improved total training time and reduction in GPU memory usage.
332
FixRes is model-agnostic, you can use it on any image classification model
333
to potentially boost performance.
334
335
You can find more results
336
[here](https://tensorboard.dev/experiment/BQOg28w0TlmvuJYeqsVntw)
337
that were gathered by running the same code with different random seeds.
338
"""
339
340