Path: blob/master/examples/vision/image_classification_efficientnet_fine_tuning.py
3507 views
"""1Title: Image classification via fine-tuning with EfficientNet2Author: [Yixing Fu](https://github.com/yixingfu)3Date created: 2020/06/304Last modified: 2023/07/105Description: Use EfficientNet with weights pre-trained on imagenet for Stanford Dogs classification.6Accelerator: GPU7"""89"""1011## Introduction: what is EfficientNet1213EfficientNet, first introduced in [Tan and Le, 2019](https://arxiv.org/abs/1905.11946)14is among the most efficient models (i.e. requiring least FLOPS for inference)15that reaches State-of-the-Art accuracy on both16imagenet and common image classification transfer learning tasks.1718The smallest base model is similar to [MnasNet](https://arxiv.org/abs/1807.11626), which19reached near-SOTA with a significantly smaller model. By introducing a heuristic way to20scale the model, EfficientNet provides a family of models (B0 to B7) that represents a21good combination of efficiency and accuracy on a variety of scales. Such a scaling22heuristics (compound-scaling, details see23[Tan and Le, 2019](https://arxiv.org/abs/1905.11946)) allows the24efficiency-oriented base model (B0) to surpass models at every scale, while avoiding25extensive grid-search of hyperparameters.2627A summary of the latest updates on the model is available at28[here](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet), where various29augmentation schemes and semi-supervised learning approaches are applied to further30improve the imagenet performance of the models. These extensions of the model can be used31by updating weights without changing model architecture.3233## B0 to B7 variants of EfficientNet3435*(This section provides some details on "compound scaling", and can be skipped36if you're only interested in using the models)*3738Based on the [original paper](https://arxiv.org/abs/1905.11946) people may have the39impression that EfficientNet is a continuous family of models created by arbitrarily40choosing scaling factor in as Eq.(3) of the paper. However, choice of resolution,41depth and width are also restricted by many factors:4243- Resolution: Resolutions not divisible by 8, 16, etc. cause zero-padding near boundaries44of some layers which wastes computational resources. This especially applies to smaller45variants of the model, hence the input resolution for B0 and B1 are chosen as 224 and46240.4748- Depth and width: The building blocks of EfficientNet demands channel size to be49multiples of 8.5051- Resource limit: Memory limitation may bottleneck resolution when depth52and width can still increase. In such a situation, increasing depth and/or53width but keep resolution can still improve performance.5455As a result, the depth, width and resolution of each variant of the EfficientNet models56are hand-picked and proven to produce good results, though they may be significantly57off from the compound scaling formula.58Therefore, the keras implementation (detailed below) only provide these 8 models, B0 to B7,59instead of allowing arbitray choice of width / depth / resolution parameters.6061## Keras implementation of EfficientNet6263An implementation of EfficientNet B0 to B7 has been shipped with Keras since v2.3. To64use EfficientNetB0 for classifying 1000 classes of images from ImageNet, run:6566```python67from tensorflow.keras.applications import EfficientNetB068model = EfficientNetB0(weights='imagenet')69```7071This model takes input images of shape `(224, 224, 3)`, and the input data should be in the72range `[0, 255]`. Normalization is included as part of the model.7374Because training EfficientNet on ImageNet takes a tremendous amount of resources and75several techniques that are not a part of the model architecture itself. Hence the Keras76implementation by default loads pre-trained weights obtained via training with77[AutoAugment](https://arxiv.org/abs/1805.09501).7879For B0 to B7 base models, the input shapes are different. Here is a list of input shape80expected for each model:8182| Base model | resolution|83|----------------|-----|84| EfficientNetB0 | 224 |85| EfficientNetB1 | 240 |86| EfficientNetB2 | 260 |87| EfficientNetB3 | 300 |88| EfficientNetB4 | 380 |89| EfficientNetB5 | 456 |90| EfficientNetB6 | 528 |91| EfficientNetB7 | 600 |9293When the model is intended for transfer learning, the Keras implementation94provides a option to remove the top layers:95```96model = EfficientNetB0(include_top=False, weights='imagenet')97```98This option excludes the final `Dense` layer that turns 1280 features on the penultimate99layer into prediction of the 1000 ImageNet classes. Replacing the top layer with custom100layers allows using EfficientNet as a feature extractor in a transfer learning workflow.101102Another argument in the model constructor worth noticing is `drop_connect_rate` which controls103the dropout rate responsible for [stochastic depth](https://arxiv.org/abs/1603.09382).104This parameter serves as a toggle for extra regularization in finetuning, but does not105affect loaded weights. For example, when stronger regularization is desired, try:106107```python108model = EfficientNetB0(weights='imagenet', drop_connect_rate=0.4)109```110The default value is 0.2.111112## Example: EfficientNetB0 for Stanford Dogs.113114EfficientNet is capable of a wide range of image classification tasks.115This makes it a good model for transfer learning.116As an end-to-end example, we will show using pre-trained EfficientNetB0 on117[Stanford Dogs](http://vision.stanford.edu/aditya86/ImageNetDogs/main.html) dataset.118119"""120121"""122## Setup and data loading123"""124125import numpy as np126import tensorflow_datasets as tfds127import tensorflow as tf # For tf.data128import matplotlib.pyplot as plt129import keras130from keras import layers131from keras.applications import EfficientNetB0132133# IMG_SIZE is determined by EfficientNet model choice134IMG_SIZE = 224135BATCH_SIZE = 64136137138"""139### Loading data140141Here we load data from [tensorflow_datasets](https://www.tensorflow.org/datasets)142(hereafter TFDS).143Stanford Dogs dataset is provided in144TFDS as [stanford_dogs](https://www.tensorflow.org/datasets/catalog/stanford_dogs).145It features 20,580 images that belong to 120 classes of dog breeds146(12,000 for training and 8,580 for testing).147148By simply changing `dataset_name` below, you may also try this notebook for149other datasets in TFDS such as150[cifar10](https://www.tensorflow.org/datasets/catalog/cifar10),151[cifar100](https://www.tensorflow.org/datasets/catalog/cifar100),152[food101](https://www.tensorflow.org/datasets/catalog/food101),153etc. When the images are much smaller than the size of EfficientNet input,154we can simply upsample the input images. It has been shown in155[Tan and Le, 2019](https://arxiv.org/abs/1905.11946) that transfer learning156result is better for increased resolution even if input images remain small.157"""158159dataset_name = "stanford_dogs"160(ds_train, ds_test), ds_info = tfds.load(161dataset_name, split=["train", "test"], with_info=True, as_supervised=True162)163NUM_CLASSES = ds_info.features["label"].num_classes164165166"""167When the dataset include images with various size, we need to resize them into a168shared size. The Stanford Dogs dataset includes only images at least 200x200169pixels in size. Here we resize the images to the input size needed for EfficientNet.170"""171172size = (IMG_SIZE, IMG_SIZE)173ds_train = ds_train.map(lambda image, label: (tf.image.resize(image, size), label))174ds_test = ds_test.map(lambda image, label: (tf.image.resize(image, size), label))175176"""177### Visualizing the data178179The following code shows the first 9 images with their labels.180"""181182183def format_label(label):184string_label = label_info.int2str(label)185return string_label.split("-")[1]186187188label_info = ds_info.features["label"]189for i, (image, label) in enumerate(ds_train.take(9)):190ax = plt.subplot(3, 3, i + 1)191plt.imshow(image.numpy().astype("uint8"))192plt.title("{}".format(format_label(label)))193plt.axis("off")194195196"""197### Data augmentation198199We can use the preprocessing layers APIs for image augmentation.200"""201202img_augmentation_layers = [203layers.RandomRotation(factor=0.15),204layers.RandomTranslation(height_factor=0.1, width_factor=0.1),205layers.RandomFlip(),206layers.RandomContrast(factor=0.1),207]208209210def img_augmentation(images):211for layer in img_augmentation_layers:212images = layer(images)213return images214215216"""217This `Sequential` model object can be used both as a part of218the model we later build, and as a function to preprocess219data before feeding into the model. Using them as function makes220it easy to visualize the augmented images. Here we plot 9 examples221of augmentation result of a given figure.222"""223224for image, label in ds_train.take(1):225for i in range(9):226ax = plt.subplot(3, 3, i + 1)227aug_img = img_augmentation(np.expand_dims(image.numpy(), axis=0))228aug_img = np.array(aug_img)229plt.imshow(aug_img[0].astype("uint8"))230plt.title("{}".format(format_label(label)))231plt.axis("off")232233234"""235### Prepare inputs236237Once we verify the input data and augmentation are working correctly,238we prepare dataset for training. The input data are resized to uniform239`IMG_SIZE`. The labels are put into one-hot240(a.k.a. categorical) encoding. The dataset is batched.241242Note: `prefetch` and `AUTOTUNE` may in some situation improve243performance, but depends on environment and the specific dataset used.244See this [guide](https://www.tensorflow.org/guide/data_performance)245for more information on data pipeline performance.246"""247248249# One-hot / categorical encoding250def input_preprocess_train(image, label):251image = img_augmentation(image)252label = tf.one_hot(label, NUM_CLASSES)253return image, label254255256def input_preprocess_test(image, label):257label = tf.one_hot(label, NUM_CLASSES)258return image, label259260261ds_train = ds_train.map(input_preprocess_train, num_parallel_calls=tf.data.AUTOTUNE)262ds_train = ds_train.batch(batch_size=BATCH_SIZE, drop_remainder=True)263ds_train = ds_train.prefetch(tf.data.AUTOTUNE)264265ds_test = ds_test.map(input_preprocess_test, num_parallel_calls=tf.data.AUTOTUNE)266ds_test = ds_test.batch(batch_size=BATCH_SIZE, drop_remainder=True)267268269"""270## Training a model from scratch271272We build an EfficientNetB0 with 120 output classes, that is initialized from scratch:273274Note: the accuracy will increase very slowly and may overfit.275"""276277model = EfficientNetB0(278include_top=True,279weights=None,280classes=NUM_CLASSES,281input_shape=(IMG_SIZE, IMG_SIZE, 3),282)283model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])284285model.summary()286287epochs = 40 # @param {type: "slider", min:10, max:100}288hist = model.fit(ds_train, epochs=epochs, validation_data=ds_test)289290291"""292Training the model is relatively fast. This might make it sounds easy to simply train EfficientNet on any293dataset wanted from scratch. However, training EfficientNet on smaller datasets,294especially those with lower resolution like CIFAR-100, faces the significant challenge of295overfitting.296297Hence training from scratch requires very careful choice of hyperparameters and is298difficult to find suitable regularization. It would also be much more demanding in resources.299Plotting the training and validation accuracy300makes it clear that validation accuracy stagnates at a low value.301"""302303import matplotlib.pyplot as plt304305306def plot_hist(hist):307plt.plot(hist.history["accuracy"])308plt.plot(hist.history["val_accuracy"])309plt.title("model accuracy")310plt.ylabel("accuracy")311plt.xlabel("epoch")312plt.legend(["train", "validation"], loc="upper left")313plt.show()314315316plot_hist(hist)317318"""319## Transfer learning from pre-trained weights320321Here we initialize the model with pre-trained ImageNet weights,322and we fine-tune it on our own dataset.323"""324325326def build_model(num_classes):327inputs = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3))328model = EfficientNetB0(include_top=False, input_tensor=inputs, weights="imagenet")329330# Freeze the pretrained weights331model.trainable = False332333# Rebuild top334x = layers.GlobalAveragePooling2D(name="avg_pool")(model.output)335x = layers.BatchNormalization()(x)336337top_dropout_rate = 0.2338x = layers.Dropout(top_dropout_rate, name="top_dropout")(x)339outputs = layers.Dense(num_classes, activation="softmax", name="pred")(x)340341# Compile342model = keras.Model(inputs, outputs, name="EfficientNet")343optimizer = keras.optimizers.Adam(learning_rate=1e-2)344model.compile(345optimizer=optimizer, loss="categorical_crossentropy", metrics=["accuracy"]346)347return model348349350"""351The first step to transfer learning is to freeze all layers and train only the top352layers. For this step, a relatively large learning rate (1e-2) can be used.353Note that validation accuracy and loss will usually be better than training354accuracy and loss. This is because the regularization is strong, which only355suppresses training-time metrics.356357Note that the convergence may take up to 50 epochs depending on choice of learning rate.358If image augmentation layers were not359applied, the validation accuracy may only reach ~60%.360"""361362model = build_model(num_classes=NUM_CLASSES)363364epochs = 25 # @param {type: "slider", min:8, max:80}365hist = model.fit(ds_train, epochs=epochs, validation_data=ds_test)366plot_hist(hist)367368"""369The second step is to unfreeze a number of layers and fit the model using smaller370learning rate. In this example we show unfreezing all layers, but depending on371specific dataset it may be desireble to only unfreeze a fraction of all layers.372373When the feature extraction with374pretrained model works good enough, this step would give a very limited gain on375validation accuracy. In our case we only see a small improvement,376as ImageNet pretraining already exposed the model to a good amount of dogs.377378On the other hand, when we use pretrained weights on a dataset that is more different379from ImageNet, this fine-tuning step can be crucial as the feature extractor also380needs to be adjusted by a considerable amount. Such a situation can be demonstrated381if choosing CIFAR-100 dataset instead, where fine-tuning boosts validation accuracy382by about 10% to pass 80% on `EfficientNetB0`.383384A side note on freezing/unfreezing models: setting `trainable` of a `Model` will385simultaneously set all layers belonging to the `Model` to the same `trainable`386attribute. Each layer is trainable only if both the layer itself and the model387containing it are trainable. Hence when we need to partially freeze/unfreeze388a model, we need to make sure the `trainable` attribute of the model is set389to `True`.390"""391392393def unfreeze_model(model):394# We unfreeze the top 20 layers while leaving BatchNorm layers frozen395for layer in model.layers[-20:]:396if not isinstance(layer, layers.BatchNormalization):397layer.trainable = True398399optimizer = keras.optimizers.Adam(learning_rate=1e-5)400model.compile(401optimizer=optimizer, loss="categorical_crossentropy", metrics=["accuracy"]402)403404405unfreeze_model(model)406407epochs = 4 # @param {type: "slider", min:4, max:10}408hist = model.fit(ds_train, epochs=epochs, validation_data=ds_test)409plot_hist(hist)410411"""412### Tips for fine tuning EfficientNet413414On unfreezing layers:415416- The `BatchNormalization` layers need to be kept frozen417([more details](https://keras.io/guides/transfer_learning/)).418If they are also turned to trainable, the419first epoch after unfreezing will significantly reduce accuracy.420- In some cases it may be beneficial to open up only a portion of layers instead of421unfreezing all. This will make fine tuning much faster when going to larger models like422B7.423- Each block needs to be all turned on or off. This is because the architecture includes424a shortcut from the first layer to the last layer for each block. Not respecting blocks425also significantly harms the final performance.426427Some other tips for utilizing EfficientNet:428429- Larger variants of EfficientNet do not guarantee improved performance, especially for430tasks with less data or fewer classes. In such a case, the larger variant of EfficientNet431chosen, the harder it is to tune hyperparameters.432- EMA (Exponential Moving Average) is very helpful in training EfficientNet from scratch,433but not so much for transfer learning.434- Do not use the RMSprop setup as in the original paper for transfer learning. The435momentum and learning rate are too high for transfer learning. It will easily corrupt the436pretrained weight and blow up the loss. A quick check is to see if loss (as categorical437cross entropy) is getting significantly larger than log(NUM_CLASSES) after the same438epoch. If so, the initial learning rate/momentum is too high.439- Smaller batch size benefit validation accuracy, possibly due to effectively providing440regularization.441"""442443444