Path: blob/master/examples/vision/fully_convolutional_network.py
3507 views
"""1Title: Image Segmentation using Composable Fully-Convolutional Networks2Author: [Suvaditya Mukherjee](https://twitter.com/halcyonrayes)3Date created: 2023/06/164Last modified: 2023/12/255Description: Using the Fully-Convolutional Network for Image Segmentation.6Accelerator: GPU7"""89"""10## Introduction1112The following example walks through the steps to implement Fully-Convolutional Networks13for Image Segmentation on the Oxford-IIIT Pets dataset.14The model was proposed in the paper,15[Fully Convolutional Networks for Semantic Segmentation by Long et. al.(2014)](https://arxiv.org/abs/1411.4038).16Image segmentation is one of the most common and introductory tasks when it comes to17Computer Vision, where we extend the problem of Image Classification from18one-label-per-image to a pixel-wise classification problem.19In this example, we will assemble the aforementioned Fully-Convolutional Segmentation architecture,20capable of performing Image Segmentation.21The network extends the pooling layer outputs from the VGG in order to perform upsampling22and get a final result. The intermediate outputs coming from the 3rd, 4th and 5th Max-Pooling layers from VGG19 are23extracted out and upsampled at different levels and factors to get a final output with the same shape as that24of the output, but with the class of each pixel present at each location, instead of pixel intensity values.25Different intermediate pool layers are extracted and processed upon for different versions of the network.26The FCN architecture has 3 versions of differing quality.2728- FCN-32S29- FCN-16S30- FCN-8S3132All versions of the model derive their outputs through an iterative processing of33successive intermediate pool layers of the main backbone used.34A better idea can be gained from the figure below.3536|  |37| :--: |38| **Diagram 1**: Combined Architecture Versions (Source: Paper) |3940To get a better idea on Image Segmentation or find more pre-trained models, feel free to41navigate to the [Hugging Face Image Segmentation Models](https://huggingface.co/models?pipeline_tag=image-segmentation) page,42or a [PyImageSearch Blog on Semantic Segmentation](https://pyimagesearch.com/2018/09/03/semantic-segmentation-with-opencv-and-deep-learning/)4344"""4546"""47## Setup Imports48"""4950import os5152os.environ["KERAS_BACKEND"] = "tensorflow"53import keras54from keras import ops55import tensorflow as tf56import matplotlib.pyplot as plt57import tensorflow_datasets as tfds58import numpy as np5960AUTOTUNE = tf.data.AUTOTUNE6162"""63## Set configurations for notebook variables6465We set the required parameters for the experiment.66The chosen dataset has a total of 4 classes per image, with regards to the segmentation mask.67We also set our hyperparameters in this cell.6869Mixed Precision as an option is also available in systems which support it, to reduce70load.71This would make most tensors use `16-bit float` values instead of `32-bit float`72values, in places where it will not adversely affect computation.73This means, during computation, TensorFlow will use `16-bit float` Tensors to increase speed at the cost of precision,74while storing the values in their original default `32-bit float` form.75"""7677NUM_CLASSES = 478INPUT_HEIGHT = 22479INPUT_WIDTH = 22480LEARNING_RATE = 1e-381WEIGHT_DECAY = 1e-482EPOCHS = 2083BATCH_SIZE = 3284MIXED_PRECISION = True85SHUFFLE = True8687# Mixed-precision setting88if MIXED_PRECISION:89policy = keras.mixed_precision.Policy("mixed_float16")90keras.mixed_precision.set_global_policy(policy)9192"""93## Load dataset9495We make use of the [Oxford-IIIT Pets dataset](http://www.robots.ox.ac.uk/~vgg/data/pets/)96which contains a total of 7,349 samples and their segmentation masks.97We have 37 classes, with roughly 200 samples per class.98Our training and validation dataset has 3,128 and 552 samples respectively.99Aside from this, our test split has a total of 3,669 samples.100101We set a `batch_size` parameter that will batch our samples together, use a `shuffle`102parameter to mix our samples together.103"""104105(train_ds, valid_ds, test_ds) = tfds.load(106"oxford_iiit_pet",107split=["train[:85%]", "train[85%:]", "test"],108batch_size=BATCH_SIZE,109shuffle_files=SHUFFLE,110)111112"""113## Unpack and preprocess dataset114115We define a simple function that includes performs Resizing over our116training, validation and test datasets.117We do the same process on the masks as well, to make sure both are aligned in terms of shape and size.118"""119120121# Image and Mask Pre-processing122def unpack_resize_data(section):123image = section["image"]124segmentation_mask = section["segmentation_mask"]125126resize_layer = keras.layers.Resizing(INPUT_HEIGHT, INPUT_WIDTH)127128image = resize_layer(image)129segmentation_mask = resize_layer(segmentation_mask)130131return image, segmentation_mask132133134train_ds = train_ds.map(unpack_resize_data, num_parallel_calls=AUTOTUNE)135valid_ds = valid_ds.map(unpack_resize_data, num_parallel_calls=AUTOTUNE)136test_ds = test_ds.map(unpack_resize_data, num_parallel_calls=AUTOTUNE)137"""138## Visualize one random sample from the pre-processed dataset139140We visualize what a random sample in our test split of the dataset looks like, and plot141the segmentation mask on top to see the effective mask areas.142Note that we have performed pre-processing on this dataset too,143which makes the image and mask size same.144"""145146# Select random image and mask. Cast to NumPy array147# for Matplotlib visualization.148149images, masks = next(iter(test_ds))150random_idx = keras.random.uniform([], minval=0, maxval=BATCH_SIZE, seed=10)151152test_image = images[int(random_idx)].numpy().astype("float")153test_mask = masks[int(random_idx)].numpy().astype("float")154155# Overlay segmentation mask on top of image.156fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))157158ax[0].set_title("Image")159ax[0].imshow(test_image / 255.0)160161ax[1].set_title("Image with segmentation mask overlay")162ax[1].imshow(test_image / 255.0)163ax[1].imshow(164test_mask,165cmap="inferno",166alpha=0.6,167)168plt.show()169170"""171## Perform VGG-specific pre-processing172173`keras.applications.VGG19` requires the use of a `preprocess_input` function that will174pro-actively perform Image-net style Standard Deviation Normalization scheme.175"""176177178def preprocess_data(image, segmentation_mask):179image = keras.applications.vgg19.preprocess_input(image)180181return image, segmentation_mask182183184train_ds = (185train_ds.map(preprocess_data, num_parallel_calls=AUTOTUNE)186.shuffle(buffer_size=1024)187.prefetch(buffer_size=1024)188)189valid_ds = (190valid_ds.map(preprocess_data, num_parallel_calls=AUTOTUNE)191.shuffle(buffer_size=1024)192.prefetch(buffer_size=1024)193)194test_ds = (195test_ds.map(preprocess_data, num_parallel_calls=AUTOTUNE)196.shuffle(buffer_size=1024)197.prefetch(buffer_size=1024)198)199"""200## Model Definition201202The Fully-Convolutional Network boasts a simple architecture composed of only203`keras.layers.Conv2D` Layers, `keras.layers.Dense` layers and `keras.layers.Dropout`204layers.205206|  |207| :--: |208| **Diagram 2**: Generic FCN Forward Pass (Source: Paper)|209210Pixel-wise prediction is performed by having a Softmax Convolutional layer with the same211size of the image, such that we can perform direct comparison212We can find several important metrics such as Accuracy and Mean-Intersection-over-Union on the network.213"""214215"""216### Backbone (VGG-19)217218We use the [VGG-19 network](https://keras.io/api/applications/vgg/) as the backbone, as219the paper suggests it to be one of the most effective backbones for this network.220We extract different outputs from the network by making use of `keras.models.Model`.221Following this, we add layers on top to make a network perfectly simulating that of222Diagram 1.223The backbone's `keras.layers.Dense` layers will be converted to `keras.layers.Conv2D`224layers based on the [original Caffe code present here.](https://github.com/linxi159/FCN-caffe/blob/master/pascalcontext-fcn16s/net.py)225All 3 networks will share the same backbone weights, but will have differing results226based on their extensions.227We make the backbone non-trainable to improve training time requirements.228It is also noted in the paper that making the network trainable does not yield major benefits.229"""230231input_layer = keras.Input(shape=(INPUT_HEIGHT, INPUT_WIDTH, 3))232233# VGG Model backbone with pre-trained ImageNet weights.234vgg_model = keras.applications.vgg19.VGG19(include_top=True, weights="imagenet")235236# Extracting different outputs from same model237fcn_backbone = keras.models.Model(238inputs=vgg_model.layers[1].input,239outputs=[240vgg_model.get_layer(block_name).output241for block_name in ["block3_pool", "block4_pool", "block5_pool"]242],243)244245# Setting backbone to be non-trainable246fcn_backbone.trainable = False247248x = fcn_backbone(input_layer)249250# Converting Dense layers to Conv2D layers251units = [4096, 4096]252dense_convs = []253254for filter_idx in range(len(units)):255dense_conv = keras.layers.Conv2D(256filters=units[filter_idx],257kernel_size=(7, 7) if filter_idx == 0 else (1, 1),258strides=(1, 1),259activation="relu",260padding="same",261use_bias=False,262kernel_initializer=keras.initializers.Constant(1.0),263)264dense_convs.append(dense_conv)265dropout_layer = keras.layers.Dropout(0.5)266dense_convs.append(dropout_layer)267268dense_convs = keras.Sequential(dense_convs)269dense_convs.trainable = False270271x[-1] = dense_convs(x[-1])272273pool3_output, pool4_output, pool5_output = x274275"""276### FCN-32S277278We extend the last output, perform a `1x1 Convolution` and perform 2D Bilinear Upsampling279by a factor of 32 to get an image of the same size as that of our input.280We use a simple `keras.layers.UpSampling2D` layer over a `keras.layers.Conv2DTranspose`281since it yields performance benefits from being a deterministic mathematical operation282over a Convolutional operation283It is also noted in the paper that making the Up-sampling parameters trainable does not yield benefits.284Original experiments of the paper used Upsampling as well.285"""286287# 1x1 convolution to set channels = number of classes288pool5 = keras.layers.Conv2D(289filters=NUM_CLASSES,290kernel_size=(1, 1),291padding="same",292strides=(1, 1),293activation="relu",294)295296# Get Softmax outputs for all classes297fcn32s_conv_layer = keras.layers.Conv2D(298filters=NUM_CLASSES,299kernel_size=(1, 1),300activation="softmax",301padding="same",302strides=(1, 1),303)304305# Up-sample to original image size306fcn32s_upsampling = keras.layers.UpSampling2D(307size=(32, 32),308data_format=keras.backend.image_data_format(),309interpolation="bilinear",310)311312final_fcn32s_pool = pool5(pool5_output)313final_fcn32s_output = fcn32s_conv_layer(final_fcn32s_pool)314final_fcn32s_output = fcn32s_upsampling(final_fcn32s_output)315316fcn32s_model = keras.Model(inputs=input_layer, outputs=final_fcn32s_output)317318"""319### FCN-16S320321The pooling output from the FCN-32S is extended and added to the 4th-level Pooling output322of our backbone.323Following this, we upsample by a factor of 16 to get image of the same324size as that of our input.325"""326327# 1x1 convolution to set channels = number of classes328# Followed from the original Caffe implementation329pool4 = keras.layers.Conv2D(330filters=NUM_CLASSES,331kernel_size=(1, 1),332padding="same",333strides=(1, 1),334activation="linear",335kernel_initializer=keras.initializers.Zeros(),336)(pool4_output)337338# Intermediate up-sample339pool5 = keras.layers.UpSampling2D(340size=(2, 2),341data_format=keras.backend.image_data_format(),342interpolation="bilinear",343)(final_fcn32s_pool)344345# Get Softmax outputs for all classes346fcn16s_conv_layer = keras.layers.Conv2D(347filters=NUM_CLASSES,348kernel_size=(1, 1),349activation="softmax",350padding="same",351strides=(1, 1),352)353354# Up-sample to original image size355fcn16s_upsample_layer = keras.layers.UpSampling2D(356size=(16, 16),357data_format=keras.backend.image_data_format(),358interpolation="bilinear",359)360361# Add intermediate outputs362final_fcn16s_pool = keras.layers.Add()([pool4, pool5])363final_fcn16s_output = fcn16s_conv_layer(final_fcn16s_pool)364final_fcn16s_output = fcn16s_upsample_layer(final_fcn16s_output)365366fcn16s_model = keras.models.Model(inputs=input_layer, outputs=final_fcn16s_output)367368"""369### FCN-8S370371The pooling output from the FCN-16S is extended once more, and added from the 3rd-level372Pooling output of our backbone.373This result is upsampled by a factor of 8 to get an image of the same size as that of our input.374"""375376# 1x1 convolution to set channels = number of classes377# Followed from the original Caffe implementation378pool3 = keras.layers.Conv2D(379filters=NUM_CLASSES,380kernel_size=(1, 1),381padding="same",382strides=(1, 1),383activation="linear",384kernel_initializer=keras.initializers.Zeros(),385)(pool3_output)386387# Intermediate up-sample388intermediate_pool_output = keras.layers.UpSampling2D(389size=(2, 2),390data_format=keras.backend.image_data_format(),391interpolation="bilinear",392)(final_fcn16s_pool)393394# Get Softmax outputs for all classes395fcn8s_conv_layer = keras.layers.Conv2D(396filters=NUM_CLASSES,397kernel_size=(1, 1),398activation="softmax",399padding="same",400strides=(1, 1),401)402403# Up-sample to original image size404fcn8s_upsample_layer = keras.layers.UpSampling2D(405size=(8, 8),406data_format=keras.backend.image_data_format(),407interpolation="bilinear",408)409410# Add intermediate outputs411final_fcn8s_pool = keras.layers.Add()([pool3, intermediate_pool_output])412final_fcn8s_output = fcn8s_conv_layer(final_fcn8s_pool)413final_fcn8s_output = fcn8s_upsample_layer(final_fcn8s_output)414415fcn8s_model = keras.models.Model(inputs=input_layer, outputs=final_fcn8s_output)416417"""418### Load weights into backbone419420It was noted in the paper, as well as through experimentation that extracting the weights421of the last 2 Fully-connected Dense layers from the backbone, reshaping the weights to422fit that of the `keras.layers.Dense` layers we had previously converted into423`keras.layers.Conv2D`, and setting them to it yields far better results and a significant424increase in mIOU performance.425"""426427# VGG's last 2 layers428weights1 = vgg_model.get_layer("fc1").get_weights()[0]429weights2 = vgg_model.get_layer("fc2").get_weights()[0]430431weights1 = weights1.reshape(7, 7, 512, 4096)432weights2 = weights2.reshape(1, 1, 4096, 4096)433434dense_convs.layers[0].set_weights([weights1])435dense_convs.layers[2].set_weights([weights2])436437"""438## Training439440The original paper talks about making use of [SGD with Momentum](https://keras.io/api/optimizers/sgd/) as the optimizer of choice.441But it was noticed during experimentation that442[AdamW](https://keras.io/api/optimizers/adamw/)443yielded better results in terms of mIOU and Pixel-wise Accuracy.444"""445446"""447### FCN-32S448"""449450fcn32s_optimizer = keras.optimizers.AdamW(451learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY452)453454fcn32s_loss = keras.losses.SparseCategoricalCrossentropy()455456# Maintain mIOU and Pixel-wise Accuracy as metrics457fcn32s_model.compile(458optimizer=fcn32s_optimizer,459loss=fcn32s_loss,460metrics=[461keras.metrics.MeanIoU(num_classes=NUM_CLASSES, sparse_y_pred=False),462keras.metrics.SparseCategoricalAccuracy(),463],464)465466fcn32s_history = fcn32s_model.fit(train_ds, epochs=EPOCHS, validation_data=valid_ds)467468"""469### FCN-16S470"""471472fcn16s_optimizer = keras.optimizers.AdamW(473learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY474)475476fcn16s_loss = keras.losses.SparseCategoricalCrossentropy()477478# Maintain mIOU and Pixel-wise Accuracy as metrics479fcn16s_model.compile(480optimizer=fcn16s_optimizer,481loss=fcn16s_loss,482metrics=[483keras.metrics.MeanIoU(num_classes=NUM_CLASSES, sparse_y_pred=False),484keras.metrics.SparseCategoricalAccuracy(),485],486)487488fcn16s_history = fcn16s_model.fit(train_ds, epochs=EPOCHS, validation_data=valid_ds)489490"""491### FCN-8S492"""493494fcn8s_optimizer = keras.optimizers.AdamW(495learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY496)497498fcn8s_loss = keras.losses.SparseCategoricalCrossentropy()499500# Maintain mIOU and Pixel-wise Accuracy as metrics501fcn8s_model.compile(502optimizer=fcn8s_optimizer,503loss=fcn8s_loss,504metrics=[505keras.metrics.MeanIoU(num_classes=NUM_CLASSES, sparse_y_pred=False),506keras.metrics.SparseCategoricalAccuracy(),507],508)509510fcn8s_history = fcn8s_model.fit(train_ds, epochs=EPOCHS, validation_data=valid_ds)511"""512## Visualizations513"""514515"""516### Plotting metrics for training run517518We perform a comparative study between all 3 versions of the model by tracking training519and validation metrics of Accuracy, Loss and Mean IoU.520"""521522total_plots = len(fcn32s_history.history)523cols = total_plots // 2524525rows = total_plots // cols526527if total_plots % cols != 0:528rows += 1529530# Set all history dictionary objects531fcn32s_dict = fcn32s_history.history532fcn16s_dict = fcn16s_history.history533fcn8s_dict = fcn8s_history.history534535pos = range(1, total_plots + 1)536plt.figure(figsize=(15, 10))537538for i, ((key_32s, value_32s), (key_16s, value_16s), (key_8s, value_8s)) in enumerate(539zip(fcn32s_dict.items(), fcn16s_dict.items(), fcn8s_dict.items())540):541plt.subplot(rows, cols, pos[i])542plt.plot(range(len(value_32s)), value_32s)543plt.plot(range(len(value_16s)), value_16s)544plt.plot(range(len(value_8s)), value_8s)545plt.title(str(key_32s) + " (combined)")546plt.legend(["FCN-32S", "FCN-16S", "FCN-8S"])547548plt.show()549550"""551### Visualizing predicted segmentation masks552553To understand the results and see them better, we pick a random image from the test554dataset and perform inference on it to see the masks generated by each model.555Note: For better results, the model must be trained for a higher number of epochs.556"""557558images, masks = next(iter(test_ds))559random_idx = keras.random.uniform([], minval=0, maxval=BATCH_SIZE, seed=10)560561# Get random test image and mask562test_image = images[int(random_idx)].numpy().astype("float")563test_mask = masks[int(random_idx)].numpy().astype("float")564565pred_image = ops.expand_dims(test_image, axis=0)566pred_image = keras.applications.vgg19.preprocess_input(pred_image)567568# Perform inference on FCN-32S569pred_mask_32s = fcn32s_model.predict(pred_image, verbose=0).astype("float")570pred_mask_32s = np.argmax(pred_mask_32s, axis=-1)571pred_mask_32s = pred_mask_32s[0, ...]572573# Perform inference on FCN-16S574pred_mask_16s = fcn16s_model.predict(pred_image, verbose=0).astype("float")575pred_mask_16s = np.argmax(pred_mask_16s, axis=-1)576pred_mask_16s = pred_mask_16s[0, ...]577578# Perform inference on FCN-8S579pred_mask_8s = fcn8s_model.predict(pred_image, verbose=0).astype("float")580pred_mask_8s = np.argmax(pred_mask_8s, axis=-1)581pred_mask_8s = pred_mask_8s[0, ...]582583# Plot all results584fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(15, 8))585586fig.delaxes(ax[0, 2])587588ax[0, 0].set_title("Image")589ax[0, 0].imshow(test_image / 255.0)590591ax[0, 1].set_title("Image with ground truth overlay")592ax[0, 1].imshow(test_image / 255.0)593ax[0, 1].imshow(594test_mask,595cmap="inferno",596alpha=0.6,597)598599ax[1, 0].set_title("Image with FCN-32S mask overlay")600ax[1, 0].imshow(test_image / 255.0)601ax[1, 0].imshow(pred_mask_32s, cmap="inferno", alpha=0.6)602603ax[1, 1].set_title("Image with FCN-16S mask overlay")604ax[1, 1].imshow(test_image / 255.0)605ax[1, 1].imshow(pred_mask_16s, cmap="inferno", alpha=0.6)606607ax[1, 2].set_title("Image with FCN-8S mask overlay")608ax[1, 2].imshow(test_image / 255.0)609ax[1, 2].imshow(pred_mask_8s, cmap="inferno", alpha=0.6)610611plt.show()612613"""614## Conclusion615616The Fully-Convolutional Network is an exceptionally simple network that has yielded617strong results in Image Segmentation tasks across different benchmarks.618With the advent of better mechanisms like [Attention](https://arxiv.org/abs/1706.03762) as used in619[SegFormer](https://arxiv.org/abs/2105.15203) and620[DeTR](https://arxiv.org/abs/2005.12872), this model serves as a quick way to iterate and621find baselines for this task on unknown data.622"""623624"""625## Acknowledgements626627I thank [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ayush628Thakur](https://twitter.com/ayushthakur0) and [Ritwik629Raha](https://twitter.com/ritwik_raha) for giving a preliminary review of the example.630I also thank the [Google Developer631Experts](https://developers.google.com/community/experts) program.632633"""634635636