Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/keras_recipes/sample_size_estimate.py
3507 views
1
"""
2
Title: Estimating required sample size for model training
3
Author: [JacoVerster](https://twitter.com/JacoVerster)
4
Date created: 2021/05/20
5
Last modified: 2021/06/06
6
Description: Modeling the relationship between training set size and model accuracy.
7
Accelerator: GPU
8
"""
9
10
"""
11
# Introduction
12
13
In many real-world scenarios, the amount image data available to train a deep learning model is
14
limited. This is especially true in the medical imaging domain, where dataset creation is
15
costly. One of the first questions that usually comes up when approaching a new problem is:
16
**"how many images will we need to train a good enough machine learning model?"**
17
18
In most cases, a small set of samples is available, and we can use it to model the relationship
19
between training data size and model performance. Such a model can be used to estimate the optimal
20
number of images needed to arrive at a sample size that would achieve the required model performance.
21
22
A systematic review of
23
[Sample-Size Determination Methodologies](https://www.researchgate.net/publication/335779941_Sample-Size_Determination_Methodologies_for_Machine_Learning_in_Medical_Imaging_Research_A_Systematic_Review)
24
by Balki et al. provides examples of several sample-size determination methods. In this
25
example, a balanced subsampling scheme is used to determine the optimal sample size for
26
our model. This is done by selecting a random subsample consisting of Y number of images
27
and training the model using the subsample. The model is then evaluated on an independent
28
test set. This process is repeated N times for each subsample with replacement to allow
29
for the construction of a mean and confidence interval for the observed performance.
30
"""
31
32
"""
33
## Setup
34
"""
35
36
import os
37
38
os.environ["KERAS_BACKEND"] = "tensorflow"
39
40
import matplotlib.pyplot as plt
41
import numpy as np
42
import tensorflow as tf
43
import keras
44
from keras import layers
45
import tensorflow_datasets as tfds
46
47
# Define seed and fixed variables
48
seed = 42
49
keras.utils.set_random_seed(seed)
50
AUTO = tf.data.AUTOTUNE
51
52
"""
53
## Load TensorFlow dataset and convert to NumPy arrays
54
55
We'll be using the [TF Flowers dataset](https://www.tensorflow.org/datasets/catalog/tf_flowers).
56
"""
57
58
# Specify dataset parameters
59
dataset_name = "tf_flowers"
60
batch_size = 64
61
image_size = (224, 224)
62
63
# Load data from tfds and split 10% off for a test set
64
(train_data, test_data), ds_info = tfds.load(
65
dataset_name,
66
split=["train[:90%]", "train[90%:]"],
67
shuffle_files=True,
68
as_supervised=True,
69
with_info=True,
70
)
71
72
# Extract number of classes and list of class names
73
num_classes = ds_info.features["label"].num_classes
74
class_names = ds_info.features["label"].names
75
76
print(f"Number of classes: {num_classes}")
77
print(f"Class names: {class_names}")
78
79
80
# Convert datasets to NumPy arrays
81
def dataset_to_array(dataset, image_size, num_classes):
82
images, labels = [], []
83
for img, lab in dataset.as_numpy_iterator():
84
images.append(tf.image.resize(img, image_size).numpy())
85
labels.append(tf.one_hot(lab, num_classes))
86
return np.array(images), np.array(labels)
87
88
89
img_train, label_train = dataset_to_array(train_data, image_size, num_classes)
90
img_test, label_test = dataset_to_array(test_data, image_size, num_classes)
91
92
num_train_samples = len(img_train)
93
print(f"Number of training samples: {num_train_samples}")
94
95
"""
96
## Plot a few examples from the test set
97
"""
98
99
plt.figure(figsize=(16, 12))
100
for n in range(30):
101
ax = plt.subplot(5, 6, n + 1)
102
plt.imshow(img_test[n].astype("uint8"))
103
plt.title(np.array(class_names)[label_test[n] == True][0])
104
plt.axis("off")
105
106
"""
107
## Augmentation
108
109
Define image augmentation using keras preprocessing layers and apply them to the training set.
110
"""
111
112
# Define image augmentation model
113
image_augmentation = keras.Sequential(
114
[
115
layers.RandomFlip(mode="horizontal"),
116
layers.RandomRotation(factor=0.1),
117
layers.RandomZoom(height_factor=(-0.1, -0)),
118
layers.RandomContrast(factor=0.1),
119
],
120
)
121
122
# Apply the augmentations to the training images and plot a few examples
123
img_train = image_augmentation(img_train).numpy()
124
125
plt.figure(figsize=(16, 12))
126
for n in range(30):
127
ax = plt.subplot(5, 6, n + 1)
128
plt.imshow(img_train[n].astype("uint8"))
129
plt.title(np.array(class_names)[label_train[n] == True][0])
130
plt.axis("off")
131
132
"""
133
## Define model building & training functions
134
135
We create a few convenience functions to build a transfer-learning model, compile and
136
train it and unfreeze layers for fine-tuning.
137
"""
138
139
140
def build_model(num_classes, img_size=image_size[0], top_dropout=0.3):
141
"""Creates a classifier based on pre-trained MobileNetV2.
142
143
Arguments:
144
num_classes: Int, number of classese to use in the softmax layer.
145
img_size: Int, square size of input images (defaults is 224).
146
top_dropout: Int, value for dropout layer (defaults is 0.3).
147
148
Returns:
149
Uncompiled Keras model.
150
"""
151
152
# Create input and pre-processing layers for MobileNetV2
153
inputs = layers.Input(shape=(img_size, img_size, 3))
154
x = layers.Rescaling(scale=1.0 / 127.5, offset=-1)(inputs)
155
model = keras.applications.MobileNetV2(
156
include_top=False, weights="imagenet", input_tensor=x
157
)
158
159
# Freeze the pretrained weights
160
model.trainable = False
161
162
# Rebuild top
163
x = layers.GlobalAveragePooling2D(name="avg_pool")(model.output)
164
x = layers.Dropout(top_dropout)(x)
165
outputs = layers.Dense(num_classes, activation="softmax")(x)
166
model = keras.Model(inputs, outputs)
167
168
print("Trainable weights:", len(model.trainable_weights))
169
print("Non_trainable weights:", len(model.non_trainable_weights))
170
return model
171
172
173
def compile_and_train(
174
model,
175
training_data,
176
training_labels,
177
metrics=[keras.metrics.AUC(name="auc"), "acc"],
178
optimizer=keras.optimizers.Adam(),
179
patience=5,
180
epochs=5,
181
):
182
"""Compiles and trains the model.
183
184
Arguments:
185
model: Uncompiled Keras model.
186
training_data: NumPy Array, training data.
187
training_labels: NumPy Array, training labels.
188
metrics: Keras/TF metrics, requires at least 'auc' metric (default is
189
`[keras.metrics.AUC(name='auc'), 'acc']`).
190
optimizer: Keras/TF optimizer (defaults is `keras.optimizers.Adam()).
191
patience: Int, epochsfor EarlyStopping patience (defaults is 5).
192
epochs: Int, number of epochs to train (default is 5).
193
194
Returns:
195
Training history for trained Keras model.
196
"""
197
198
stopper = keras.callbacks.EarlyStopping(
199
monitor="val_auc",
200
mode="max",
201
min_delta=0,
202
patience=patience,
203
verbose=1,
204
restore_best_weights=True,
205
)
206
207
model.compile(loss="categorical_crossentropy", optimizer=optimizer, metrics=metrics)
208
209
history = model.fit(
210
x=training_data,
211
y=training_labels,
212
batch_size=batch_size,
213
epochs=epochs,
214
validation_split=0.1,
215
callbacks=[stopper],
216
)
217
return history
218
219
220
def unfreeze(model, block_name, verbose=0):
221
"""Unfreezes Keras model layers.
222
223
Arguments:
224
model: Keras model.
225
block_name: Str, layer name for example block_name = 'block4'.
226
Checks if supplied string is in the layer name.
227
verbose: Int, 0 means silent, 1 prints out layers trainability status.
228
229
Returns:
230
Keras model with all layers after (and including) the specified
231
block_name to trainable, excluding BatchNormalization layers.
232
"""
233
234
# Unfreeze from block_name onwards
235
set_trainable = False
236
237
for layer in model.layers:
238
if block_name in layer.name:
239
set_trainable = True
240
if set_trainable and not isinstance(layer, layers.BatchNormalization):
241
layer.trainable = True
242
if verbose == 1:
243
print(layer.name, "trainable")
244
else:
245
if verbose == 1:
246
print(layer.name, "NOT trainable")
247
print("Trainable weights:", len(model.trainable_weights))
248
print("Non-trainable weights:", len(model.non_trainable_weights))
249
return model
250
251
252
"""
253
## Define iterative training function
254
255
To train a model over several subsample sets we need to create an iterative training function.
256
"""
257
258
259
def train_model(training_data, training_labels):
260
"""Trains the model as follows:
261
262
- Trains only the top layers for 10 epochs.
263
- Unfreezes deeper layers.
264
- Train for 20 more epochs.
265
266
Arguments:
267
training_data: NumPy Array, training data.
268
training_labels: NumPy Array, training labels.
269
270
Returns:
271
Model accuracy.
272
"""
273
274
model = build_model(num_classes)
275
276
# Compile and train top layers
277
history = compile_and_train(
278
model,
279
training_data,
280
training_labels,
281
metrics=[keras.metrics.AUC(name="auc"), "acc"],
282
optimizer=keras.optimizers.Adam(),
283
patience=3,
284
epochs=10,
285
)
286
287
# Unfreeze model from block 10 onwards
288
model = unfreeze(model, "block_10")
289
290
# Compile and train for 20 epochs with a lower learning rate
291
fine_tune_epochs = 20
292
total_epochs = history.epoch[-1] + fine_tune_epochs
293
294
history_fine = compile_and_train(
295
model,
296
training_data,
297
training_labels,
298
metrics=[keras.metrics.AUC(name="auc"), "acc"],
299
optimizer=keras.optimizers.Adam(learning_rate=1e-4),
300
patience=5,
301
epochs=total_epochs,
302
)
303
304
# Calculate model accuracy on the test set
305
_, _, acc = model.evaluate(img_test, label_test)
306
return np.round(acc, 4)
307
308
309
"""
310
## Train models iteratively
311
312
Now that we have model building functions and supporting iterative functions we can train
313
the model over several subsample splits.
314
315
- We select the subsample splits as 5%, 10%, 25% and 50% of the downloaded dataset. We
316
pretend that only 50% of the actual data is available at present.
317
- We train the model 5 times from scratch at each split and record the accuracy values.
318
319
Note that this trains 20 models and will take some time. Make sure you have a GPU runtime
320
active.
321
322
To keep this example lightweight, sample data from a previous training run is provided.
323
"""
324
325
326
def train_iteratively(sample_splits=[0.05, 0.1, 0.25, 0.5], iter_per_split=5):
327
"""Trains a model iteratively over several sample splits.
328
329
Arguments:
330
sample_splits: List/NumPy array, contains fractions of the trainins set
331
to train over.
332
iter_per_split: Int, number of times to train a model per sample split.
333
334
Returns:
335
Training accuracy for all splits and iterations and the number of samples
336
used for training at each split.
337
"""
338
# Train all the sample models and calculate accuracy
339
train_acc = []
340
sample_sizes = []
341
342
for fraction in sample_splits:
343
print(f"Fraction split: {fraction}")
344
# Repeat training 3 times for each sample size
345
sample_accuracy = []
346
num_samples = int(num_train_samples * fraction)
347
for i in range(iter_per_split):
348
print(f"Run {i+1} out of {iter_per_split}:")
349
# Create fractional subsets
350
rand_idx = np.random.randint(num_train_samples, size=num_samples)
351
train_img_subset = img_train[rand_idx, :]
352
train_label_subset = label_train[rand_idx, :]
353
# Train model and calculate accuracy
354
accuracy = train_model(train_img_subset, train_label_subset)
355
print(f"Accuracy: {accuracy}")
356
sample_accuracy.append(accuracy)
357
train_acc.append(sample_accuracy)
358
sample_sizes.append(num_samples)
359
return train_acc, sample_sizes
360
361
362
# Running the above function produces the following outputs
363
train_acc = [
364
[0.8202, 0.7466, 0.8011, 0.8447, 0.8229],
365
[0.861, 0.8774, 0.8501, 0.8937, 0.891],
366
[0.891, 0.9237, 0.8856, 0.9101, 0.891],
367
[0.8937, 0.9373, 0.9128, 0.8719, 0.9128],
368
]
369
370
sample_sizes = [165, 330, 825, 1651]
371
372
"""
373
## Learning curve
374
375
We now plot the learning curve by fitting an exponential curve through the mean accuracy
376
points. We use TF to fit an exponential function through the data.
377
378
We then extrapolate the learning curve to the predict the accuracy of a model trained on
379
the whole training set.
380
"""
381
382
383
def fit_and_predict(train_acc, sample_sizes, pred_sample_size):
384
"""Fits a learning curve to model training accuracy results.
385
386
Arguments:
387
train_acc: List/Numpy Array, training accuracy for all model
388
training splits and iterations.
389
sample_sizes: List/Numpy array, number of samples used for training at
390
each split.
391
pred_sample_size: Int, sample size to predict model accuracy based on
392
fitted learning curve.
393
"""
394
x = sample_sizes
395
mean_acc = tf.convert_to_tensor([np.mean(i) for i in train_acc])
396
error = [np.std(i) for i in train_acc]
397
398
# Define mean squared error cost and exponential curve fit functions
399
mse = keras.losses.MeanSquaredError()
400
401
def exp_func(x, a, b):
402
return a * x**b
403
404
# Define variables, learning rate and number of epochs for fitting with TF
405
a = tf.Variable(0.0)
406
b = tf.Variable(0.0)
407
learning_rate = 0.01
408
training_epochs = 5000
409
410
# Fit the exponential function to the data
411
for epoch in range(training_epochs):
412
with tf.GradientTape() as tape:
413
y_pred = exp_func(x, a, b)
414
cost_function = mse(y_pred, mean_acc)
415
# Get gradients and compute adjusted weights
416
gradients = tape.gradient(cost_function, [a, b])
417
a.assign_sub(gradients[0] * learning_rate)
418
b.assign_sub(gradients[1] * learning_rate)
419
print(f"Curve fit weights: a = {a.numpy()} and b = {b.numpy()}.")
420
421
# We can now estimate the accuracy for pred_sample_size
422
max_acc = exp_func(pred_sample_size, a, b).numpy()
423
424
# Print predicted x value and append to plot values
425
print(f"A model accuracy of {max_acc} is predicted for {pred_sample_size} samples.")
426
x_cont = np.linspace(x[0], pred_sample_size, 100)
427
428
# Build the plot
429
fig, ax = plt.subplots(figsize=(12, 6))
430
ax.errorbar(x, mean_acc, yerr=error, fmt="o", label="Mean acc & std dev.")
431
ax.plot(x_cont, exp_func(x_cont, a, b), "r-", label="Fitted exponential curve.")
432
ax.set_ylabel("Model classification accuracy.", fontsize=12)
433
ax.set_xlabel("Training sample size.", fontsize=12)
434
ax.set_xticks(np.append(x, pred_sample_size))
435
ax.set_yticks(np.append(mean_acc, max_acc))
436
ax.set_xticklabels(list(np.append(x, pred_sample_size)), rotation=90, fontsize=10)
437
ax.yaxis.set_tick_params(labelsize=10)
438
ax.set_title("Learning curve: model accuracy vs sample size.", fontsize=14)
439
ax.legend(loc=(0.75, 0.75), fontsize=10)
440
ax.xaxis.grid(True)
441
ax.yaxis.grid(True)
442
plt.tight_layout()
443
plt.show()
444
445
# The mean absolute error (MAE) is calculated for curve fit to see how well
446
# it fits the data. The lower the error the better the fit.
447
mae = keras.losses.MeanAbsoluteError()
448
print(f"The mae for the curve fit is {mae(mean_acc, exp_func(x, a, b)).numpy()}.")
449
450
451
# We use the whole training set to predict the model accuracy
452
fit_and_predict(train_acc, sample_sizes, pred_sample_size=num_train_samples)
453
454
"""
455
From the extrapolated curve we can see that 3303 images will yield an estimated
456
accuracy of about 95%.
457
458
Now, let's use all the data (3303 images) and train the model to see if our prediction
459
was accurate!
460
"""
461
462
# Now train the model with full dataset to get the actual accuracy
463
accuracy = train_model(img_train, label_train)
464
print(f"A model accuracy of {accuracy} is reached on {num_train_samples} images!")
465
466
"""
467
## Conclusion
468
469
We see that a model accuracy of about 94-96%* is reached using 3303 images. This is quite
470
close to our estimate!
471
472
Even though we used only 50% of the dataset (1651 images) we were able to model the training
473
behaviour of our model and predict the model accuracy for a given amount of images. This same
474
methodology can be used to predict the amount of images needed to reach a desired accuracy.
475
This is very useful when a smaller set of data is available, and it has been shown that
476
convergence on a deep learning model is possible, but more images are needed. The image count
477
prediction can be used to plan and budget for further image collection initiatives.
478
"""
479
480