Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
matterport
GitHub Repository: matterport/Mask_RCNN
Path: blob/master/mrcnn/visualize.py
239 views
1
"""
2
Mask R-CNN
3
Display and Visualization Functions.
4
5
Copyright (c) 2017 Matterport, Inc.
6
Licensed under the MIT License (see LICENSE for details)
7
Written by Waleed Abdulla
8
"""
9
10
import os
11
import sys
12
import random
13
import itertools
14
import colorsys
15
16
import numpy as np
17
from skimage.measure import find_contours
18
import matplotlib.pyplot as plt
19
from matplotlib import patches, lines
20
from matplotlib.patches import Polygon
21
import IPython.display
22
23
# Root directory of the project
24
ROOT_DIR = os.path.abspath("../")
25
26
# Import Mask RCNN
27
sys.path.append(ROOT_DIR) # To find local version of the library
28
from mrcnn import utils
29
30
31
############################################################
32
# Visualization
33
############################################################
34
35
def display_images(images, titles=None, cols=4, cmap=None, norm=None,
36
interpolation=None):
37
"""Display the given set of images, optionally with titles.
38
images: list or array of image tensors in HWC format.
39
titles: optional. A list of titles to display with each image.
40
cols: number of images per row
41
cmap: Optional. Color map to use. For example, "Blues".
42
norm: Optional. A Normalize instance to map values to colors.
43
interpolation: Optional. Image interpolation to use for display.
44
"""
45
titles = titles if titles is not None else [""] * len(images)
46
rows = len(images) // cols + 1
47
plt.figure(figsize=(14, 14 * rows // cols))
48
i = 1
49
for image, title in zip(images, titles):
50
plt.subplot(rows, cols, i)
51
plt.title(title, fontsize=9)
52
plt.axis('off')
53
plt.imshow(image.astype(np.uint8), cmap=cmap,
54
norm=norm, interpolation=interpolation)
55
i += 1
56
plt.show()
57
58
59
def random_colors(N, bright=True):
60
"""
61
Generate random colors.
62
To get visually distinct colors, generate them in HSV space then
63
convert to RGB.
64
"""
65
brightness = 1.0 if bright else 0.7
66
hsv = [(i / N, 1, brightness) for i in range(N)]
67
colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
68
random.shuffle(colors)
69
return colors
70
71
72
def apply_mask(image, mask, color, alpha=0.5):
73
"""Apply the given mask to the image.
74
"""
75
for c in range(3):
76
image[:, :, c] = np.where(mask == 1,
77
image[:, :, c] *
78
(1 - alpha) + alpha * color[c] * 255,
79
image[:, :, c])
80
return image
81
82
83
def display_instances(image, boxes, masks, class_ids, class_names,
84
scores=None, title="",
85
figsize=(16, 16), ax=None,
86
show_mask=True, show_bbox=True,
87
colors=None, captions=None):
88
"""
89
boxes: [num_instance, (y1, x1, y2, x2, class_id)] in image coordinates.
90
masks: [height, width, num_instances]
91
class_ids: [num_instances]
92
class_names: list of class names of the dataset
93
scores: (optional) confidence scores for each box
94
title: (optional) Figure title
95
show_mask, show_bbox: To show masks and bounding boxes or not
96
figsize: (optional) the size of the image
97
colors: (optional) An array or colors to use with each object
98
captions: (optional) A list of strings to use as captions for each object
99
"""
100
# Number of instances
101
N = boxes.shape[0]
102
if not N:
103
print("\n*** No instances to display *** \n")
104
else:
105
assert boxes.shape[0] == masks.shape[-1] == class_ids.shape[0]
106
107
# If no axis is passed, create one and automatically call show()
108
auto_show = False
109
if not ax:
110
_, ax = plt.subplots(1, figsize=figsize)
111
auto_show = True
112
113
# Generate random colors
114
colors = colors or random_colors(N)
115
116
# Show area outside image boundaries.
117
height, width = image.shape[:2]
118
ax.set_ylim(height + 10, -10)
119
ax.set_xlim(-10, width + 10)
120
ax.axis('off')
121
ax.set_title(title)
122
123
masked_image = image.astype(np.uint32).copy()
124
for i in range(N):
125
color = colors[i]
126
127
# Bounding box
128
if not np.any(boxes[i]):
129
# Skip this instance. Has no bbox. Likely lost in image cropping.
130
continue
131
y1, x1, y2, x2 = boxes[i]
132
if show_bbox:
133
p = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2,
134
alpha=0.7, linestyle="dashed",
135
edgecolor=color, facecolor='none')
136
ax.add_patch(p)
137
138
# Label
139
if not captions:
140
class_id = class_ids[i]
141
score = scores[i] if scores is not None else None
142
label = class_names[class_id]
143
caption = "{} {:.3f}".format(label, score) if score else label
144
else:
145
caption = captions[i]
146
ax.text(x1, y1 + 8, caption,
147
color='w', size=11, backgroundcolor="none")
148
149
# Mask
150
mask = masks[:, :, i]
151
if show_mask:
152
masked_image = apply_mask(masked_image, mask, color)
153
154
# Mask Polygon
155
# Pad to ensure proper polygons for masks that touch image edges.
156
padded_mask = np.zeros(
157
(mask.shape[0] + 2, mask.shape[1] + 2), dtype=np.uint8)
158
padded_mask[1:-1, 1:-1] = mask
159
contours = find_contours(padded_mask, 0.5)
160
for verts in contours:
161
# Subtract the padding and flip (y, x) to (x, y)
162
verts = np.fliplr(verts) - 1
163
p = Polygon(verts, facecolor="none", edgecolor=color)
164
ax.add_patch(p)
165
ax.imshow(masked_image.astype(np.uint8))
166
if auto_show:
167
plt.show()
168
169
170
def display_differences(image,
171
gt_box, gt_class_id, gt_mask,
172
pred_box, pred_class_id, pred_score, pred_mask,
173
class_names, title="", ax=None,
174
show_mask=True, show_box=True,
175
iou_threshold=0.5, score_threshold=0.5):
176
"""Display ground truth and prediction instances on the same image."""
177
# Match predictions to ground truth
178
gt_match, pred_match, overlaps = utils.compute_matches(
179
gt_box, gt_class_id, gt_mask,
180
pred_box, pred_class_id, pred_score, pred_mask,
181
iou_threshold=iou_threshold, score_threshold=score_threshold)
182
# Ground truth = green. Predictions = red
183
colors = [(0, 1, 0, .8)] * len(gt_match)\
184
+ [(1, 0, 0, 1)] * len(pred_match)
185
# Concatenate GT and predictions
186
class_ids = np.concatenate([gt_class_id, pred_class_id])
187
scores = np.concatenate([np.zeros([len(gt_match)]), pred_score])
188
boxes = np.concatenate([gt_box, pred_box])
189
masks = np.concatenate([gt_mask, pred_mask], axis=-1)
190
# Captions per instance show score/IoU
191
captions = ["" for m in gt_match] + ["{:.2f} / {:.2f}".format(
192
pred_score[i],
193
(overlaps[i, int(pred_match[i])]
194
if pred_match[i] > -1 else overlaps[i].max()))
195
for i in range(len(pred_match))]
196
# Set title if not provided
197
title = title or "Ground Truth and Detections\n GT=green, pred=red, captions: score/IoU"
198
# Display
199
display_instances(
200
image,
201
boxes, masks, class_ids,
202
class_names, scores, ax=ax,
203
show_bbox=show_box, show_mask=show_mask,
204
colors=colors, captions=captions,
205
title=title)
206
207
208
def draw_rois(image, rois, refined_rois, mask, class_ids, class_names, limit=10):
209
"""
210
anchors: [n, (y1, x1, y2, x2)] list of anchors in image coordinates.
211
proposals: [n, 4] the same anchors but refined to fit objects better.
212
"""
213
masked_image = image.copy()
214
215
# Pick random anchors in case there are too many.
216
ids = np.arange(rois.shape[0], dtype=np.int32)
217
ids = np.random.choice(
218
ids, limit, replace=False) if ids.shape[0] > limit else ids
219
220
fig, ax = plt.subplots(1, figsize=(12, 12))
221
if rois.shape[0] > limit:
222
plt.title("Showing {} random ROIs out of {}".format(
223
len(ids), rois.shape[0]))
224
else:
225
plt.title("{} ROIs".format(len(ids)))
226
227
# Show area outside image boundaries.
228
ax.set_ylim(image.shape[0] + 20, -20)
229
ax.set_xlim(-50, image.shape[1] + 20)
230
ax.axis('off')
231
232
for i, id in enumerate(ids):
233
color = np.random.rand(3)
234
class_id = class_ids[id]
235
# ROI
236
y1, x1, y2, x2 = rois[id]
237
p = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2,
238
edgecolor=color if class_id else "gray",
239
facecolor='none', linestyle="dashed")
240
ax.add_patch(p)
241
# Refined ROI
242
if class_id:
243
ry1, rx1, ry2, rx2 = refined_rois[id]
244
p = patches.Rectangle((rx1, ry1), rx2 - rx1, ry2 - ry1, linewidth=2,
245
edgecolor=color, facecolor='none')
246
ax.add_patch(p)
247
# Connect the top-left corners of the anchor and proposal for easy visualization
248
ax.add_line(lines.Line2D([x1, rx1], [y1, ry1], color=color))
249
250
# Label
251
label = class_names[class_id]
252
ax.text(rx1, ry1 + 8, "{}".format(label),
253
color='w', size=11, backgroundcolor="none")
254
255
# Mask
256
m = utils.unmold_mask(mask[id], rois[id]
257
[:4].astype(np.int32), image.shape)
258
masked_image = apply_mask(masked_image, m, color)
259
260
ax.imshow(masked_image)
261
262
# Print stats
263
print("Positive ROIs: ", class_ids[class_ids > 0].shape[0])
264
print("Negative ROIs: ", class_ids[class_ids == 0].shape[0])
265
print("Positive Ratio: {:.2f}".format(
266
class_ids[class_ids > 0].shape[0] / class_ids.shape[0]))
267
268
269
# TODO: Replace with matplotlib equivalent?
270
def draw_box(image, box, color):
271
"""Draw 3-pixel width bounding boxes on the given image array.
272
color: list of 3 int values for RGB.
273
"""
274
y1, x1, y2, x2 = box
275
image[y1:y1 + 2, x1:x2] = color
276
image[y2:y2 + 2, x1:x2] = color
277
image[y1:y2, x1:x1 + 2] = color
278
image[y1:y2, x2:x2 + 2] = color
279
return image
280
281
282
def display_top_masks(image, mask, class_ids, class_names, limit=4):
283
"""Display the given image and the top few class masks."""
284
to_display = []
285
titles = []
286
to_display.append(image)
287
titles.append("H x W={}x{}".format(image.shape[0], image.shape[1]))
288
# Pick top prominent classes in this image
289
unique_class_ids = np.unique(class_ids)
290
mask_area = [np.sum(mask[:, :, np.where(class_ids == i)[0]])
291
for i in unique_class_ids]
292
top_ids = [v[0] for v in sorted(zip(unique_class_ids, mask_area),
293
key=lambda r: r[1], reverse=True) if v[1] > 0]
294
# Generate images and titles
295
for i in range(limit):
296
class_id = top_ids[i] if i < len(top_ids) else -1
297
# Pull masks of instances belonging to the same class.
298
m = mask[:, :, np.where(class_ids == class_id)[0]]
299
m = np.sum(m * np.arange(1, m.shape[-1] + 1), -1)
300
to_display.append(m)
301
titles.append(class_names[class_id] if class_id != -1 else "-")
302
display_images(to_display, titles=titles, cols=limit + 1, cmap="Blues_r")
303
304
305
def plot_precision_recall(AP, precisions, recalls):
306
"""Draw the precision-recall curve.
307
308
AP: Average precision at IoU >= 0.5
309
precisions: list of precision values
310
recalls: list of recall values
311
"""
312
# Plot the Precision-Recall curve
313
_, ax = plt.subplots(1)
314
ax.set_title("Precision-Recall Curve. AP@50 = {:.3f}".format(AP))
315
ax.set_ylim(0, 1.1)
316
ax.set_xlim(0, 1.1)
317
_ = ax.plot(recalls, precisions)
318
319
320
def plot_overlaps(gt_class_ids, pred_class_ids, pred_scores,
321
overlaps, class_names, threshold=0.5):
322
"""Draw a grid showing how ground truth objects are classified.
323
gt_class_ids: [N] int. Ground truth class IDs
324
pred_class_id: [N] int. Predicted class IDs
325
pred_scores: [N] float. The probability scores of predicted classes
326
overlaps: [pred_boxes, gt_boxes] IoU overlaps of predictions and GT boxes.
327
class_names: list of all class names in the dataset
328
threshold: Float. The prediction probability required to predict a class
329
"""
330
gt_class_ids = gt_class_ids[gt_class_ids != 0]
331
pred_class_ids = pred_class_ids[pred_class_ids != 0]
332
333
plt.figure(figsize=(12, 10))
334
plt.imshow(overlaps, interpolation='nearest', cmap=plt.cm.Blues)
335
plt.yticks(np.arange(len(pred_class_ids)),
336
["{} ({:.2f})".format(class_names[int(id)], pred_scores[i])
337
for i, id in enumerate(pred_class_ids)])
338
plt.xticks(np.arange(len(gt_class_ids)),
339
[class_names[int(id)] for id in gt_class_ids], rotation=90)
340
341
thresh = overlaps.max() / 2.
342
for i, j in itertools.product(range(overlaps.shape[0]),
343
range(overlaps.shape[1])):
344
text = ""
345
if overlaps[i, j] > threshold:
346
text = "match" if gt_class_ids[j] == pred_class_ids[i] else "wrong"
347
color = ("white" if overlaps[i, j] > thresh
348
else "black" if overlaps[i, j] > 0
349
else "grey")
350
plt.text(j, i, "{:.3f}\n{}".format(overlaps[i, j], text),
351
horizontalalignment="center", verticalalignment="center",
352
fontsize=9, color=color)
353
354
plt.tight_layout()
355
plt.xlabel("Ground Truth")
356
plt.ylabel("Predictions")
357
358
359
def draw_boxes(image, boxes=None, refined_boxes=None,
360
masks=None, captions=None, visibilities=None,
361
title="", ax=None):
362
"""Draw bounding boxes and segmentation masks with different
363
customizations.
364
365
boxes: [N, (y1, x1, y2, x2, class_id)] in image coordinates.
366
refined_boxes: Like boxes, but draw with solid lines to show
367
that they're the result of refining 'boxes'.
368
masks: [N, height, width]
369
captions: List of N titles to display on each box
370
visibilities: (optional) List of values of 0, 1, or 2. Determine how
371
prominent each bounding box should be.
372
title: An optional title to show over the image
373
ax: (optional) Matplotlib axis to draw on.
374
"""
375
# Number of boxes
376
assert boxes is not None or refined_boxes is not None
377
N = boxes.shape[0] if boxes is not None else refined_boxes.shape[0]
378
379
# Matplotlib Axis
380
if not ax:
381
_, ax = plt.subplots(1, figsize=(12, 12))
382
383
# Generate random colors
384
colors = random_colors(N)
385
386
# Show area outside image boundaries.
387
margin = image.shape[0] // 10
388
ax.set_ylim(image.shape[0] + margin, -margin)
389
ax.set_xlim(-margin, image.shape[1] + margin)
390
ax.axis('off')
391
392
ax.set_title(title)
393
394
masked_image = image.astype(np.uint32).copy()
395
for i in range(N):
396
# Box visibility
397
visibility = visibilities[i] if visibilities is not None else 1
398
if visibility == 0:
399
color = "gray"
400
style = "dotted"
401
alpha = 0.5
402
elif visibility == 1:
403
color = colors[i]
404
style = "dotted"
405
alpha = 1
406
elif visibility == 2:
407
color = colors[i]
408
style = "solid"
409
alpha = 1
410
411
# Boxes
412
if boxes is not None:
413
if not np.any(boxes[i]):
414
# Skip this instance. Has no bbox. Likely lost in cropping.
415
continue
416
y1, x1, y2, x2 = boxes[i]
417
p = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2,
418
alpha=alpha, linestyle=style,
419
edgecolor=color, facecolor='none')
420
ax.add_patch(p)
421
422
# Refined boxes
423
if refined_boxes is not None and visibility > 0:
424
ry1, rx1, ry2, rx2 = refined_boxes[i].astype(np.int32)
425
p = patches.Rectangle((rx1, ry1), rx2 - rx1, ry2 - ry1, linewidth=2,
426
edgecolor=color, facecolor='none')
427
ax.add_patch(p)
428
# Connect the top-left corners of the anchor and proposal
429
if boxes is not None:
430
ax.add_line(lines.Line2D([x1, rx1], [y1, ry1], color=color))
431
432
# Captions
433
if captions is not None:
434
caption = captions[i]
435
# If there are refined boxes, display captions on them
436
if refined_boxes is not None:
437
y1, x1, y2, x2 = ry1, rx1, ry2, rx2
438
ax.text(x1, y1, caption, size=11, verticalalignment='top',
439
color='w', backgroundcolor="none",
440
bbox={'facecolor': color, 'alpha': 0.5,
441
'pad': 2, 'edgecolor': 'none'})
442
443
# Masks
444
if masks is not None:
445
mask = masks[:, :, i]
446
masked_image = apply_mask(masked_image, mask, color)
447
# Mask Polygon
448
# Pad to ensure proper polygons for masks that touch image edges.
449
padded_mask = np.zeros(
450
(mask.shape[0] + 2, mask.shape[1] + 2), dtype=np.uint8)
451
padded_mask[1:-1, 1:-1] = mask
452
contours = find_contours(padded_mask, 0.5)
453
for verts in contours:
454
# Subtract the padding and flip (y, x) to (x, y)
455
verts = np.fliplr(verts) - 1
456
p = Polygon(verts, facecolor="none", edgecolor=color)
457
ax.add_patch(p)
458
ax.imshow(masked_image.astype(np.uint8))
459
460
461
def display_table(table):
462
"""Display values in a table format.
463
table: an iterable of rows, and each row is an iterable of values.
464
"""
465
html = ""
466
for row in table:
467
row_html = ""
468
for col in row:
469
row_html += "<td>{:40}</td>".format(str(col))
470
html += "<tr>" + row_html + "</tr>"
471
html = "<table>" + html + "</table>"
472
IPython.display.display(IPython.display.HTML(html))
473
474
475
def display_weight_stats(model):
476
"""Scans all the weights in the model and returns a list of tuples
477
that contain stats about each weight.
478
"""
479
layers = model.get_trainable_layers()
480
table = [["WEIGHT NAME", "SHAPE", "MIN", "MAX", "STD"]]
481
for l in layers:
482
weight_values = l.get_weights() # list of Numpy arrays
483
weight_tensors = l.weights # list of TF tensors
484
for i, w in enumerate(weight_values):
485
weight_name = weight_tensors[i].name
486
# Detect problematic layers. Exclude biases of conv layers.
487
alert = ""
488
if w.min() == w.max() and not (l.__class__.__name__ == "Conv2D" and i == 1):
489
alert += "<span style='color:red'>*** dead?</span>"
490
if np.abs(w.min()) > 1000 or np.abs(w.max()) > 1000:
491
alert += "<span style='color:red'>*** Overflow?</span>"
492
# Add row
493
table.append([
494
weight_name + alert,
495
str(w.shape),
496
"{:+9.4f}".format(w.min()),
497
"{:+10.4f}".format(w.max()),
498
"{:+9.4f}".format(w.std()),
499
])
500
display_table(table)
501
502