Path: blob/master/examples/vision/image_classification_from_scratch.py
3507 views
"""1Title: Image classification from scratch2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2020/04/274Last modified: 2023/11/095Description: Training an image classifier from scratch on the Kaggle Cats vs Dogs dataset.6Accelerator: GPU7"""89"""10## Introduction1112This example shows how to do image classification from scratch, starting from JPEG13image files on disk, without leveraging pre-trained weights or a pre-made Keras14Application model. We demonstrate the workflow on the Kaggle Cats vs Dogs binary15classification dataset.1617We use the `image_dataset_from_directory` utility to generate the datasets, and18we use Keras image preprocessing layers for image standardization and data augmentation.19"""2021"""22## Setup23"""2425import os26import numpy as np27import keras28from keras import layers29from tensorflow import data as tf_data30import matplotlib.pyplot as plt3132"""33## Load the data: the Cats vs Dogs dataset3435### Raw data download3637First, let's download the 786M ZIP archive of the raw data:38"""3940"""shell41curl -O https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip42"""4344"""shell45unzip -q kagglecatsanddogs_5340.zip46ls47"""4849"""50Now we have a `PetImages` folder which contain two subfolders, `Cat` and `Dog`. Each51subfolder contains image files for each category.52"""5354"""shell55ls PetImages56"""5758"""59### Filter out corrupted images6061When working with lots of real-world image data, corrupted images are a common62occurence. Let's filter out badly-encoded images that do not feature the string "JFIF"63in their header.64"""6566num_skipped = 067for folder_name in ("Cat", "Dog"):68folder_path = os.path.join("PetImages", folder_name)69for fname in os.listdir(folder_path):70fpath = os.path.join(folder_path, fname)71try:72fobj = open(fpath, "rb")73is_jfif = b"JFIF" in fobj.peek(10)74finally:75fobj.close()7677if not is_jfif:78num_skipped += 179# Delete corrupted image80os.remove(fpath)8182print(f"Deleted {num_skipped} images.")8384"""85## Generate a `Dataset`86"""8788image_size = (180, 180)89batch_size = 1289091train_ds, val_ds = keras.utils.image_dataset_from_directory(92"PetImages",93validation_split=0.2,94subset="both",95seed=1337,96image_size=image_size,97batch_size=batch_size,98)99100"""101## Visualize the data102103Here are the first 9 images in the training dataset.104"""105106107plt.figure(figsize=(10, 10))108for images, labels in train_ds.take(1):109for i in range(9):110ax = plt.subplot(3, 3, i + 1)111plt.imshow(np.array(images[i]).astype("uint8"))112plt.title(int(labels[i]))113plt.axis("off")114115"""116## Using image data augmentation117118When you don't have a large image dataset, it's a good practice to artificially119introduce sample diversity by applying random yet realistic transformations to the120training images, such as random horizontal flipping or small random rotations. This121helps expose the model to different aspects of the training data while slowing down122overfitting.123"""124125data_augmentation_layers = [126layers.RandomFlip("horizontal"),127layers.RandomRotation(0.1),128]129130131def data_augmentation(images):132for layer in data_augmentation_layers:133images = layer(images)134return images135136137"""138Let's visualize what the augmented samples look like, by applying `data_augmentation`139repeatedly to the first few images in the dataset:140"""141142plt.figure(figsize=(10, 10))143for images, _ in train_ds.take(1):144for i in range(9):145augmented_images = data_augmentation(images)146ax = plt.subplot(3, 3, i + 1)147plt.imshow(np.array(augmented_images[0]).astype("uint8"))148plt.axis("off")149150151"""152## Standardizing the data153154Our image are already in a standard size (180x180), as they are being yielded as155contiguous `float32` batches by our dataset. However, their RGB channel values are in156the `[0, 255]` range. This is not ideal for a neural network;157in general you should seek to make your input values small. Here, we will158standardize values to be in the `[0, 1]` by using a `Rescaling` layer at the start of159our model.160"""161162"""163## Two options to preprocess the data164165There are two ways you could be using the `data_augmentation` preprocessor:166167**Option 1: Make it part of the model**, like this:168169```python170inputs = keras.Input(shape=input_shape)171x = data_augmentation(inputs)172x = layers.Rescaling(1./255)(x)173... # Rest of the model174```175176With this option, your data augmentation will happen *on device*, synchronously177with the rest of the model execution, meaning that it will benefit from GPU178acceleration.179180Note that data augmentation is inactive at test time, so the input samples will only be181augmented during `fit()`, not when calling `evaluate()` or `predict()`.182183If you're training on GPU, this may be a good option.184185**Option 2: apply it to the dataset**, so as to obtain a dataset that yields batches of186augmented images, like this:187188```python189augmented_train_ds = train_ds.map(190lambda x, y: (data_augmentation(x, training=True), y))191```192193With this option, your data augmentation will happen **on CPU**, asynchronously, and will194be buffered before going into the model.195196If you're training on CPU, this is the better option, since it makes data augmentation197asynchronous and non-blocking.198199In our case, we'll go with the second option. If you're not sure200which one to pick, this second option (asynchronous preprocessing) is always a solid choice.201"""202203"""204## Configure the dataset for performance205206Let's apply data augmentation to our training dataset,207and let's make sure to use buffered prefetching so we can yield data from disk without208having I/O becoming blocking:209"""210211# Apply `data_augmentation` to the training images.212train_ds = train_ds.map(213lambda img, label: (data_augmentation(img), label),214num_parallel_calls=tf_data.AUTOTUNE,215)216# Prefetching samples in GPU memory helps maximize GPU utilization.217train_ds = train_ds.prefetch(tf_data.AUTOTUNE)218val_ds = val_ds.prefetch(tf_data.AUTOTUNE)219220"""221## Build a model222223We'll build a small version of the Xception network. We haven't particularly tried to224optimize the architecture; if you want to do a systematic search for the best model225configuration, consider using226[KerasTuner](https://github.com/keras-team/keras-tuner).227228Note that:229230- We start the model with the `data_augmentation` preprocessor, followed by a231`Rescaling` layer.232- We include a `Dropout` layer before the final classification layer.233"""234235236def make_model(input_shape, num_classes):237inputs = keras.Input(shape=input_shape)238239# Entry block240x = layers.Rescaling(1.0 / 255)(inputs)241x = layers.Conv2D(128, 3, strides=2, padding="same")(x)242x = layers.BatchNormalization()(x)243x = layers.Activation("relu")(x)244245previous_block_activation = x # Set aside residual246247for size in [256, 512, 728]:248x = layers.Activation("relu")(x)249x = layers.SeparableConv2D(size, 3, padding="same")(x)250x = layers.BatchNormalization()(x)251252x = layers.Activation("relu")(x)253x = layers.SeparableConv2D(size, 3, padding="same")(x)254x = layers.BatchNormalization()(x)255256x = layers.MaxPooling2D(3, strides=2, padding="same")(x)257258# Project residual259residual = layers.Conv2D(size, 1, strides=2, padding="same")(260previous_block_activation261)262x = layers.add([x, residual]) # Add back residual263previous_block_activation = x # Set aside next residual264265x = layers.SeparableConv2D(1024, 3, padding="same")(x)266x = layers.BatchNormalization()(x)267x = layers.Activation("relu")(x)268269x = layers.GlobalAveragePooling2D()(x)270if num_classes == 2:271units = 1272else:273units = num_classes274275x = layers.Dropout(0.25)(x)276# We specify activation=None so as to return logits277outputs = layers.Dense(units, activation=None)(x)278return keras.Model(inputs, outputs)279280281model = make_model(input_shape=image_size + (3,), num_classes=2)282keras.utils.plot_model(model, show_shapes=True)283284"""285## Train the model286"""287288epochs = 25289290callbacks = [291keras.callbacks.ModelCheckpoint("save_at_{epoch}.keras"),292]293model.compile(294optimizer=keras.optimizers.Adam(3e-4),295loss=keras.losses.BinaryCrossentropy(from_logits=True),296metrics=[keras.metrics.BinaryAccuracy(name="acc")],297)298model.fit(299train_ds,300epochs=epochs,301callbacks=callbacks,302validation_data=val_ds,303)304305"""306We get to >90% validation accuracy after training for 25 epochs on the full dataset307(in practice, you can train for 50+ epochs before validation performance starts degrading).308"""309310"""311## Run inference on new data312313Note that data augmentation and dropout are inactive at inference time.314"""315316img = keras.utils.load_img("PetImages/Cat/6779.jpg", target_size=image_size)317plt.imshow(img)318319img_array = keras.utils.img_to_array(img)320img_array = keras.ops.expand_dims(img_array, 0) # Create batch axis321322predictions = model.predict(img_array)323score = float(keras.ops.sigmoid(predictions[0][0]))324print(f"This image is {100 * (1 - score):.2f}% cat and {100 * score:.2f}% dog.")325326327