Path: blob/master/examples/vision/basnet_segmentation.py
3507 views
"""1Title: Highly accurate boundaries segmentation using BASNet2Author: [Hamid Ali](https://github.com/hamidriasat)3Date created: 2023/05/304Last modified: 2024/10/025Description: Boundaries aware segmentation model trained on the DUTS dataset.6Accelerator: GPU7"""89"""10## Introduction1112Deep semantic segmentation algorithms have improved a lot recently, but still fails to correctly13predict pixels around object boundaries. In this example we implement14**Boundary-Aware Segmentation Network (BASNet)**, using two stage predict and refine15architecture, and a hybrid loss it can predict highly accurate boundaries and fine structures16for image segmentation.1718### References:1920- [Boundary-Aware Segmentation Network for Mobile and Web Applications](https://arxiv.org/abs/2101.04704)21- [BASNet Keras Implementation](https://github.com/hamidriasat/BASNet/tree/basnet_keras)22- [Learning to Detect Salient Objects with Image-level Supervision](https://openaccess.thecvf.com/content_cvpr_2017/html/Wang_Learning_to_Detect_CVPR_2017_paper.html)23"""2425"""26## Download the Data2728We will use the [DUTS-TE](http://saliencydetection.net/duts/) dataset for training. It has 5,01929images but we will use 140 for training and validation to save notebook running time. DUTS is30relatively large salient object segmentation dataset. which contain diversified textures and31structures common to real-world images in both foreground and background.32"""3334import os3536# Because of the use of tf.image.ssim in the loss,37# this example requires TensorFlow. The rest of the code38# is backend-agnostic.39os.environ["KERAS_BACKEND"] = "tensorflow"4041import numpy as np42from glob import glob43import matplotlib.pyplot as plt4445import keras_hub46import tensorflow as tf47import keras48from keras import layers, ops4950keras.config.disable_traceback_filtering()5152"""53## Define Hyperparameters54"""5556IMAGE_SIZE = 28857BATCH_SIZE = 458OUT_CLASSES = 159TRAIN_SPLIT_RATIO = 0.906061"""62## Create `PyDataset`s6364We will use `load_paths()` to load and split 140 paths into train and validation set, and65convert paths into `PyDataset` object.66"""6768data_dir = keras.utils.get_file(69origin="http://saliencydetection.net/duts/download/DUTS-TE.zip",70extract=True,71)72data_dir = os.path.join(data_dir, "DUTS-TE")737475def load_paths(path, split_ratio):76images = sorted(glob(os.path.join(path, "DUTS-TE-Image/*")))[:140]77masks = sorted(glob(os.path.join(path, "DUTS-TE-Mask/*")))[:140]78len_ = int(len(images) * split_ratio)79return (images[:len_], masks[:len_]), (images[len_:], masks[len_:])808182class Dataset(keras.utils.PyDataset):83def __init__(84self,85image_paths,86mask_paths,87img_size,88out_classes,89batch,90shuffle=True,91**kwargs,92):93if shuffle:94perm = np.random.permutation(len(image_paths))95image_paths = [image_paths[i] for i in perm]96mask_paths = [mask_paths[i] for i in perm]97self.image_paths = image_paths98self.mask_paths = mask_paths99self.img_size = img_size100self.out_classes = out_classes101self.batch_size = batch102super().__init__(*kwargs)103104def __len__(self):105return len(self.image_paths) // self.batch_size106107def __getitem__(self, idx):108batch_x, batch_y = [], []109for i in range(idx * self.batch_size, (idx + 1) * self.batch_size):110x, y = self.preprocess(111self.image_paths[i],112self.mask_paths[i],113self.img_size,114)115batch_x.append(x)116batch_y.append(y)117batch_x = np.stack(batch_x, axis=0)118batch_y = np.stack(batch_y, axis=0)119return batch_x, batch_y120121def read_image(self, path, size, mode):122x = keras.utils.load_img(path, target_size=size, color_mode=mode)123x = keras.utils.img_to_array(x)124x = (x / 255.0).astype(np.float32)125return x126127def preprocess(self, x_batch, y_batch, img_size):128images = self.read_image(x_batch, (img_size, img_size), mode="rgb") # image129masks = self.read_image(y_batch, (img_size, img_size), mode="grayscale") # mask130return images, masks131132133train_paths, val_paths = load_paths(data_dir, TRAIN_SPLIT_RATIO)134135train_dataset = Dataset(136train_paths[0], train_paths[1], IMAGE_SIZE, OUT_CLASSES, BATCH_SIZE, shuffle=True137)138val_dataset = Dataset(139val_paths[0], val_paths[1], IMAGE_SIZE, OUT_CLASSES, BATCH_SIZE, shuffle=False140)141142"""143## Visualize Data144"""145146147def display(display_list):148title = ["Input Image", "True Mask", "Predicted Mask"]149150for i in range(len(display_list)):151plt.subplot(1, len(display_list), i + 1)152plt.title(title[i])153plt.imshow(keras.utils.array_to_img(display_list[i]), cmap="gray")154plt.axis("off")155plt.show()156157158for image, mask in val_dataset:159display([image[0], mask[0]])160break161162"""163## Analyze Mask164165Lets print unique values of above displayed mask. You can see despite belonging to one class, it's166intensity is changing between low(0) to high(255). This variation in intensity makes it hard for167network to generate good segmentation map for **salient or camouflaged object segmentation**.168Because of its Residual Refined Module (RMs), BASNet is good in generating highly accurate169boundaries and fine structures.170"""171172print(f"Unique values count: {len(np.unique((mask[0] * 255)))}")173print("Unique values:")174print(np.unique((mask[0] * 255)).astype(int))175176"""177## Building the BASNet Model178179BASNet comprises of a predict-refine architecture and a hybrid loss. The predict-refine180architecture consists of a densely supervised encoder-decoder network and a residual refinement181module, which are respectively used to predict and refine a segmentation probability map.182183184"""185186187def basic_block(x_input, filters, stride=1, down_sample=None, activation=None):188"""Creates a residual(identity) block with two 3*3 convolutions."""189residual = x_input190191x = layers.Conv2D(filters, (3, 3), strides=stride, padding="same", use_bias=False)(192x_input193)194x = layers.BatchNormalization()(x)195x = layers.Activation("relu")(x)196197x = layers.Conv2D(filters, (3, 3), strides=(1, 1), padding="same", use_bias=False)(198x199)200x = layers.BatchNormalization()(x)201202if down_sample is not None:203residual = down_sample204205x = layers.Add()([x, residual])206207if activation is not None:208x = layers.Activation(activation)(x)209210return x211212213def convolution_block(x_input, filters, dilation=1):214"""Apply convolution + batch normalization + relu layer."""215x = layers.Conv2D(filters, (3, 3), padding="same", dilation_rate=dilation)(x_input)216x = layers.BatchNormalization()(x)217return layers.Activation("relu")(x)218219220def segmentation_head(x_input, out_classes, final_size):221"""Map each decoder stage output to model output classes."""222x = layers.Conv2D(out_classes, kernel_size=(3, 3), padding="same")(x_input)223224if final_size is not None:225x = layers.Resizing(final_size[0], final_size[1])(x)226227return x228229230def get_resnet_block(resnet, block_num):231"""Extract and return a ResNet-34 block."""232extractor_levels = ["P2", "P3", "P4", "P5"]233num_blocks = resnet.stackwise_num_blocks234if block_num == 0:235x = resnet.get_layer("pool1_pool").output236else:237x = resnet.pyramid_outputs[extractor_levels[block_num - 1]]238y = resnet.get_layer(f"stack{block_num}_block{num_blocks[block_num]-1}_add").output239return keras.models.Model(240inputs=x,241outputs=y,242name=f"resnet_block{block_num + 1}",243)244245246"""247## Prediction Module248249Prediction module is a heavy encoder decoder structure like U-Net. The encoder includes an input250convolutional layer and six stages. First four are adopted from ResNet-34 and rest are basic251res-blocks. Since first convolution and pooling layer of ResNet-34 is skipped so we will use252`get_resnet_block()` to extract first four blocks. Both bridge and decoder uses three253convolutional layers with side outputs. The module produces seven segmentation probability254maps during training, with the last one considered the final output.255"""256257258def basnet_predict(input_shape, out_classes):259"""BASNet Prediction Module, it outputs coarse label map."""260filters = 64261num_stages = 6262263x_input = layers.Input(input_shape)264265# -------------Encoder--------------266x = layers.Conv2D(filters, kernel_size=(3, 3), padding="same")(x_input)267268resnet = keras_hub.models.ResNetBackbone(269input_conv_filters=[64],270input_conv_kernel_sizes=[7],271stackwise_num_filters=[64, 128, 256, 512],272stackwise_num_blocks=[3, 4, 6, 3],273stackwise_num_strides=[1, 2, 2, 2],274block_type="basic_block",275)276277encoder_blocks = []278for i in range(num_stages):279if i < 4: # First four stages are adopted from ResNet-34 blocks.280x = get_resnet_block(resnet, i)(x)281encoder_blocks.append(x)282x = layers.Activation("relu")(x)283else: # Last 2 stages consist of three basic resnet blocks.284x = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(x)285x = basic_block(x, filters=filters * 8, activation="relu")286x = basic_block(x, filters=filters * 8, activation="relu")287x = basic_block(x, filters=filters * 8, activation="relu")288encoder_blocks.append(x)289290# -------------Bridge-------------291x = convolution_block(x, filters=filters * 8, dilation=2)292x = convolution_block(x, filters=filters * 8, dilation=2)293x = convolution_block(x, filters=filters * 8, dilation=2)294encoder_blocks.append(x)295296# -------------Decoder-------------297decoder_blocks = []298for i in reversed(range(num_stages)):299if i != (num_stages - 1): # Except first, scale other decoder stages.300shape = x.shape301x = layers.Resizing(shape[1] * 2, shape[2] * 2)(x)302303x = layers.concatenate([encoder_blocks[i], x], axis=-1)304x = convolution_block(x, filters=filters * 8)305x = convolution_block(x, filters=filters * 8)306x = convolution_block(x, filters=filters * 8)307decoder_blocks.append(x)308309decoder_blocks.reverse() # Change order from last to first decoder stage.310decoder_blocks.append(encoder_blocks[-1]) # Copy bridge to decoder.311312# -------------Side Outputs--------------313decoder_blocks = [314segmentation_head(decoder_block, out_classes, input_shape[:2])315for decoder_block in decoder_blocks316]317318return keras.models.Model(inputs=x_input, outputs=decoder_blocks)319320321"""322## Residual Refinement Module323324Refinement Modules (RMs), designed as a residual block aim to refines the coarse(blurry and noisy325boundaries) segmentation maps generated by prediction module. Similar to prediction module it's326also an encode decoder structure but with light weight 4 stages, each containing one327`convolutional block()` init. At the end it adds both coarse and residual output to generate328refined output.329"""330331332def basnet_rrm(base_model, out_classes):333"""BASNet Residual Refinement Module(RRM) module, output fine label map."""334num_stages = 4335filters = 64336337x_input = base_model.output[0]338339# -------------Encoder--------------340x = layers.Conv2D(filters, kernel_size=(3, 3), padding="same")(x_input)341342encoder_blocks = []343for _ in range(num_stages):344x = convolution_block(x, filters=filters)345encoder_blocks.append(x)346x = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(x)347348# -------------Bridge--------------349x = convolution_block(x, filters=filters)350351# -------------Decoder--------------352for i in reversed(range(num_stages)):353shape = x.shape354x = layers.Resizing(shape[1] * 2, shape[2] * 2)(x)355x = layers.concatenate([encoder_blocks[i], x], axis=-1)356x = convolution_block(x, filters=filters)357358x = segmentation_head(x, out_classes, None) # Segmentation head.359360# ------------- refined = coarse + residual361x = layers.Add()([x_input, x]) # Add prediction + refinement output362363return keras.models.Model(inputs=[base_model.input], outputs=[x])364365366"""367## Combine Predict and Refinement Module368"""369370371class BASNet(keras.Model):372def __init__(self, input_shape, out_classes):373"""BASNet, it's a combination of two modules374Prediction Module and Residual Refinement Module(RRM)."""375376# Prediction model.377predict_model = basnet_predict(input_shape, out_classes)378# Refinement model.379refine_model = basnet_rrm(predict_model, out_classes)380381output = refine_model.outputs # Combine outputs.382output.extend(predict_model.output)383384# Activations.385output = [layers.Activation("sigmoid")(x) for x in output]386super().__init__(inputs=predict_model.input, outputs=output)387388self.smooth = 1.0e-9389# Binary Cross Entropy loss.390self.cross_entropy_loss = keras.losses.BinaryCrossentropy()391# Structural Similarity Index value.392self.ssim_value = tf.image.ssim393# Jaccard / IoU loss.394self.iou_value = self.calculate_iou395396def calculate_iou(397self,398y_true,399y_pred,400):401"""Calculate intersection over union (IoU) between images."""402intersection = ops.sum(ops.abs(y_true * y_pred), axis=[1, 2, 3])403union = ops.sum(y_true, [1, 2, 3]) + ops.sum(y_pred, [1, 2, 3])404union = union - intersection405return ops.mean((intersection + self.smooth) / (union + self.smooth), axis=0)406407def compute_loss(self, x, y_true, y_pred, sample_weight=None, training=False):408total = 0.0409for y_pred_i in y_pred: # y_pred = refine_model.outputs + predict_model.output410cross_entropy_loss = self.cross_entropy_loss(y_true, y_pred_i)411412ssim_value = self.ssim_value(y_true, y_pred, max_val=1)413ssim_loss = ops.mean(1 - ssim_value + self.smooth, axis=0)414415iou_value = self.iou_value(y_true, y_pred)416iou_loss = 1 - iou_value417418# Add all three losses.419total += cross_entropy_loss + ssim_loss + iou_loss420return total421422423"""424## Hybrid Loss425426Another important feature of BASNet is its hybrid loss function, which is a combination of427binary cross entropy, structural similarity and intersection-over-union losses, which guide428the network to learn three-level (i.e., pixel, patch and map level) hierarchy representations.429"""430431432basnet_model = BASNet(433input_shape=[IMAGE_SIZE, IMAGE_SIZE, 3], out_classes=OUT_CLASSES434) # Create model.435basnet_model.summary() # Show model summary.436437optimizer = keras.optimizers.Adam(learning_rate=1e-4, epsilon=1e-8)438# Compile model.439basnet_model.compile(440optimizer=optimizer,441metrics=[keras.metrics.MeanAbsoluteError(name="mae") for _ in basnet_model.outputs],442)443444"""445### Train the Model446"""447448basnet_model.fit(train_dataset, validation_data=val_dataset, epochs=1)449450"""451### Visualize Predictions452453In paper BASNet was trained on DUTS-TR dataset, which has 10553 images. Model was trained for 400k454iterations with a batch size of eight and without a validation dataset. After training model was455evaluated on DUTS-TE dataset and achieved a mean absolute error of `0.042`.456457Since BASNet is a deep model and cannot be trained in a short amount of time which is a458requirement for keras example notebook, so we will load pretrained weights from [here](https://github.com/hamidriasat/BASNet/tree/basnet_keras)459to show model prediction. Due to computer power limitation this model was trained for 120k460iterations but it still demonstrates its capabilities. For further details about461trainings parameters please check given link.462"""463464import gdown465466gdown.download(id="1OWKouuAQ7XpXZbWA3mmxDPrFGW71Axrg", output="basnet_weights.h5")467468469def normalize_output(prediction):470max_value = np.max(prediction)471min_value = np.min(prediction)472return (prediction - min_value) / (max_value - min_value)473474475# Load weights.476basnet_model.load_weights("./basnet_weights.h5")477478"""479### Make Predictions480"""481482for (image, mask), _ in zip(val_dataset, range(1)):483pred_mask = basnet_model.predict(image)484display([image[0], mask[0], normalize_output(pred_mask[0][0])])485486487