Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/basnet_segmentation.py
3507 views
1
"""
2
Title: Highly accurate boundaries segmentation using BASNet
3
Author: [Hamid Ali](https://github.com/hamidriasat)
4
Date created: 2023/05/30
5
Last modified: 2024/10/02
6
Description: Boundaries aware segmentation model trained on the DUTS dataset.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
Deep semantic segmentation algorithms have improved a lot recently, but still fails to correctly
14
predict pixels around object boundaries. In this example we implement
15
**Boundary-Aware Segmentation Network (BASNet)**, using two stage predict and refine
16
architecture, and a hybrid loss it can predict highly accurate boundaries and fine structures
17
for image segmentation.
18
19
### References:
20
21
- [Boundary-Aware Segmentation Network for Mobile and Web Applications](https://arxiv.org/abs/2101.04704)
22
- [BASNet Keras Implementation](https://github.com/hamidriasat/BASNet/tree/basnet_keras)
23
- [Learning to Detect Salient Objects with Image-level Supervision](https://openaccess.thecvf.com/content_cvpr_2017/html/Wang_Learning_to_Detect_CVPR_2017_paper.html)
24
"""
25
26
"""
27
## Download the Data
28
29
We will use the [DUTS-TE](http://saliencydetection.net/duts/) dataset for training. It has 5,019
30
images but we will use 140 for training and validation to save notebook running time. DUTS is
31
relatively large salient object segmentation dataset. which contain diversified textures and
32
structures common to real-world images in both foreground and background.
33
"""
34
35
import os
36
37
# Because of the use of tf.image.ssim in the loss,
38
# this example requires TensorFlow. The rest of the code
39
# is backend-agnostic.
40
os.environ["KERAS_BACKEND"] = "tensorflow"
41
42
import numpy as np
43
from glob import glob
44
import matplotlib.pyplot as plt
45
46
import keras_hub
47
import tensorflow as tf
48
import keras
49
from keras import layers, ops
50
51
keras.config.disable_traceback_filtering()
52
53
"""
54
## Define Hyperparameters
55
"""
56
57
IMAGE_SIZE = 288
58
BATCH_SIZE = 4
59
OUT_CLASSES = 1
60
TRAIN_SPLIT_RATIO = 0.90
61
62
"""
63
## Create `PyDataset`s
64
65
We will use `load_paths()` to load and split 140 paths into train and validation set, and
66
convert paths into `PyDataset` object.
67
"""
68
69
data_dir = keras.utils.get_file(
70
origin="http://saliencydetection.net/duts/download/DUTS-TE.zip",
71
extract=True,
72
)
73
data_dir = os.path.join(data_dir, "DUTS-TE")
74
75
76
def load_paths(path, split_ratio):
77
images = sorted(glob(os.path.join(path, "DUTS-TE-Image/*")))[:140]
78
masks = sorted(glob(os.path.join(path, "DUTS-TE-Mask/*")))[:140]
79
len_ = int(len(images) * split_ratio)
80
return (images[:len_], masks[:len_]), (images[len_:], masks[len_:])
81
82
83
class Dataset(keras.utils.PyDataset):
84
def __init__(
85
self,
86
image_paths,
87
mask_paths,
88
img_size,
89
out_classes,
90
batch,
91
shuffle=True,
92
**kwargs,
93
):
94
if shuffle:
95
perm = np.random.permutation(len(image_paths))
96
image_paths = [image_paths[i] for i in perm]
97
mask_paths = [mask_paths[i] for i in perm]
98
self.image_paths = image_paths
99
self.mask_paths = mask_paths
100
self.img_size = img_size
101
self.out_classes = out_classes
102
self.batch_size = batch
103
super().__init__(*kwargs)
104
105
def __len__(self):
106
return len(self.image_paths) // self.batch_size
107
108
def __getitem__(self, idx):
109
batch_x, batch_y = [], []
110
for i in range(idx * self.batch_size, (idx + 1) * self.batch_size):
111
x, y = self.preprocess(
112
self.image_paths[i],
113
self.mask_paths[i],
114
self.img_size,
115
)
116
batch_x.append(x)
117
batch_y.append(y)
118
batch_x = np.stack(batch_x, axis=0)
119
batch_y = np.stack(batch_y, axis=0)
120
return batch_x, batch_y
121
122
def read_image(self, path, size, mode):
123
x = keras.utils.load_img(path, target_size=size, color_mode=mode)
124
x = keras.utils.img_to_array(x)
125
x = (x / 255.0).astype(np.float32)
126
return x
127
128
def preprocess(self, x_batch, y_batch, img_size):
129
images = self.read_image(x_batch, (img_size, img_size), mode="rgb") # image
130
masks = self.read_image(y_batch, (img_size, img_size), mode="grayscale") # mask
131
return images, masks
132
133
134
train_paths, val_paths = load_paths(data_dir, TRAIN_SPLIT_RATIO)
135
136
train_dataset = Dataset(
137
train_paths[0], train_paths[1], IMAGE_SIZE, OUT_CLASSES, BATCH_SIZE, shuffle=True
138
)
139
val_dataset = Dataset(
140
val_paths[0], val_paths[1], IMAGE_SIZE, OUT_CLASSES, BATCH_SIZE, shuffle=False
141
)
142
143
"""
144
## Visualize Data
145
"""
146
147
148
def display(display_list):
149
title = ["Input Image", "True Mask", "Predicted Mask"]
150
151
for i in range(len(display_list)):
152
plt.subplot(1, len(display_list), i + 1)
153
plt.title(title[i])
154
plt.imshow(keras.utils.array_to_img(display_list[i]), cmap="gray")
155
plt.axis("off")
156
plt.show()
157
158
159
for image, mask in val_dataset:
160
display([image[0], mask[0]])
161
break
162
163
"""
164
## Analyze Mask
165
166
Lets print unique values of above displayed mask. You can see despite belonging to one class, it's
167
intensity is changing between low(0) to high(255). This variation in intensity makes it hard for
168
network to generate good segmentation map for **salient or camouflaged object segmentation**.
169
Because of its Residual Refined Module (RMs), BASNet is good in generating highly accurate
170
boundaries and fine structures.
171
"""
172
173
print(f"Unique values count: {len(np.unique((mask[0] * 255)))}")
174
print("Unique values:")
175
print(np.unique((mask[0] * 255)).astype(int))
176
177
"""
178
## Building the BASNet Model
179
180
BASNet comprises of a predict-refine architecture and a hybrid loss. The predict-refine
181
architecture consists of a densely supervised encoder-decoder network and a residual refinement
182
module, which are respectively used to predict and refine a segmentation probability map.
183
184
![](https://i.imgur.com/8jaZ2qs.png)
185
"""
186
187
188
def basic_block(x_input, filters, stride=1, down_sample=None, activation=None):
189
"""Creates a residual(identity) block with two 3*3 convolutions."""
190
residual = x_input
191
192
x = layers.Conv2D(filters, (3, 3), strides=stride, padding="same", use_bias=False)(
193
x_input
194
)
195
x = layers.BatchNormalization()(x)
196
x = layers.Activation("relu")(x)
197
198
x = layers.Conv2D(filters, (3, 3), strides=(1, 1), padding="same", use_bias=False)(
199
x
200
)
201
x = layers.BatchNormalization()(x)
202
203
if down_sample is not None:
204
residual = down_sample
205
206
x = layers.Add()([x, residual])
207
208
if activation is not None:
209
x = layers.Activation(activation)(x)
210
211
return x
212
213
214
def convolution_block(x_input, filters, dilation=1):
215
"""Apply convolution + batch normalization + relu layer."""
216
x = layers.Conv2D(filters, (3, 3), padding="same", dilation_rate=dilation)(x_input)
217
x = layers.BatchNormalization()(x)
218
return layers.Activation("relu")(x)
219
220
221
def segmentation_head(x_input, out_classes, final_size):
222
"""Map each decoder stage output to model output classes."""
223
x = layers.Conv2D(out_classes, kernel_size=(3, 3), padding="same")(x_input)
224
225
if final_size is not None:
226
x = layers.Resizing(final_size[0], final_size[1])(x)
227
228
return x
229
230
231
def get_resnet_block(resnet, block_num):
232
"""Extract and return a ResNet-34 block."""
233
extractor_levels = ["P2", "P3", "P4", "P5"]
234
num_blocks = resnet.stackwise_num_blocks
235
if block_num == 0:
236
x = resnet.get_layer("pool1_pool").output
237
else:
238
x = resnet.pyramid_outputs[extractor_levels[block_num - 1]]
239
y = resnet.get_layer(f"stack{block_num}_block{num_blocks[block_num]-1}_add").output
240
return keras.models.Model(
241
inputs=x,
242
outputs=y,
243
name=f"resnet_block{block_num + 1}",
244
)
245
246
247
"""
248
## Prediction Module
249
250
Prediction module is a heavy encoder decoder structure like U-Net. The encoder includes an input
251
convolutional layer and six stages. First four are adopted from ResNet-34 and rest are basic
252
res-blocks. Since first convolution and pooling layer of ResNet-34 is skipped so we will use
253
`get_resnet_block()` to extract first four blocks. Both bridge and decoder uses three
254
convolutional layers with side outputs. The module produces seven segmentation probability
255
maps during training, with the last one considered the final output.
256
"""
257
258
259
def basnet_predict(input_shape, out_classes):
260
"""BASNet Prediction Module, it outputs coarse label map."""
261
filters = 64
262
num_stages = 6
263
264
x_input = layers.Input(input_shape)
265
266
# -------------Encoder--------------
267
x = layers.Conv2D(filters, kernel_size=(3, 3), padding="same")(x_input)
268
269
resnet = keras_hub.models.ResNetBackbone(
270
input_conv_filters=[64],
271
input_conv_kernel_sizes=[7],
272
stackwise_num_filters=[64, 128, 256, 512],
273
stackwise_num_blocks=[3, 4, 6, 3],
274
stackwise_num_strides=[1, 2, 2, 2],
275
block_type="basic_block",
276
)
277
278
encoder_blocks = []
279
for i in range(num_stages):
280
if i < 4: # First four stages are adopted from ResNet-34 blocks.
281
x = get_resnet_block(resnet, i)(x)
282
encoder_blocks.append(x)
283
x = layers.Activation("relu")(x)
284
else: # Last 2 stages consist of three basic resnet blocks.
285
x = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(x)
286
x = basic_block(x, filters=filters * 8, activation="relu")
287
x = basic_block(x, filters=filters * 8, activation="relu")
288
x = basic_block(x, filters=filters * 8, activation="relu")
289
encoder_blocks.append(x)
290
291
# -------------Bridge-------------
292
x = convolution_block(x, filters=filters * 8, dilation=2)
293
x = convolution_block(x, filters=filters * 8, dilation=2)
294
x = convolution_block(x, filters=filters * 8, dilation=2)
295
encoder_blocks.append(x)
296
297
# -------------Decoder-------------
298
decoder_blocks = []
299
for i in reversed(range(num_stages)):
300
if i != (num_stages - 1): # Except first, scale other decoder stages.
301
shape = x.shape
302
x = layers.Resizing(shape[1] * 2, shape[2] * 2)(x)
303
304
x = layers.concatenate([encoder_blocks[i], x], axis=-1)
305
x = convolution_block(x, filters=filters * 8)
306
x = convolution_block(x, filters=filters * 8)
307
x = convolution_block(x, filters=filters * 8)
308
decoder_blocks.append(x)
309
310
decoder_blocks.reverse() # Change order from last to first decoder stage.
311
decoder_blocks.append(encoder_blocks[-1]) # Copy bridge to decoder.
312
313
# -------------Side Outputs--------------
314
decoder_blocks = [
315
segmentation_head(decoder_block, out_classes, input_shape[:2])
316
for decoder_block in decoder_blocks
317
]
318
319
return keras.models.Model(inputs=x_input, outputs=decoder_blocks)
320
321
322
"""
323
## Residual Refinement Module
324
325
Refinement Modules (RMs), designed as a residual block aim to refines the coarse(blurry and noisy
326
boundaries) segmentation maps generated by prediction module. Similar to prediction module it's
327
also an encode decoder structure but with light weight 4 stages, each containing one
328
`convolutional block()` init. At the end it adds both coarse and residual output to generate
329
refined output.
330
"""
331
332
333
def basnet_rrm(base_model, out_classes):
334
"""BASNet Residual Refinement Module(RRM) module, output fine label map."""
335
num_stages = 4
336
filters = 64
337
338
x_input = base_model.output[0]
339
340
# -------------Encoder--------------
341
x = layers.Conv2D(filters, kernel_size=(3, 3), padding="same")(x_input)
342
343
encoder_blocks = []
344
for _ in range(num_stages):
345
x = convolution_block(x, filters=filters)
346
encoder_blocks.append(x)
347
x = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(x)
348
349
# -------------Bridge--------------
350
x = convolution_block(x, filters=filters)
351
352
# -------------Decoder--------------
353
for i in reversed(range(num_stages)):
354
shape = x.shape
355
x = layers.Resizing(shape[1] * 2, shape[2] * 2)(x)
356
x = layers.concatenate([encoder_blocks[i], x], axis=-1)
357
x = convolution_block(x, filters=filters)
358
359
x = segmentation_head(x, out_classes, None) # Segmentation head.
360
361
# ------------- refined = coarse + residual
362
x = layers.Add()([x_input, x]) # Add prediction + refinement output
363
364
return keras.models.Model(inputs=[base_model.input], outputs=[x])
365
366
367
"""
368
## Combine Predict and Refinement Module
369
"""
370
371
372
class BASNet(keras.Model):
373
def __init__(self, input_shape, out_classes):
374
"""BASNet, it's a combination of two modules
375
Prediction Module and Residual Refinement Module(RRM)."""
376
377
# Prediction model.
378
predict_model = basnet_predict(input_shape, out_classes)
379
# Refinement model.
380
refine_model = basnet_rrm(predict_model, out_classes)
381
382
output = refine_model.outputs # Combine outputs.
383
output.extend(predict_model.output)
384
385
# Activations.
386
output = [layers.Activation("sigmoid")(x) for x in output]
387
super().__init__(inputs=predict_model.input, outputs=output)
388
389
self.smooth = 1.0e-9
390
# Binary Cross Entropy loss.
391
self.cross_entropy_loss = keras.losses.BinaryCrossentropy()
392
# Structural Similarity Index value.
393
self.ssim_value = tf.image.ssim
394
# Jaccard / IoU loss.
395
self.iou_value = self.calculate_iou
396
397
def calculate_iou(
398
self,
399
y_true,
400
y_pred,
401
):
402
"""Calculate intersection over union (IoU) between images."""
403
intersection = ops.sum(ops.abs(y_true * y_pred), axis=[1, 2, 3])
404
union = ops.sum(y_true, [1, 2, 3]) + ops.sum(y_pred, [1, 2, 3])
405
union = union - intersection
406
return ops.mean((intersection + self.smooth) / (union + self.smooth), axis=0)
407
408
def compute_loss(self, x, y_true, y_pred, sample_weight=None, training=False):
409
total = 0.0
410
for y_pred_i in y_pred: # y_pred = refine_model.outputs + predict_model.output
411
cross_entropy_loss = self.cross_entropy_loss(y_true, y_pred_i)
412
413
ssim_value = self.ssim_value(y_true, y_pred, max_val=1)
414
ssim_loss = ops.mean(1 - ssim_value + self.smooth, axis=0)
415
416
iou_value = self.iou_value(y_true, y_pred)
417
iou_loss = 1 - iou_value
418
419
# Add all three losses.
420
total += cross_entropy_loss + ssim_loss + iou_loss
421
return total
422
423
424
"""
425
## Hybrid Loss
426
427
Another important feature of BASNet is its hybrid loss function, which is a combination of
428
binary cross entropy, structural similarity and intersection-over-union losses, which guide
429
the network to learn three-level (i.e., pixel, patch and map level) hierarchy representations.
430
"""
431
432
433
basnet_model = BASNet(
434
input_shape=[IMAGE_SIZE, IMAGE_SIZE, 3], out_classes=OUT_CLASSES
435
) # Create model.
436
basnet_model.summary() # Show model summary.
437
438
optimizer = keras.optimizers.Adam(learning_rate=1e-4, epsilon=1e-8)
439
# Compile model.
440
basnet_model.compile(
441
optimizer=optimizer,
442
metrics=[keras.metrics.MeanAbsoluteError(name="mae") for _ in basnet_model.outputs],
443
)
444
445
"""
446
### Train the Model
447
"""
448
449
basnet_model.fit(train_dataset, validation_data=val_dataset, epochs=1)
450
451
"""
452
### Visualize Predictions
453
454
In paper BASNet was trained on DUTS-TR dataset, which has 10553 images. Model was trained for 400k
455
iterations with a batch size of eight and without a validation dataset. After training model was
456
evaluated on DUTS-TE dataset and achieved a mean absolute error of `0.042`.
457
458
Since BASNet is a deep model and cannot be trained in a short amount of time which is a
459
requirement for keras example notebook, so we will load pretrained weights from [here](https://github.com/hamidriasat/BASNet/tree/basnet_keras)
460
to show model prediction. Due to computer power limitation this model was trained for 120k
461
iterations but it still demonstrates its capabilities. For further details about
462
trainings parameters please check given link.
463
"""
464
465
import gdown
466
467
gdown.download(id="1OWKouuAQ7XpXZbWA3mmxDPrFGW71Axrg", output="basnet_weights.h5")
468
469
470
def normalize_output(prediction):
471
max_value = np.max(prediction)
472
min_value = np.min(prediction)
473
return (prediction - min_value) / (max_value - min_value)
474
475
476
# Load weights.
477
basnet_model.load_weights("./basnet_weights.h5")
478
479
"""
480
### Make Predictions
481
"""
482
483
for (image, mask), _ in zip(val_dataset, range(1)):
484
pred_mask = basnet_model.predict(image)
485
display([image[0], mask[0], normalize_output(pred_mask[0][0])])
486
487