Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/mixup.py
3507 views
1
"""
2
Title: MixUp augmentation for image classification
3
Author: [Sayak Paul](https://twitter.com/RisingSayak)
4
Date created: 2021/03/06
5
Last modified: 2023/07/24
6
Description: Data augmentation using the mixup technique for image classification.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
"""
13
14
"""
15
_mixup_ is a *domain-agnostic* data augmentation technique proposed in [mixup: Beyond Empirical Risk Minimization](https://arxiv.org/abs/1710.09412)
16
by Zhang et al. It's implemented with the following formulas:
17
18
![](https://i.ibb.co/DRyHYww/image.png)
19
20
(Note that the lambda values are values with the [0, 1] range and are sampled from the
21
[Beta distribution](https://en.wikipedia.org/wiki/Beta_distribution).)
22
23
The technique is quite systematically named. We are literally mixing up the features and
24
their corresponding labels. Implementation-wise it's simple. Neural networks are prone
25
to [memorizing corrupt labels](https://arxiv.org/abs/1611.03530). mixup relaxes this by
26
combining different features with one another (same happens for the labels too) so that
27
a network does not get overconfident about the relationship between the features and
28
their labels.
29
30
mixup is specifically useful when we are not sure about selecting a set of augmentation
31
transforms for a given dataset, medical imaging datasets, for example. mixup can be
32
extended to a variety of data modalities such as computer vision, naturallanguage
33
processing, speech, and so on.
34
"""
35
36
"""
37
## Setup
38
"""
39
40
import os
41
42
os.environ["KERAS_BACKEND"] = "tensorflow"
43
44
import numpy as np
45
import keras
46
import matplotlib.pyplot as plt
47
48
from keras import layers
49
50
# TF imports related to tf.data preprocessing
51
from tensorflow import data as tf_data
52
from tensorflow import image as tf_image
53
from tensorflow.random import gamma as tf_random_gamma
54
55
56
"""
57
## Prepare the dataset
58
59
In this example, we will be using the [FashionMNIST](https://github.com/zalandoresearch/fashion-mnist) dataset. But this same recipe can
60
be used for other classification datasets as well.
61
"""
62
63
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
64
65
x_train = x_train.astype("float32") / 255.0
66
x_train = np.reshape(x_train, (-1, 28, 28, 1))
67
y_train = keras.ops.one_hot(y_train, 10)
68
69
x_test = x_test.astype("float32") / 255.0
70
x_test = np.reshape(x_test, (-1, 28, 28, 1))
71
y_test = keras.ops.one_hot(y_test, 10)
72
73
"""
74
## Define hyperparameters
75
"""
76
77
AUTO = tf_data.AUTOTUNE
78
BATCH_SIZE = 64
79
EPOCHS = 10
80
81
"""
82
## Convert the data into TensorFlow `Dataset` objects
83
"""
84
85
# Put aside a few samples to create our validation set
86
val_samples = 2000
87
x_val, y_val = x_train[:val_samples], y_train[:val_samples]
88
new_x_train, new_y_train = x_train[val_samples:], y_train[val_samples:]
89
90
train_ds_one = (
91
tf_data.Dataset.from_tensor_slices((new_x_train, new_y_train))
92
.shuffle(BATCH_SIZE * 100)
93
.batch(BATCH_SIZE)
94
)
95
train_ds_two = (
96
tf_data.Dataset.from_tensor_slices((new_x_train, new_y_train))
97
.shuffle(BATCH_SIZE * 100)
98
.batch(BATCH_SIZE)
99
)
100
# Because we will be mixing up the images and their corresponding labels, we will be
101
# combining two shuffled datasets from the same training data.
102
train_ds = tf_data.Dataset.zip((train_ds_one, train_ds_two))
103
104
val_ds = tf_data.Dataset.from_tensor_slices((x_val, y_val)).batch(BATCH_SIZE)
105
106
test_ds = tf_data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)
107
108
"""
109
## Define the mixup technique function
110
111
To perform the mixup routine, we create new virtual datasets using the training data from
112
the same dataset, and apply a lambda value within the [0, 1] range sampled from a [Beta distribution](https://en.wikipedia.org/wiki/Beta_distribution)
113
— such that, for example, `new_x = lambda * x1 + (1 - lambda) * x2` (where
114
`x1` and `x2` are images) and the same equation is applied to the labels as well.
115
"""
116
117
118
def sample_beta_distribution(size, concentration_0=0.2, concentration_1=0.2):
119
gamma_1_sample = tf_random_gamma(shape=[size], alpha=concentration_1)
120
gamma_2_sample = tf_random_gamma(shape=[size], alpha=concentration_0)
121
return gamma_1_sample / (gamma_1_sample + gamma_2_sample)
122
123
124
def mix_up(ds_one, ds_two, alpha=0.2):
125
# Unpack two datasets
126
images_one, labels_one = ds_one
127
images_two, labels_two = ds_two
128
batch_size = keras.ops.shape(images_one)[0]
129
130
# Sample lambda and reshape it to do the mixup
131
l = sample_beta_distribution(batch_size, alpha, alpha)
132
x_l = keras.ops.reshape(l, (batch_size, 1, 1, 1))
133
y_l = keras.ops.reshape(l, (batch_size, 1))
134
135
# Perform mixup on both images and labels by combining a pair of images/labels
136
# (one from each dataset) into one image/label
137
images = images_one * x_l + images_two * (1 - x_l)
138
labels = labels_one * y_l + labels_two * (1 - y_l)
139
return (images, labels)
140
141
142
"""
143
**Note** that here , we are combining two images to create a single one. Theoretically,
144
we can combine as many we want but that comes at an increased computation cost. In
145
certain cases, it may not help improve the performance as well.
146
"""
147
148
"""
149
## Visualize the new augmented dataset
150
"""
151
152
# First create the new dataset using our `mix_up` utility
153
train_ds_mu = train_ds.map(
154
lambda ds_one, ds_two: mix_up(ds_one, ds_two, alpha=0.2),
155
num_parallel_calls=AUTO,
156
)
157
158
# Let's preview 9 samples from the dataset
159
sample_images, sample_labels = next(iter(train_ds_mu))
160
plt.figure(figsize=(10, 10))
161
for i, (image, label) in enumerate(zip(sample_images[:9], sample_labels[:9])):
162
ax = plt.subplot(3, 3, i + 1)
163
plt.imshow(image.numpy().squeeze())
164
print(label.numpy().tolist())
165
plt.axis("off")
166
167
"""
168
## Model building
169
"""
170
171
172
def get_training_model():
173
model = keras.Sequential(
174
[
175
layers.Input(shape=(28, 28, 1)),
176
layers.Conv2D(16, (5, 5), activation="relu"),
177
layers.MaxPooling2D(pool_size=(2, 2)),
178
layers.Conv2D(32, (5, 5), activation="relu"),
179
layers.MaxPooling2D(pool_size=(2, 2)),
180
layers.Dropout(0.2),
181
layers.GlobalAveragePooling2D(),
182
layers.Dense(128, activation="relu"),
183
layers.Dense(10, activation="softmax"),
184
]
185
)
186
return model
187
188
189
"""
190
For the sake of reproducibility, we serialize the initial random weights of our shallow
191
network.
192
"""
193
194
initial_model = get_training_model()
195
initial_model.save_weights("initial_weights.weights.h5")
196
197
"""
198
## 1. Train the model with the mixed up dataset
199
"""
200
201
model = get_training_model()
202
model.load_weights("initial_weights.weights.h5")
203
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
204
model.fit(train_ds_mu, validation_data=val_ds, epochs=EPOCHS)
205
_, test_acc = model.evaluate(test_ds)
206
print("Test accuracy: {:.2f}%".format(test_acc * 100))
207
208
"""
209
## 2. Train the model *without* the mixed up dataset
210
"""
211
212
model = get_training_model()
213
model.load_weights("initial_weights.weights.h5")
214
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
215
# Notice that we are NOT using the mixed up dataset here
216
model.fit(train_ds_one, validation_data=val_ds, epochs=EPOCHS)
217
_, test_acc = model.evaluate(test_ds)
218
print("Test accuracy: {:.2f}%".format(test_acc * 100))
219
220
"""
221
Readers are encouraged to try out mixup on different datasets from different domains and
222
experiment with the lambda parameter. You are strongly advised to check out the
223
[original paper](https://arxiv.org/abs/1710.09412) as well - the authors present several ablation studies on mixup
224
showing how it can improve generalization, as well as show their results of combining
225
more than two images to create a single one.
226
"""
227
228
"""
229
## Notes
230
231
* With mixup, you can create synthetic examples — especially when you lack a large
232
dataset - without incurring high computational costs.
233
* [Label smoothing](https://www.pyimagesearch.com/2019/12/30/label-smoothing-with-keras-tensorflow-and-deep-learning/) and mixup usually do not work well together because label smoothing
234
already modifies the hard labels by some factor.
235
* mixup does not work well when you are using [Supervised Contrastive
236
Learning](https://arxiv.org/abs/2004.11362) (SCL) since SCL expects the true labels
237
during its pre-training phase.
238
* A few other benefits of mixup include (as described in the [paper](https://arxiv.org/abs/1710.09412)) robustness to
239
adversarial examples and stabilized GAN (Generative Adversarial Networks) training.
240
* There are a number of data augmentation techniques that extend mixup such as
241
[CutMix](https://arxiv.org/abs/1905.04899) and [AugMix](https://arxiv.org/abs/1912.02781).
242
"""
243
244