Path: blob/master/examples/keras_recipes/sample_size_estimate.py
3507 views
"""1Title: Estimating required sample size for model training2Author: [JacoVerster](https://twitter.com/JacoVerster)3Date created: 2021/05/204Last modified: 2021/06/065Description: Modeling the relationship between training set size and model accuracy.6Accelerator: GPU7"""89"""10# Introduction1112In many real-world scenarios, the amount image data available to train a deep learning model is13limited. This is especially true in the medical imaging domain, where dataset creation is14costly. One of the first questions that usually comes up when approaching a new problem is:15**"how many images will we need to train a good enough machine learning model?"**1617In most cases, a small set of samples is available, and we can use it to model the relationship18between training data size and model performance. Such a model can be used to estimate the optimal19number of images needed to arrive at a sample size that would achieve the required model performance.2021A systematic review of22[Sample-Size Determination Methodologies](https://www.researchgate.net/publication/335779941_Sample-Size_Determination_Methodologies_for_Machine_Learning_in_Medical_Imaging_Research_A_Systematic_Review)23by Balki et al. provides examples of several sample-size determination methods. In this24example, a balanced subsampling scheme is used to determine the optimal sample size for25our model. This is done by selecting a random subsample consisting of Y number of images26and training the model using the subsample. The model is then evaluated on an independent27test set. This process is repeated N times for each subsample with replacement to allow28for the construction of a mean and confidence interval for the observed performance.29"""3031"""32## Setup33"""3435import os3637os.environ["KERAS_BACKEND"] = "tensorflow"3839import matplotlib.pyplot as plt40import numpy as np41import tensorflow as tf42import keras43from keras import layers44import tensorflow_datasets as tfds4546# Define seed and fixed variables47seed = 4248keras.utils.set_random_seed(seed)49AUTO = tf.data.AUTOTUNE5051"""52## Load TensorFlow dataset and convert to NumPy arrays5354We'll be using the [TF Flowers dataset](https://www.tensorflow.org/datasets/catalog/tf_flowers).55"""5657# Specify dataset parameters58dataset_name = "tf_flowers"59batch_size = 6460image_size = (224, 224)6162# Load data from tfds and split 10% off for a test set63(train_data, test_data), ds_info = tfds.load(64dataset_name,65split=["train[:90%]", "train[90%:]"],66shuffle_files=True,67as_supervised=True,68with_info=True,69)7071# Extract number of classes and list of class names72num_classes = ds_info.features["label"].num_classes73class_names = ds_info.features["label"].names7475print(f"Number of classes: {num_classes}")76print(f"Class names: {class_names}")777879# Convert datasets to NumPy arrays80def dataset_to_array(dataset, image_size, num_classes):81images, labels = [], []82for img, lab in dataset.as_numpy_iterator():83images.append(tf.image.resize(img, image_size).numpy())84labels.append(tf.one_hot(lab, num_classes))85return np.array(images), np.array(labels)868788img_train, label_train = dataset_to_array(train_data, image_size, num_classes)89img_test, label_test = dataset_to_array(test_data, image_size, num_classes)9091num_train_samples = len(img_train)92print(f"Number of training samples: {num_train_samples}")9394"""95## Plot a few examples from the test set96"""9798plt.figure(figsize=(16, 12))99for n in range(30):100ax = plt.subplot(5, 6, n + 1)101plt.imshow(img_test[n].astype("uint8"))102plt.title(np.array(class_names)[label_test[n] == True][0])103plt.axis("off")104105"""106## Augmentation107108Define image augmentation using keras preprocessing layers and apply them to the training set.109"""110111# Define image augmentation model112image_augmentation = keras.Sequential(113[114layers.RandomFlip(mode="horizontal"),115layers.RandomRotation(factor=0.1),116layers.RandomZoom(height_factor=(-0.1, -0)),117layers.RandomContrast(factor=0.1),118],119)120121# Apply the augmentations to the training images and plot a few examples122img_train = image_augmentation(img_train).numpy()123124plt.figure(figsize=(16, 12))125for n in range(30):126ax = plt.subplot(5, 6, n + 1)127plt.imshow(img_train[n].astype("uint8"))128plt.title(np.array(class_names)[label_train[n] == True][0])129plt.axis("off")130131"""132## Define model building & training functions133134We create a few convenience functions to build a transfer-learning model, compile and135train it and unfreeze layers for fine-tuning.136"""137138139def build_model(num_classes, img_size=image_size[0], top_dropout=0.3):140"""Creates a classifier based on pre-trained MobileNetV2.141142Arguments:143num_classes: Int, number of classese to use in the softmax layer.144img_size: Int, square size of input images (defaults is 224).145top_dropout: Int, value for dropout layer (defaults is 0.3).146147Returns:148Uncompiled Keras model.149"""150151# Create input and pre-processing layers for MobileNetV2152inputs = layers.Input(shape=(img_size, img_size, 3))153x = layers.Rescaling(scale=1.0 / 127.5, offset=-1)(inputs)154model = keras.applications.MobileNetV2(155include_top=False, weights="imagenet", input_tensor=x156)157158# Freeze the pretrained weights159model.trainable = False160161# Rebuild top162x = layers.GlobalAveragePooling2D(name="avg_pool")(model.output)163x = layers.Dropout(top_dropout)(x)164outputs = layers.Dense(num_classes, activation="softmax")(x)165model = keras.Model(inputs, outputs)166167print("Trainable weights:", len(model.trainable_weights))168print("Non_trainable weights:", len(model.non_trainable_weights))169return model170171172def compile_and_train(173model,174training_data,175training_labels,176metrics=[keras.metrics.AUC(name="auc"), "acc"],177optimizer=keras.optimizers.Adam(),178patience=5,179epochs=5,180):181"""Compiles and trains the model.182183Arguments:184model: Uncompiled Keras model.185training_data: NumPy Array, training data.186training_labels: NumPy Array, training labels.187metrics: Keras/TF metrics, requires at least 'auc' metric (default is188`[keras.metrics.AUC(name='auc'), 'acc']`).189optimizer: Keras/TF optimizer (defaults is `keras.optimizers.Adam()).190patience: Int, epochsfor EarlyStopping patience (defaults is 5).191epochs: Int, number of epochs to train (default is 5).192193Returns:194Training history for trained Keras model.195"""196197stopper = keras.callbacks.EarlyStopping(198monitor="val_auc",199mode="max",200min_delta=0,201patience=patience,202verbose=1,203restore_best_weights=True,204)205206model.compile(loss="categorical_crossentropy", optimizer=optimizer, metrics=metrics)207208history = model.fit(209x=training_data,210y=training_labels,211batch_size=batch_size,212epochs=epochs,213validation_split=0.1,214callbacks=[stopper],215)216return history217218219def unfreeze(model, block_name, verbose=0):220"""Unfreezes Keras model layers.221222Arguments:223model: Keras model.224block_name: Str, layer name for example block_name = 'block4'.225Checks if supplied string is in the layer name.226verbose: Int, 0 means silent, 1 prints out layers trainability status.227228Returns:229Keras model with all layers after (and including) the specified230block_name to trainable, excluding BatchNormalization layers.231"""232233# Unfreeze from block_name onwards234set_trainable = False235236for layer in model.layers:237if block_name in layer.name:238set_trainable = True239if set_trainable and not isinstance(layer, layers.BatchNormalization):240layer.trainable = True241if verbose == 1:242print(layer.name, "trainable")243else:244if verbose == 1:245print(layer.name, "NOT trainable")246print("Trainable weights:", len(model.trainable_weights))247print("Non-trainable weights:", len(model.non_trainable_weights))248return model249250251"""252## Define iterative training function253254To train a model over several subsample sets we need to create an iterative training function.255"""256257258def train_model(training_data, training_labels):259"""Trains the model as follows:260261- Trains only the top layers for 10 epochs.262- Unfreezes deeper layers.263- Train for 20 more epochs.264265Arguments:266training_data: NumPy Array, training data.267training_labels: NumPy Array, training labels.268269Returns:270Model accuracy.271"""272273model = build_model(num_classes)274275# Compile and train top layers276history = compile_and_train(277model,278training_data,279training_labels,280metrics=[keras.metrics.AUC(name="auc"), "acc"],281optimizer=keras.optimizers.Adam(),282patience=3,283epochs=10,284)285286# Unfreeze model from block 10 onwards287model = unfreeze(model, "block_10")288289# Compile and train for 20 epochs with a lower learning rate290fine_tune_epochs = 20291total_epochs = history.epoch[-1] + fine_tune_epochs292293history_fine = compile_and_train(294model,295training_data,296training_labels,297metrics=[keras.metrics.AUC(name="auc"), "acc"],298optimizer=keras.optimizers.Adam(learning_rate=1e-4),299patience=5,300epochs=total_epochs,301)302303# Calculate model accuracy on the test set304_, _, acc = model.evaluate(img_test, label_test)305return np.round(acc, 4)306307308"""309## Train models iteratively310311Now that we have model building functions and supporting iterative functions we can train312the model over several subsample splits.313314- We select the subsample splits as 5%, 10%, 25% and 50% of the downloaded dataset. We315pretend that only 50% of the actual data is available at present.316- We train the model 5 times from scratch at each split and record the accuracy values.317318Note that this trains 20 models and will take some time. Make sure you have a GPU runtime319active.320321To keep this example lightweight, sample data from a previous training run is provided.322"""323324325def train_iteratively(sample_splits=[0.05, 0.1, 0.25, 0.5], iter_per_split=5):326"""Trains a model iteratively over several sample splits.327328Arguments:329sample_splits: List/NumPy array, contains fractions of the trainins set330to train over.331iter_per_split: Int, number of times to train a model per sample split.332333Returns:334Training accuracy for all splits and iterations and the number of samples335used for training at each split.336"""337# Train all the sample models and calculate accuracy338train_acc = []339sample_sizes = []340341for fraction in sample_splits:342print(f"Fraction split: {fraction}")343# Repeat training 3 times for each sample size344sample_accuracy = []345num_samples = int(num_train_samples * fraction)346for i in range(iter_per_split):347print(f"Run {i+1} out of {iter_per_split}:")348# Create fractional subsets349rand_idx = np.random.randint(num_train_samples, size=num_samples)350train_img_subset = img_train[rand_idx, :]351train_label_subset = label_train[rand_idx, :]352# Train model and calculate accuracy353accuracy = train_model(train_img_subset, train_label_subset)354print(f"Accuracy: {accuracy}")355sample_accuracy.append(accuracy)356train_acc.append(sample_accuracy)357sample_sizes.append(num_samples)358return train_acc, sample_sizes359360361# Running the above function produces the following outputs362train_acc = [363[0.8202, 0.7466, 0.8011, 0.8447, 0.8229],364[0.861, 0.8774, 0.8501, 0.8937, 0.891],365[0.891, 0.9237, 0.8856, 0.9101, 0.891],366[0.8937, 0.9373, 0.9128, 0.8719, 0.9128],367]368369sample_sizes = [165, 330, 825, 1651]370371"""372## Learning curve373374We now plot the learning curve by fitting an exponential curve through the mean accuracy375points. We use TF to fit an exponential function through the data.376377We then extrapolate the learning curve to the predict the accuracy of a model trained on378the whole training set.379"""380381382def fit_and_predict(train_acc, sample_sizes, pred_sample_size):383"""Fits a learning curve to model training accuracy results.384385Arguments:386train_acc: List/Numpy Array, training accuracy for all model387training splits and iterations.388sample_sizes: List/Numpy array, number of samples used for training at389each split.390pred_sample_size: Int, sample size to predict model accuracy based on391fitted learning curve.392"""393x = sample_sizes394mean_acc = tf.convert_to_tensor([np.mean(i) for i in train_acc])395error = [np.std(i) for i in train_acc]396397# Define mean squared error cost and exponential curve fit functions398mse = keras.losses.MeanSquaredError()399400def exp_func(x, a, b):401return a * x**b402403# Define variables, learning rate and number of epochs for fitting with TF404a = tf.Variable(0.0)405b = tf.Variable(0.0)406learning_rate = 0.01407training_epochs = 5000408409# Fit the exponential function to the data410for epoch in range(training_epochs):411with tf.GradientTape() as tape:412y_pred = exp_func(x, a, b)413cost_function = mse(y_pred, mean_acc)414# Get gradients and compute adjusted weights415gradients = tape.gradient(cost_function, [a, b])416a.assign_sub(gradients[0] * learning_rate)417b.assign_sub(gradients[1] * learning_rate)418print(f"Curve fit weights: a = {a.numpy()} and b = {b.numpy()}.")419420# We can now estimate the accuracy for pred_sample_size421max_acc = exp_func(pred_sample_size, a, b).numpy()422423# Print predicted x value and append to plot values424print(f"A model accuracy of {max_acc} is predicted for {pred_sample_size} samples.")425x_cont = np.linspace(x[0], pred_sample_size, 100)426427# Build the plot428fig, ax = plt.subplots(figsize=(12, 6))429ax.errorbar(x, mean_acc, yerr=error, fmt="o", label="Mean acc & std dev.")430ax.plot(x_cont, exp_func(x_cont, a, b), "r-", label="Fitted exponential curve.")431ax.set_ylabel("Model classification accuracy.", fontsize=12)432ax.set_xlabel("Training sample size.", fontsize=12)433ax.set_xticks(np.append(x, pred_sample_size))434ax.set_yticks(np.append(mean_acc, max_acc))435ax.set_xticklabels(list(np.append(x, pred_sample_size)), rotation=90, fontsize=10)436ax.yaxis.set_tick_params(labelsize=10)437ax.set_title("Learning curve: model accuracy vs sample size.", fontsize=14)438ax.legend(loc=(0.75, 0.75), fontsize=10)439ax.xaxis.grid(True)440ax.yaxis.grid(True)441plt.tight_layout()442plt.show()443444# The mean absolute error (MAE) is calculated for curve fit to see how well445# it fits the data. The lower the error the better the fit.446mae = keras.losses.MeanAbsoluteError()447print(f"The mae for the curve fit is {mae(mean_acc, exp_func(x, a, b)).numpy()}.")448449450# We use the whole training set to predict the model accuracy451fit_and_predict(train_acc, sample_sizes, pred_sample_size=num_train_samples)452453"""454From the extrapolated curve we can see that 3303 images will yield an estimated455accuracy of about 95%.456457Now, let's use all the data (3303 images) and train the model to see if our prediction458was accurate!459"""460461# Now train the model with full dataset to get the actual accuracy462accuracy = train_model(img_train, label_train)463print(f"A model accuracy of {accuracy} is reached on {num_train_samples} images!")464465"""466## Conclusion467468We see that a model accuracy of about 94-96%* is reached using 3303 images. This is quite469close to our estimate!470471Even though we used only 50% of the dataset (1651 images) we were able to model the training472behaviour of our model and predict the model accuracy for a given amount of images. This same473methodology can be used to predict the amount of images needed to reach a desired accuracy.474This is very useful when a smaller set of data is available, and it has been shown that475convergence on a deep learning model is possible, but more images are needed. The image count476prediction can be used to plan and budget for further image collection initiatives.477"""478479480