Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/cutmix.py
3507 views
1
"""
2
Title: CutMix data augmentation for image classification
3
Author: [Sayan Nath](https://twitter.com/sayannath2350)
4
Date created: 2021/06/08
5
Last modified: 2023/11/14
6
Description: Data augmentation with CutMix for image classification on CIFAR-10.
7
Accelerator: GPU
8
Converted to Keras 3 By: [Piyush Thakur](https://github.com/cosmo3769)
9
"""
10
11
"""
12
## Introduction
13
"""
14
15
"""
16
_CutMix_ is a data augmentation technique that addresses the issue of information loss
17
and inefficiency present in regional dropout strategies.
18
Instead of removing pixels and filling them with black or grey pixels or Gaussian noise,
19
you replace the removed regions with a patch from another image,
20
while the ground truth labels are mixed proportionally to the number of pixels of combined images.
21
CutMix was proposed in
22
[CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features](https://arxiv.org/abs/1905.04899)
23
(Yun et al., 2019)
24
25
It's implemented via the following formulas:
26
27
<img src="https://i.imgur.com/cGvd13V.png" width="200"/>
28
29
where `M` is the binary mask which indicates the cutout and the fill-in
30
regions from the two randomly drawn images and `λ` (in `[0, 1]`) is drawn from a
31
[`Beta(α, α)` distribution](https://en.wikipedia.org/wiki/Beta_distribution)
32
33
The coordinates of bounding boxes are:
34
35
<img src="https://i.imgur.com/eNisep4.png" width="150"/>
36
37
which indicates the cutout and fill-in regions in case of the images.
38
The bounding box sampling is represented by:
39
40
<img src="https://i.imgur.com/Snph9aj.png" width="200"/>
41
42
where `rx, ry` are randomly drawn from a uniform distribution with upper bound.
43
"""
44
45
"""
46
## Setup
47
"""
48
49
import numpy as np
50
import keras
51
import matplotlib.pyplot as plt
52
53
from keras import layers
54
55
# TF imports related to tf.data preprocessing
56
from tensorflow import clip_by_value
57
from tensorflow import data as tf_data
58
from tensorflow import image as tf_image
59
from tensorflow import random as tf_random
60
61
keras.utils.set_random_seed(42)
62
63
"""
64
## Load the CIFAR-10 dataset
65
66
In this example, we will use the
67
[CIFAR-10 image classification dataset](https://www.cs.toronto.edu/~kriz/cifar.html).
68
"""
69
70
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
71
y_train = keras.utils.to_categorical(y_train, num_classes=10)
72
y_test = keras.utils.to_categorical(y_test, num_classes=10)
73
74
print(x_train.shape)
75
print(y_train.shape)
76
print(x_test.shape)
77
print(y_test.shape)
78
79
class_names = [
80
"Airplane",
81
"Automobile",
82
"Bird",
83
"Cat",
84
"Deer",
85
"Dog",
86
"Frog",
87
"Horse",
88
"Ship",
89
"Truck",
90
]
91
92
"""
93
## Define hyperparameters
94
"""
95
96
AUTO = tf_data.AUTOTUNE
97
BATCH_SIZE = 32
98
IMG_SIZE = 32
99
100
"""
101
## Define the image preprocessing function
102
"""
103
104
105
def preprocess_image(image, label):
106
image = tf_image.resize(image, (IMG_SIZE, IMG_SIZE))
107
image = tf_image.convert_image_dtype(image, "float32") / 255.0
108
label = keras.ops.cast(label, dtype="float32")
109
return image, label
110
111
112
"""
113
## Convert the data into TensorFlow `Dataset` objects
114
"""
115
116
train_ds_one = (
117
tf_data.Dataset.from_tensor_slices((x_train, y_train))
118
.shuffle(1024)
119
.map(preprocess_image, num_parallel_calls=AUTO)
120
)
121
train_ds_two = (
122
tf_data.Dataset.from_tensor_slices((x_train, y_train))
123
.shuffle(1024)
124
.map(preprocess_image, num_parallel_calls=AUTO)
125
)
126
127
train_ds_simple = tf_data.Dataset.from_tensor_slices((x_train, y_train))
128
129
test_ds = tf_data.Dataset.from_tensor_slices((x_test, y_test))
130
131
train_ds_simple = (
132
train_ds_simple.map(preprocess_image, num_parallel_calls=AUTO)
133
.batch(BATCH_SIZE)
134
.prefetch(AUTO)
135
)
136
137
# Combine two shuffled datasets from the same training data.
138
train_ds = tf_data.Dataset.zip((train_ds_one, train_ds_two))
139
140
test_ds = (
141
test_ds.map(preprocess_image, num_parallel_calls=AUTO)
142
.batch(BATCH_SIZE)
143
.prefetch(AUTO)
144
)
145
146
"""
147
## Define the CutMix data augmentation function
148
149
The CutMix function takes two `image` and `label` pairs to perform the augmentation.
150
It samples `λ(l)` from the [Beta distribution](https://en.wikipedia.org/wiki/Beta_distribution)
151
and returns a bounding box from `get_box` function. We then crop the second image (`image2`)
152
and pad this image in the final padded image at the same location.
153
"""
154
155
156
def sample_beta_distribution(size, concentration_0=0.2, concentration_1=0.2):
157
gamma_1_sample = tf_random.gamma(shape=[size], alpha=concentration_1)
158
gamma_2_sample = tf_random.gamma(shape=[size], alpha=concentration_0)
159
return gamma_1_sample / (gamma_1_sample + gamma_2_sample)
160
161
162
def get_box(lambda_value):
163
cut_rat = keras.ops.sqrt(1.0 - lambda_value)
164
165
cut_w = IMG_SIZE * cut_rat # rw
166
cut_w = keras.ops.cast(cut_w, "int32")
167
168
cut_h = IMG_SIZE * cut_rat # rh
169
cut_h = keras.ops.cast(cut_h, "int32")
170
171
cut_x = keras.random.uniform((1,), minval=0, maxval=IMG_SIZE) # rx
172
cut_x = keras.ops.cast(cut_x, "int32")
173
cut_y = keras.random.uniform((1,), minval=0, maxval=IMG_SIZE) # ry
174
cut_y = keras.ops.cast(cut_y, "int32")
175
176
boundaryx1 = clip_by_value(cut_x[0] - cut_w // 2, 0, IMG_SIZE)
177
boundaryy1 = clip_by_value(cut_y[0] - cut_h // 2, 0, IMG_SIZE)
178
bbx2 = clip_by_value(cut_x[0] + cut_w // 2, 0, IMG_SIZE)
179
bby2 = clip_by_value(cut_y[0] + cut_h // 2, 0, IMG_SIZE)
180
181
target_h = bby2 - boundaryy1
182
if target_h == 0:
183
target_h += 1
184
185
target_w = bbx2 - boundaryx1
186
if target_w == 0:
187
target_w += 1
188
189
return boundaryx1, boundaryy1, target_h, target_w
190
191
192
def cutmix(train_ds_one, train_ds_two):
193
(image1, label1), (image2, label2) = train_ds_one, train_ds_two
194
195
alpha = [0.25]
196
beta = [0.25]
197
198
# Get a sample from the Beta distribution
199
lambda_value = sample_beta_distribution(1, alpha, beta)
200
201
# Define Lambda
202
lambda_value = lambda_value[0][0]
203
204
# Get the bounding box offsets, heights and widths
205
boundaryx1, boundaryy1, target_h, target_w = get_box(lambda_value)
206
207
# Get a patch from the second image (`image2`)
208
crop2 = tf_image.crop_to_bounding_box(
209
image2, boundaryy1, boundaryx1, target_h, target_w
210
)
211
# Pad the `image2` patch (`crop2`) with the same offset
212
image2 = tf_image.pad_to_bounding_box(
213
crop2, boundaryy1, boundaryx1, IMG_SIZE, IMG_SIZE
214
)
215
# Get a patch from the first image (`image1`)
216
crop1 = tf_image.crop_to_bounding_box(
217
image1, boundaryy1, boundaryx1, target_h, target_w
218
)
219
# Pad the `image1` patch (`crop1`) with the same offset
220
img1 = tf_image.pad_to_bounding_box(
221
crop1, boundaryy1, boundaryx1, IMG_SIZE, IMG_SIZE
222
)
223
224
# Modify the first image by subtracting the patch from `image1`
225
# (before applying the `image2` patch)
226
image1 = image1 - img1
227
# Add the modified `image1` and `image2` together to get the CutMix image
228
image = image1 + image2
229
230
# Adjust Lambda in accordance to the pixel ration
231
lambda_value = 1 - (target_w * target_h) / (IMG_SIZE * IMG_SIZE)
232
lambda_value = keras.ops.cast(lambda_value, "float32")
233
234
# Combine the labels of both images
235
label = lambda_value * label1 + (1 - lambda_value) * label2
236
return image, label
237
238
239
"""
240
**Note**: we are combining two images to create a single one.
241
242
## Visualize the new dataset after applying the CutMix augmentation
243
"""
244
245
# Create the new dataset using our `cutmix` utility
246
train_ds_cmu = (
247
train_ds.shuffle(1024)
248
.map(cutmix, num_parallel_calls=AUTO)
249
.batch(BATCH_SIZE)
250
.prefetch(AUTO)
251
)
252
253
# Let's preview 9 samples from the dataset
254
image_batch, label_batch = next(iter(train_ds_cmu))
255
plt.figure(figsize=(10, 10))
256
for i in range(9):
257
ax = plt.subplot(3, 3, i + 1)
258
plt.title(class_names[np.argmax(label_batch[i])])
259
plt.imshow(image_batch[i])
260
plt.axis("off")
261
262
"""
263
## Define a ResNet-20 model
264
"""
265
266
267
def resnet_layer(
268
inputs,
269
num_filters=16,
270
kernel_size=3,
271
strides=1,
272
activation="relu",
273
batch_normalization=True,
274
conv_first=True,
275
):
276
conv = layers.Conv2D(
277
num_filters,
278
kernel_size=kernel_size,
279
strides=strides,
280
padding="same",
281
kernel_initializer="he_normal",
282
kernel_regularizer=keras.regularizers.L2(1e-4),
283
)
284
x = inputs
285
if conv_first:
286
x = conv(x)
287
if batch_normalization:
288
x = layers.BatchNormalization()(x)
289
if activation is not None:
290
x = layers.Activation(activation)(x)
291
else:
292
if batch_normalization:
293
x = layers.BatchNormalization()(x)
294
if activation is not None:
295
x = layers.Activation(activation)(x)
296
x = conv(x)
297
return x
298
299
300
def resnet_v20(input_shape, depth, num_classes=10):
301
if (depth - 2) % 6 != 0:
302
raise ValueError("depth should be 6n+2 (eg 20, 32, 44 in [a])")
303
# Start model definition.
304
num_filters = 16
305
num_res_blocks = int((depth - 2) / 6)
306
307
inputs = layers.Input(shape=input_shape)
308
x = resnet_layer(inputs=inputs)
309
# Instantiate the stack of residual units
310
for stack in range(3):
311
for res_block in range(num_res_blocks):
312
strides = 1
313
if stack > 0 and res_block == 0: # first layer but not first stack
314
strides = 2 # downsample
315
y = resnet_layer(inputs=x, num_filters=num_filters, strides=strides)
316
y = resnet_layer(inputs=y, num_filters=num_filters, activation=None)
317
if stack > 0 and res_block == 0: # first layer but not first stack
318
# linear projection residual shortcut connection to match
319
# changed dims
320
x = resnet_layer(
321
inputs=x,
322
num_filters=num_filters,
323
kernel_size=1,
324
strides=strides,
325
activation=None,
326
batch_normalization=False,
327
)
328
x = layers.add([x, y])
329
x = layers.Activation("relu")(x)
330
num_filters *= 2
331
332
# Add classifier on top.
333
# v1 does not use BN after last shortcut connection-ReLU
334
x = layers.AveragePooling2D(pool_size=8)(x)
335
y = layers.Flatten()(x)
336
outputs = layers.Dense(
337
num_classes, activation="softmax", kernel_initializer="he_normal"
338
)(y)
339
340
# Instantiate model.
341
model = keras.Model(inputs=inputs, outputs=outputs)
342
return model
343
344
345
def training_model():
346
return resnet_v20((32, 32, 3), 20)
347
348
349
initial_model = training_model()
350
initial_model.save_weights("initial_weights.weights.h5")
351
352
"""
353
## Train the model with the dataset augmented by CutMix
354
"""
355
356
model = training_model()
357
model.load_weights("initial_weights.weights.h5")
358
359
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
360
model.fit(train_ds_cmu, validation_data=test_ds, epochs=15)
361
362
test_loss, test_accuracy = model.evaluate(test_ds)
363
print("Test accuracy: {:.2f}%".format(test_accuracy * 100))
364
365
"""
366
## Train the model using the original non-augmented dataset
367
"""
368
369
model = training_model()
370
model.load_weights("initial_weights.weights.h5")
371
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
372
model.fit(train_ds_simple, validation_data=test_ds, epochs=15)
373
374
test_loss, test_accuracy = model.evaluate(test_ds)
375
print("Test accuracy: {:.2f}%".format(test_accuracy * 100))
376
377
"""
378
## Notes
379
380
In this example, we trained our model for 15 epochs.
381
In our experiment, the model with CutMix achieves a better accuracy on the CIFAR-10 dataset
382
(77.34% in our experiment) compared to the model that doesn't use the augmentation (66.90%).
383
You may notice it takes less time to train the model with the CutMix augmentation.
384
385
You can experiment further with the CutMix technique by following the
386
[original paper](https://arxiv.org/abs/1905.04899).
387
"""
388
389