Path: blob/master/examples/vision/deeplabv3_plus.py
3507 views
"""1Title: Multiclass semantic segmentation using DeepLabV3+2Author: [Soumik Rakshit](http://github.com/soumik12345)3Date created: 2021/08/314Last modified: 2024/01/055Description: Implement DeepLabV3+ architecture for Multi-class Semantic Segmentation.6Accelerator: GPU7Converted to Keras 3: [Muhammad Anas Raza](https://anasrz.com)8"""910"""11## Introduction1213Semantic segmentation, with the goal to assign semantic labels to every pixel in an image,14is an essential computer vision task. In this example, we implement15the **DeepLabV3+** model for multi-class semantic segmentation, a fully-convolutional16architecture that performs well on semantic segmentation benchmarks.1718### References:1920- [Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1802.02611)21- [Rethinking Atrous Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1706.05587)22- [DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs](https://arxiv.org/abs/1606.00915)23"""2425"""26## Downloading the data2728We will use the [Crowd Instance-level Human Parsing Dataset](https://arxiv.org/abs/1811.12596)29for training our model. The Crowd Instance-level Human Parsing (CIHP) dataset has 38,280 diverse human images.30Each image in CIHP is labeled with pixel-wise annotations for 20 categories, as well as instance-level identification.31This dataset can be used for the "human part segmentation" task.32"""333435import keras36from keras import layers37from keras import ops3839import os40import numpy as np41from glob import glob42import cv243from scipy.io import loadmat44import matplotlib.pyplot as plt4546# For data preprocessing47from tensorflow import image as tf_image48from tensorflow import data as tf_data49from tensorflow import io as tf_io5051"""shell52gdown "1B9A9UCJYMwTL4oBEo4RZfbMZMaZhKJaz&confirm=t"53unzip -q instance-level-human-parsing.zip54"""5556"""57## Creating a TensorFlow Dataset5859Training on the entire CIHP dataset with 38,280 images takes a lot of time, hence we will be using60a smaller subset of 200 images for training our model in this example.61"""6263IMAGE_SIZE = 51264BATCH_SIZE = 465NUM_CLASSES = 2066DATA_DIR = "./instance-level_human_parsing/instance-level_human_parsing/Training"67NUM_TRAIN_IMAGES = 100068NUM_VAL_IMAGES = 506970train_images = sorted(glob(os.path.join(DATA_DIR, "Images/*")))[:NUM_TRAIN_IMAGES]71train_masks = sorted(glob(os.path.join(DATA_DIR, "Category_ids/*")))[:NUM_TRAIN_IMAGES]72val_images = sorted(glob(os.path.join(DATA_DIR, "Images/*")))[73NUM_TRAIN_IMAGES : NUM_VAL_IMAGES + NUM_TRAIN_IMAGES74]75val_masks = sorted(glob(os.path.join(DATA_DIR, "Category_ids/*")))[76NUM_TRAIN_IMAGES : NUM_VAL_IMAGES + NUM_TRAIN_IMAGES77]787980def read_image(image_path, mask=False):81image = tf_io.read_file(image_path)82if mask:83image = tf_image.decode_png(image, channels=1)84image.set_shape([None, None, 1])85image = tf_image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE])86else:87image = tf_image.decode_png(image, channels=3)88image.set_shape([None, None, 3])89image = tf_image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE])90return image919293def load_data(image_list, mask_list):94image = read_image(image_list)95mask = read_image(mask_list, mask=True)96return image, mask979899def data_generator(image_list, mask_list):100dataset = tf_data.Dataset.from_tensor_slices((image_list, mask_list))101dataset = dataset.map(load_data, num_parallel_calls=tf_data.AUTOTUNE)102dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)103return dataset104105106train_dataset = data_generator(train_images, train_masks)107val_dataset = data_generator(val_images, val_masks)108109print("Train Dataset:", train_dataset)110print("Val Dataset:", val_dataset)111112"""113## Building the DeepLabV3+ model114115DeepLabv3+ extends DeepLabv3 by adding an encoder-decoder structure. The encoder module116processes multiscale contextual information by applying dilated convolution at multiple117scales, while the decoder module refines the segmentation results along object boundaries.118119120121**Dilated convolution:** With dilated convolution, as we go deeper in the network, we can keep the122stride constant but with larger field-of-view without increasing the number of parameters123or the amount of computation. Besides, it enables larger output feature maps, which is124useful for semantic segmentation.125126The reason for using **Dilated Spatial Pyramid Pooling** is that it was shown that as the127sampling rate becomes larger, the number of valid filter weights (i.e., weights that128are applied to the valid feature region, instead of padded zeros) becomes smaller.129"""130131132def convolution_block(133block_input,134num_filters=256,135kernel_size=3,136dilation_rate=1,137use_bias=False,138):139x = layers.Conv2D(140num_filters,141kernel_size=kernel_size,142dilation_rate=dilation_rate,143padding="same",144use_bias=use_bias,145kernel_initializer=keras.initializers.HeNormal(),146)(block_input)147x = layers.BatchNormalization()(x)148return ops.nn.relu(x)149150151def DilatedSpatialPyramidPooling(dspp_input):152dims = dspp_input.shape153x = layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)154x = convolution_block(x, kernel_size=1, use_bias=True)155out_pool = layers.UpSampling2D(156size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]),157interpolation="bilinear",158)(x)159160out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)161out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)162out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)163out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)164165x = layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])166output = convolution_block(x, kernel_size=1)167return output168169170"""171The encoder features are first bilinearly upsampled by a factor 4, and then172concatenated with the corresponding low-level features from the network backbone that173have the same spatial resolution. For this example, we174use a ResNet50 pretrained on ImageNet as the backbone model, and we use175the low-level features from the `conv4_block6_2_relu` block of the backbone.176"""177178179def DeeplabV3Plus(image_size, num_classes):180model_input = keras.Input(shape=(image_size, image_size, 3))181preprocessed = keras.applications.resnet50.preprocess_input(model_input)182resnet50 = keras.applications.ResNet50(183weights="imagenet", include_top=False, input_tensor=preprocessed184)185x = resnet50.get_layer("conv4_block6_2_relu").output186x = DilatedSpatialPyramidPooling(x)187188input_a = layers.UpSampling2D(189size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]),190interpolation="bilinear",191)(x)192input_b = resnet50.get_layer("conv2_block3_2_relu").output193input_b = convolution_block(input_b, num_filters=48, kernel_size=1)194195x = layers.Concatenate(axis=-1)([input_a, input_b])196x = convolution_block(x)197x = convolution_block(x)198x = layers.UpSampling2D(199size=(image_size // x.shape[1], image_size // x.shape[2]),200interpolation="bilinear",201)(x)202model_output = layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)203return keras.Model(inputs=model_input, outputs=model_output)204205206model = DeeplabV3Plus(image_size=IMAGE_SIZE, num_classes=NUM_CLASSES)207model.summary()208209"""210## Training211212We train the model using sparse categorical crossentropy as the loss function, and213Adam as the optimizer.214"""215216loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)217model.compile(218optimizer=keras.optimizers.Adam(learning_rate=0.001),219loss=loss,220metrics=["accuracy"],221)222223history = model.fit(train_dataset, validation_data=val_dataset, epochs=25)224225plt.plot(history.history["loss"])226plt.title("Training Loss")227plt.ylabel("loss")228plt.xlabel("epoch")229plt.show()230231plt.plot(history.history["accuracy"])232plt.title("Training Accuracy")233plt.ylabel("accuracy")234plt.xlabel("epoch")235plt.show()236237plt.plot(history.history["val_loss"])238plt.title("Validation Loss")239plt.ylabel("val_loss")240plt.xlabel("epoch")241plt.show()242243plt.plot(history.history["val_accuracy"])244plt.title("Validation Accuracy")245plt.ylabel("val_accuracy")246plt.xlabel("epoch")247plt.show()248249"""250## Inference using Colormap Overlay251252The raw predictions from the model represent a one-hot encoded tensor of shape `(N, 512, 512, 20)`253where each one of the 20 channels is a binary mask corresponding to a predicted label.254In order to visualize the results, we plot them as RGB segmentation masks where each pixel255is represented by a unique color corresponding to the particular label predicted. We can easily256find the color corresponding to each label from the `human_colormap.mat` file provided as part257of the dataset. We would also plot an overlay of the RGB segmentation mask on the input image as258this further helps us to identify the different categories present in the image more intuitively.259"""260261# Loading the Colormap262colormap = loadmat(263"./instance-level_human_parsing/instance-level_human_parsing/human_colormap.mat"264)["colormap"]265colormap = colormap * 100266colormap = colormap.astype(np.uint8)267268269def infer(model, image_tensor):270predictions = model.predict(np.expand_dims((image_tensor), axis=0))271predictions = np.squeeze(predictions)272predictions = np.argmax(predictions, axis=2)273return predictions274275276def decode_segmentation_masks(mask, colormap, n_classes):277r = np.zeros_like(mask).astype(np.uint8)278g = np.zeros_like(mask).astype(np.uint8)279b = np.zeros_like(mask).astype(np.uint8)280for l in range(0, n_classes):281idx = mask == l282r[idx] = colormap[l, 0]283g[idx] = colormap[l, 1]284b[idx] = colormap[l, 2]285rgb = np.stack([r, g, b], axis=2)286return rgb287288289def get_overlay(image, colored_mask):290image = keras.utils.array_to_img(image)291image = np.array(image).astype(np.uint8)292overlay = cv2.addWeighted(image, 0.35, colored_mask, 0.65, 0)293return overlay294295296def plot_samples_matplotlib(display_list, figsize=(5, 3)):297_, axes = plt.subplots(nrows=1, ncols=len(display_list), figsize=figsize)298for i in range(len(display_list)):299if display_list[i].shape[-1] == 3:300axes[i].imshow(keras.utils.array_to_img(display_list[i]))301else:302axes[i].imshow(display_list[i])303plt.show()304305306def plot_predictions(images_list, colormap, model):307for image_file in images_list:308image_tensor = read_image(image_file)309prediction_mask = infer(image_tensor=image_tensor, model=model)310prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 20)311overlay = get_overlay(image_tensor, prediction_colormap)312plot_samples_matplotlib(313[image_tensor, overlay, prediction_colormap], figsize=(18, 14)314)315316317"""318### Inference on Train Images319"""320321plot_predictions(train_images[:4], colormap, model=model)322323"""324### Inference on Validation Images325326You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/deeplabv3p-resnet50)327and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/Human-Part-Segmentation).328"""329330plot_predictions(val_images[:4], colormap, model=model)331332333