MixUp augmentation for image classification
Author: Sayak Paul
Date created: 2021/03/06
Last modified: 2023/07/24
Description: Data augmentation using the mixup technique for image classification.
Introduction
mixup is a domain-agnostic data augmentation technique proposed in mixup: Beyond Empirical Risk Minimization by Zhang et al. It's implemented with the following formulas:
(Note that the lambda values are values with the [0, 1] range and are sampled from the Beta distribution.)
The technique is quite systematically named. We are literally mixing up the features and their corresponding labels. Implementation-wise it's simple. Neural networks are prone to memorizing corrupt labels. mixup relaxes this by combining different features with one another (same happens for the labels too) so that a network does not get overconfident about the relationship between the features and their labels.
mixup is specifically useful when we are not sure about selecting a set of augmentation transforms for a given dataset, medical imaging datasets, for example. mixup can be extended to a variety of data modalities such as computer vision, naturallanguage processing, speech, and so on.
Setup
Prepare the dataset
In this example, we will be using the FashionMNIST dataset. But this same recipe can be used for other classification datasets as well.
Define hyperparameters
Convert the data into TensorFlow Dataset
objects
Define the mixup technique function
To perform the mixup routine, we create new virtual datasets using the training data from the same dataset, and apply a lambda value within the [0, 1] range sampled from a Beta distribution — such that, for example, new_x = lambda * x1 + (1 - lambda) * x2
(where x1
and x2
are images) and the same equation is applied to the labels as well.
Note that here , we are combining two images to create a single one. Theoretically, we can combine as many we want but that comes at an increased computation cost. In certain cases, it may not help improve the performance as well.
Visualize the new augmented dataset
For the sake of reproducibility, we serialize the initial random weights of our shallow network.
1. Train the model with the mixed up dataset
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1699655923.381468 16749 device_compiler.h:187] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
907/907 ━━━━━━━━━━━━━━━━━━━━ 13s 9ms/step - accuracy: 0.5335 - loss: 1.4414 - val_accuracy: 0.7635 - val_loss: 0.6678 Epoch 2/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 12s 4ms/step - accuracy: 0.7168 - loss: 0.9688 - val_accuracy: 0.7925 - val_loss: 0.5849 Epoch 3/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 5s 4ms/step - accuracy: 0.7525 - loss: 0.8940 - val_accuracy: 0.8290 - val_loss: 0.5138 Epoch 4/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 4s 3ms/step - accuracy: 0.7742 - loss: 0.8431 - val_accuracy: 0.8360 - val_loss: 0.4726 Epoch 5/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 3ms/step - accuracy: 0.7876 - loss: 0.8095 - val_accuracy: 0.8550 - val_loss: 0.4450 Epoch 6/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 3ms/step - accuracy: 0.8029 - loss: 0.7794 - val_accuracy: 0.8560 - val_loss: 0.4178 Epoch 7/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - accuracy: 0.8039 - loss: 0.7632 - val_accuracy: 0.8600 - val_loss: 0.4056 Epoch 8/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 3ms/step - accuracy: 0.8115 - loss: 0.7465 - val_accuracy: 0.8510 - val_loss: 0.4114 Epoch 9/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 3ms/step - accuracy: 0.8115 - loss: 0.7364 - val_accuracy: 0.8645 - val_loss: 0.3983 Epoch 10/10 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 3ms/step - accuracy: 0.8182 - loss: 0.7237 - val_accuracy: 0.8630 - val_loss: 0.3735 157/157 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - accuracy: 0.8610 - loss: 0.4030 Test accuracy: 85.82%