Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/image_classification_from_scratch.py
3507 views
1
"""
2
Title: Image classification from scratch
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2020/04/27
5
Last modified: 2023/11/09
6
Description: Training an image classifier from scratch on the Kaggle Cats vs Dogs dataset.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
This example shows how to do image classification from scratch, starting from JPEG
14
image files on disk, without leveraging pre-trained weights or a pre-made Keras
15
Application model. We demonstrate the workflow on the Kaggle Cats vs Dogs binary
16
classification dataset.
17
18
We use the `image_dataset_from_directory` utility to generate the datasets, and
19
we use Keras image preprocessing layers for image standardization and data augmentation.
20
"""
21
22
"""
23
## Setup
24
"""
25
26
import os
27
import numpy as np
28
import keras
29
from keras import layers
30
from tensorflow import data as tf_data
31
import matplotlib.pyplot as plt
32
33
"""
34
## Load the data: the Cats vs Dogs dataset
35
36
### Raw data download
37
38
First, let's download the 786M ZIP archive of the raw data:
39
"""
40
41
"""shell
42
curl -O https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip
43
"""
44
45
"""shell
46
unzip -q kagglecatsanddogs_5340.zip
47
ls
48
"""
49
50
"""
51
Now we have a `PetImages` folder which contain two subfolders, `Cat` and `Dog`. Each
52
subfolder contains image files for each category.
53
"""
54
55
"""shell
56
ls PetImages
57
"""
58
59
"""
60
### Filter out corrupted images
61
62
When working with lots of real-world image data, corrupted images are a common
63
occurence. Let's filter out badly-encoded images that do not feature the string "JFIF"
64
in their header.
65
"""
66
67
num_skipped = 0
68
for folder_name in ("Cat", "Dog"):
69
folder_path = os.path.join("PetImages", folder_name)
70
for fname in os.listdir(folder_path):
71
fpath = os.path.join(folder_path, fname)
72
try:
73
fobj = open(fpath, "rb")
74
is_jfif = b"JFIF" in fobj.peek(10)
75
finally:
76
fobj.close()
77
78
if not is_jfif:
79
num_skipped += 1
80
# Delete corrupted image
81
os.remove(fpath)
82
83
print(f"Deleted {num_skipped} images.")
84
85
"""
86
## Generate a `Dataset`
87
"""
88
89
image_size = (180, 180)
90
batch_size = 128
91
92
train_ds, val_ds = keras.utils.image_dataset_from_directory(
93
"PetImages",
94
validation_split=0.2,
95
subset="both",
96
seed=1337,
97
image_size=image_size,
98
batch_size=batch_size,
99
)
100
101
"""
102
## Visualize the data
103
104
Here are the first 9 images in the training dataset.
105
"""
106
107
108
plt.figure(figsize=(10, 10))
109
for images, labels in train_ds.take(1):
110
for i in range(9):
111
ax = plt.subplot(3, 3, i + 1)
112
plt.imshow(np.array(images[i]).astype("uint8"))
113
plt.title(int(labels[i]))
114
plt.axis("off")
115
116
"""
117
## Using image data augmentation
118
119
When you don't have a large image dataset, it's a good practice to artificially
120
introduce sample diversity by applying random yet realistic transformations to the
121
training images, such as random horizontal flipping or small random rotations. This
122
helps expose the model to different aspects of the training data while slowing down
123
overfitting.
124
"""
125
126
data_augmentation_layers = [
127
layers.RandomFlip("horizontal"),
128
layers.RandomRotation(0.1),
129
]
130
131
132
def data_augmentation(images):
133
for layer in data_augmentation_layers:
134
images = layer(images)
135
return images
136
137
138
"""
139
Let's visualize what the augmented samples look like, by applying `data_augmentation`
140
repeatedly to the first few images in the dataset:
141
"""
142
143
plt.figure(figsize=(10, 10))
144
for images, _ in train_ds.take(1):
145
for i in range(9):
146
augmented_images = data_augmentation(images)
147
ax = plt.subplot(3, 3, i + 1)
148
plt.imshow(np.array(augmented_images[0]).astype("uint8"))
149
plt.axis("off")
150
151
152
"""
153
## Standardizing the data
154
155
Our image are already in a standard size (180x180), as they are being yielded as
156
contiguous `float32` batches by our dataset. However, their RGB channel values are in
157
the `[0, 255]` range. This is not ideal for a neural network;
158
in general you should seek to make your input values small. Here, we will
159
standardize values to be in the `[0, 1]` by using a `Rescaling` layer at the start of
160
our model.
161
"""
162
163
"""
164
## Two options to preprocess the data
165
166
There are two ways you could be using the `data_augmentation` preprocessor:
167
168
**Option 1: Make it part of the model**, like this:
169
170
```python
171
inputs = keras.Input(shape=input_shape)
172
x = data_augmentation(inputs)
173
x = layers.Rescaling(1./255)(x)
174
... # Rest of the model
175
```
176
177
With this option, your data augmentation will happen *on device*, synchronously
178
with the rest of the model execution, meaning that it will benefit from GPU
179
acceleration.
180
181
Note that data augmentation is inactive at test time, so the input samples will only be
182
augmented during `fit()`, not when calling `evaluate()` or `predict()`.
183
184
If you're training on GPU, this may be a good option.
185
186
**Option 2: apply it to the dataset**, so as to obtain a dataset that yields batches of
187
augmented images, like this:
188
189
```python
190
augmented_train_ds = train_ds.map(
191
lambda x, y: (data_augmentation(x, training=True), y))
192
```
193
194
With this option, your data augmentation will happen **on CPU**, asynchronously, and will
195
be buffered before going into the model.
196
197
If you're training on CPU, this is the better option, since it makes data augmentation
198
asynchronous and non-blocking.
199
200
In our case, we'll go with the second option. If you're not sure
201
which one to pick, this second option (asynchronous preprocessing) is always a solid choice.
202
"""
203
204
"""
205
## Configure the dataset for performance
206
207
Let's apply data augmentation to our training dataset,
208
and let's make sure to use buffered prefetching so we can yield data from disk without
209
having I/O becoming blocking:
210
"""
211
212
# Apply `data_augmentation` to the training images.
213
train_ds = train_ds.map(
214
lambda img, label: (data_augmentation(img), label),
215
num_parallel_calls=tf_data.AUTOTUNE,
216
)
217
# Prefetching samples in GPU memory helps maximize GPU utilization.
218
train_ds = train_ds.prefetch(tf_data.AUTOTUNE)
219
val_ds = val_ds.prefetch(tf_data.AUTOTUNE)
220
221
"""
222
## Build a model
223
224
We'll build a small version of the Xception network. We haven't particularly tried to
225
optimize the architecture; if you want to do a systematic search for the best model
226
configuration, consider using
227
[KerasTuner](https://github.com/keras-team/keras-tuner).
228
229
Note that:
230
231
- We start the model with the `data_augmentation` preprocessor, followed by a
232
`Rescaling` layer.
233
- We include a `Dropout` layer before the final classification layer.
234
"""
235
236
237
def make_model(input_shape, num_classes):
238
inputs = keras.Input(shape=input_shape)
239
240
# Entry block
241
x = layers.Rescaling(1.0 / 255)(inputs)
242
x = layers.Conv2D(128, 3, strides=2, padding="same")(x)
243
x = layers.BatchNormalization()(x)
244
x = layers.Activation("relu")(x)
245
246
previous_block_activation = x # Set aside residual
247
248
for size in [256, 512, 728]:
249
x = layers.Activation("relu")(x)
250
x = layers.SeparableConv2D(size, 3, padding="same")(x)
251
x = layers.BatchNormalization()(x)
252
253
x = layers.Activation("relu")(x)
254
x = layers.SeparableConv2D(size, 3, padding="same")(x)
255
x = layers.BatchNormalization()(x)
256
257
x = layers.MaxPooling2D(3, strides=2, padding="same")(x)
258
259
# Project residual
260
residual = layers.Conv2D(size, 1, strides=2, padding="same")(
261
previous_block_activation
262
)
263
x = layers.add([x, residual]) # Add back residual
264
previous_block_activation = x # Set aside next residual
265
266
x = layers.SeparableConv2D(1024, 3, padding="same")(x)
267
x = layers.BatchNormalization()(x)
268
x = layers.Activation("relu")(x)
269
270
x = layers.GlobalAveragePooling2D()(x)
271
if num_classes == 2:
272
units = 1
273
else:
274
units = num_classes
275
276
x = layers.Dropout(0.25)(x)
277
# We specify activation=None so as to return logits
278
outputs = layers.Dense(units, activation=None)(x)
279
return keras.Model(inputs, outputs)
280
281
282
model = make_model(input_shape=image_size + (3,), num_classes=2)
283
keras.utils.plot_model(model, show_shapes=True)
284
285
"""
286
## Train the model
287
"""
288
289
epochs = 25
290
291
callbacks = [
292
keras.callbacks.ModelCheckpoint("save_at_{epoch}.keras"),
293
]
294
model.compile(
295
optimizer=keras.optimizers.Adam(3e-4),
296
loss=keras.losses.BinaryCrossentropy(from_logits=True),
297
metrics=[keras.metrics.BinaryAccuracy(name="acc")],
298
)
299
model.fit(
300
train_ds,
301
epochs=epochs,
302
callbacks=callbacks,
303
validation_data=val_ds,
304
)
305
306
"""
307
We get to >90% validation accuracy after training for 25 epochs on the full dataset
308
(in practice, you can train for 50+ epochs before validation performance starts degrading).
309
"""
310
311
"""
312
## Run inference on new data
313
314
Note that data augmentation and dropout are inactive at inference time.
315
"""
316
317
img = keras.utils.load_img("PetImages/Cat/6779.jpg", target_size=image_size)
318
plt.imshow(img)
319
320
img_array = keras.utils.img_to_array(img)
321
img_array = keras.ops.expand_dims(img_array, 0) # Create batch axis
322
323
predictions = model.predict(img_array)
324
score = float(keras.ops.sigmoid(predictions[0][0]))
325
print(f"This image is {100 * (1 - score):.2f}% cat and {100 * score:.2f}% dog.")
326
327