Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/deeplabv3_plus.py
3507 views
1
"""
2
Title: Multiclass semantic segmentation using DeepLabV3+
3
Author: [Soumik Rakshit](http://github.com/soumik12345)
4
Date created: 2021/08/31
5
Last modified: 2024/01/05
6
Description: Implement DeepLabV3+ architecture for Multi-class Semantic Segmentation.
7
Accelerator: GPU
8
Converted to Keras 3: [Muhammad Anas Raza](https://anasrz.com)
9
"""
10
11
"""
12
## Introduction
13
14
Semantic segmentation, with the goal to assign semantic labels to every pixel in an image,
15
is an essential computer vision task. In this example, we implement
16
the **DeepLabV3+** model for multi-class semantic segmentation, a fully-convolutional
17
architecture that performs well on semantic segmentation benchmarks.
18
19
### References:
20
21
- [Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1802.02611)
22
- [Rethinking Atrous Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1706.05587)
23
- [DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs](https://arxiv.org/abs/1606.00915)
24
"""
25
26
"""
27
## Downloading the data
28
29
We will use the [Crowd Instance-level Human Parsing Dataset](https://arxiv.org/abs/1811.12596)
30
for training our model. The Crowd Instance-level Human Parsing (CIHP) dataset has 38,280 diverse human images.
31
Each image in CIHP is labeled with pixel-wise annotations for 20 categories, as well as instance-level identification.
32
This dataset can be used for the "human part segmentation" task.
33
"""
34
35
36
import keras
37
from keras import layers
38
from keras import ops
39
40
import os
41
import numpy as np
42
from glob import glob
43
import cv2
44
from scipy.io import loadmat
45
import matplotlib.pyplot as plt
46
47
# For data preprocessing
48
from tensorflow import image as tf_image
49
from tensorflow import data as tf_data
50
from tensorflow import io as tf_io
51
52
"""shell
53
gdown "1B9A9UCJYMwTL4oBEo4RZfbMZMaZhKJaz&confirm=t"
54
unzip -q instance-level-human-parsing.zip
55
"""
56
57
"""
58
## Creating a TensorFlow Dataset
59
60
Training on the entire CIHP dataset with 38,280 images takes a lot of time, hence we will be using
61
a smaller subset of 200 images for training our model in this example.
62
"""
63
64
IMAGE_SIZE = 512
65
BATCH_SIZE = 4
66
NUM_CLASSES = 20
67
DATA_DIR = "./instance-level_human_parsing/instance-level_human_parsing/Training"
68
NUM_TRAIN_IMAGES = 1000
69
NUM_VAL_IMAGES = 50
70
71
train_images = sorted(glob(os.path.join(DATA_DIR, "Images/*")))[:NUM_TRAIN_IMAGES]
72
train_masks = sorted(glob(os.path.join(DATA_DIR, "Category_ids/*")))[:NUM_TRAIN_IMAGES]
73
val_images = sorted(glob(os.path.join(DATA_DIR, "Images/*")))[
74
NUM_TRAIN_IMAGES : NUM_VAL_IMAGES + NUM_TRAIN_IMAGES
75
]
76
val_masks = sorted(glob(os.path.join(DATA_DIR, "Category_ids/*")))[
77
NUM_TRAIN_IMAGES : NUM_VAL_IMAGES + NUM_TRAIN_IMAGES
78
]
79
80
81
def read_image(image_path, mask=False):
82
image = tf_io.read_file(image_path)
83
if mask:
84
image = tf_image.decode_png(image, channels=1)
85
image.set_shape([None, None, 1])
86
image = tf_image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE])
87
else:
88
image = tf_image.decode_png(image, channels=3)
89
image.set_shape([None, None, 3])
90
image = tf_image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE])
91
return image
92
93
94
def load_data(image_list, mask_list):
95
image = read_image(image_list)
96
mask = read_image(mask_list, mask=True)
97
return image, mask
98
99
100
def data_generator(image_list, mask_list):
101
dataset = tf_data.Dataset.from_tensor_slices((image_list, mask_list))
102
dataset = dataset.map(load_data, num_parallel_calls=tf_data.AUTOTUNE)
103
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
104
return dataset
105
106
107
train_dataset = data_generator(train_images, train_masks)
108
val_dataset = data_generator(val_images, val_masks)
109
110
print("Train Dataset:", train_dataset)
111
print("Val Dataset:", val_dataset)
112
113
"""
114
## Building the DeepLabV3+ model
115
116
DeepLabv3+ extends DeepLabv3 by adding an encoder-decoder structure. The encoder module
117
processes multiscale contextual information by applying dilated convolution at multiple
118
scales, while the decoder module refines the segmentation results along object boundaries.
119
120
![](https://github.com/lattice-ai/DeepLabV3-Plus/raw/master/assets/deeplabv3_plus_diagram.png)
121
122
**Dilated convolution:** With dilated convolution, as we go deeper in the network, we can keep the
123
stride constant but with larger field-of-view without increasing the number of parameters
124
or the amount of computation. Besides, it enables larger output feature maps, which is
125
useful for semantic segmentation.
126
127
The reason for using **Dilated Spatial Pyramid Pooling** is that it was shown that as the
128
sampling rate becomes larger, the number of valid filter weights (i.e., weights that
129
are applied to the valid feature region, instead of padded zeros) becomes smaller.
130
"""
131
132
133
def convolution_block(
134
block_input,
135
num_filters=256,
136
kernel_size=3,
137
dilation_rate=1,
138
use_bias=False,
139
):
140
x = layers.Conv2D(
141
num_filters,
142
kernel_size=kernel_size,
143
dilation_rate=dilation_rate,
144
padding="same",
145
use_bias=use_bias,
146
kernel_initializer=keras.initializers.HeNormal(),
147
)(block_input)
148
x = layers.BatchNormalization()(x)
149
return ops.nn.relu(x)
150
151
152
def DilatedSpatialPyramidPooling(dspp_input):
153
dims = dspp_input.shape
154
x = layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
155
x = convolution_block(x, kernel_size=1, use_bias=True)
156
out_pool = layers.UpSampling2D(
157
size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]),
158
interpolation="bilinear",
159
)(x)
160
161
out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
162
out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
163
out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
164
out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)
165
166
x = layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])
167
output = convolution_block(x, kernel_size=1)
168
return output
169
170
171
"""
172
The encoder features are first bilinearly upsampled by a factor 4, and then
173
concatenated with the corresponding low-level features from the network backbone that
174
have the same spatial resolution. For this example, we
175
use a ResNet50 pretrained on ImageNet as the backbone model, and we use
176
the low-level features from the `conv4_block6_2_relu` block of the backbone.
177
"""
178
179
180
def DeeplabV3Plus(image_size, num_classes):
181
model_input = keras.Input(shape=(image_size, image_size, 3))
182
preprocessed = keras.applications.resnet50.preprocess_input(model_input)
183
resnet50 = keras.applications.ResNet50(
184
weights="imagenet", include_top=False, input_tensor=preprocessed
185
)
186
x = resnet50.get_layer("conv4_block6_2_relu").output
187
x = DilatedSpatialPyramidPooling(x)
188
189
input_a = layers.UpSampling2D(
190
size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]),
191
interpolation="bilinear",
192
)(x)
193
input_b = resnet50.get_layer("conv2_block3_2_relu").output
194
input_b = convolution_block(input_b, num_filters=48, kernel_size=1)
195
196
x = layers.Concatenate(axis=-1)([input_a, input_b])
197
x = convolution_block(x)
198
x = convolution_block(x)
199
x = layers.UpSampling2D(
200
size=(image_size // x.shape[1], image_size // x.shape[2]),
201
interpolation="bilinear",
202
)(x)
203
model_output = layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
204
return keras.Model(inputs=model_input, outputs=model_output)
205
206
207
model = DeeplabV3Plus(image_size=IMAGE_SIZE, num_classes=NUM_CLASSES)
208
model.summary()
209
210
"""
211
## Training
212
213
We train the model using sparse categorical crossentropy as the loss function, and
214
Adam as the optimizer.
215
"""
216
217
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
218
model.compile(
219
optimizer=keras.optimizers.Adam(learning_rate=0.001),
220
loss=loss,
221
metrics=["accuracy"],
222
)
223
224
history = model.fit(train_dataset, validation_data=val_dataset, epochs=25)
225
226
plt.plot(history.history["loss"])
227
plt.title("Training Loss")
228
plt.ylabel("loss")
229
plt.xlabel("epoch")
230
plt.show()
231
232
plt.plot(history.history["accuracy"])
233
plt.title("Training Accuracy")
234
plt.ylabel("accuracy")
235
plt.xlabel("epoch")
236
plt.show()
237
238
plt.plot(history.history["val_loss"])
239
plt.title("Validation Loss")
240
plt.ylabel("val_loss")
241
plt.xlabel("epoch")
242
plt.show()
243
244
plt.plot(history.history["val_accuracy"])
245
plt.title("Validation Accuracy")
246
plt.ylabel("val_accuracy")
247
plt.xlabel("epoch")
248
plt.show()
249
250
"""
251
## Inference using Colormap Overlay
252
253
The raw predictions from the model represent a one-hot encoded tensor of shape `(N, 512, 512, 20)`
254
where each one of the 20 channels is a binary mask corresponding to a predicted label.
255
In order to visualize the results, we plot them as RGB segmentation masks where each pixel
256
is represented by a unique color corresponding to the particular label predicted. We can easily
257
find the color corresponding to each label from the `human_colormap.mat` file provided as part
258
of the dataset. We would also plot an overlay of the RGB segmentation mask on the input image as
259
this further helps us to identify the different categories present in the image more intuitively.
260
"""
261
262
# Loading the Colormap
263
colormap = loadmat(
264
"./instance-level_human_parsing/instance-level_human_parsing/human_colormap.mat"
265
)["colormap"]
266
colormap = colormap * 100
267
colormap = colormap.astype(np.uint8)
268
269
270
def infer(model, image_tensor):
271
predictions = model.predict(np.expand_dims((image_tensor), axis=0))
272
predictions = np.squeeze(predictions)
273
predictions = np.argmax(predictions, axis=2)
274
return predictions
275
276
277
def decode_segmentation_masks(mask, colormap, n_classes):
278
r = np.zeros_like(mask).astype(np.uint8)
279
g = np.zeros_like(mask).astype(np.uint8)
280
b = np.zeros_like(mask).astype(np.uint8)
281
for l in range(0, n_classes):
282
idx = mask == l
283
r[idx] = colormap[l, 0]
284
g[idx] = colormap[l, 1]
285
b[idx] = colormap[l, 2]
286
rgb = np.stack([r, g, b], axis=2)
287
return rgb
288
289
290
def get_overlay(image, colored_mask):
291
image = keras.utils.array_to_img(image)
292
image = np.array(image).astype(np.uint8)
293
overlay = cv2.addWeighted(image, 0.35, colored_mask, 0.65, 0)
294
return overlay
295
296
297
def plot_samples_matplotlib(display_list, figsize=(5, 3)):
298
_, axes = plt.subplots(nrows=1, ncols=len(display_list), figsize=figsize)
299
for i in range(len(display_list)):
300
if display_list[i].shape[-1] == 3:
301
axes[i].imshow(keras.utils.array_to_img(display_list[i]))
302
else:
303
axes[i].imshow(display_list[i])
304
plt.show()
305
306
307
def plot_predictions(images_list, colormap, model):
308
for image_file in images_list:
309
image_tensor = read_image(image_file)
310
prediction_mask = infer(image_tensor=image_tensor, model=model)
311
prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 20)
312
overlay = get_overlay(image_tensor, prediction_colormap)
313
plot_samples_matplotlib(
314
[image_tensor, overlay, prediction_colormap], figsize=(18, 14)
315
)
316
317
318
"""
319
### Inference on Train Images
320
"""
321
322
plot_predictions(train_images[:4], colormap, model=model)
323
324
"""
325
### Inference on Validation Images
326
327
You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/deeplabv3p-resnet50)
328
and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/Human-Part-Segmentation).
329
"""
330
331
plot_predictions(val_images[:4], colormap, model=model)
332
333