Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
matterport
GitHub Repository: matterport/Mask_RCNN
Path: blob/master/mrcnn/utils.py
239 views
1
"""
2
Mask R-CNN
3
Common utility functions and classes.
4
5
Copyright (c) 2017 Matterport, Inc.
6
Licensed under the MIT License (see LICENSE for details)
7
Written by Waleed Abdulla
8
"""
9
10
import sys
11
import os
12
import logging
13
import math
14
import random
15
import numpy as np
16
import tensorflow as tf
17
import scipy
18
import skimage.color
19
import skimage.io
20
import skimage.transform
21
import urllib.request
22
import shutil
23
import warnings
24
from distutils.version import LooseVersion
25
26
# URL from which to download the latest COCO trained weights
27
COCO_MODEL_URL = "https://github.com/matterport/Mask_RCNN/releases/download/v2.0/mask_rcnn_coco.h5"
28
29
30
############################################################
31
# Bounding Boxes
32
############################################################
33
34
def extract_bboxes(mask):
35
"""Compute bounding boxes from masks.
36
mask: [height, width, num_instances]. Mask pixels are either 1 or 0.
37
38
Returns: bbox array [num_instances, (y1, x1, y2, x2)].
39
"""
40
boxes = np.zeros([mask.shape[-1], 4], dtype=np.int32)
41
for i in range(mask.shape[-1]):
42
m = mask[:, :, i]
43
# Bounding box.
44
horizontal_indicies = np.where(np.any(m, axis=0))[0]
45
vertical_indicies = np.where(np.any(m, axis=1))[0]
46
if horizontal_indicies.shape[0]:
47
x1, x2 = horizontal_indicies[[0, -1]]
48
y1, y2 = vertical_indicies[[0, -1]]
49
# x2 and y2 should not be part of the box. Increment by 1.
50
x2 += 1
51
y2 += 1
52
else:
53
# No mask for this instance. Might happen due to
54
# resizing or cropping. Set bbox to zeros
55
x1, x2, y1, y2 = 0, 0, 0, 0
56
boxes[i] = np.array([y1, x1, y2, x2])
57
return boxes.astype(np.int32)
58
59
60
def compute_iou(box, boxes, box_area, boxes_area):
61
"""Calculates IoU of the given box with the array of the given boxes.
62
box: 1D vector [y1, x1, y2, x2]
63
boxes: [boxes_count, (y1, x1, y2, x2)]
64
box_area: float. the area of 'box'
65
boxes_area: array of length boxes_count.
66
67
Note: the areas are passed in rather than calculated here for
68
efficiency. Calculate once in the caller to avoid duplicate work.
69
"""
70
# Calculate intersection areas
71
y1 = np.maximum(box[0], boxes[:, 0])
72
y2 = np.minimum(box[2], boxes[:, 2])
73
x1 = np.maximum(box[1], boxes[:, 1])
74
x2 = np.minimum(box[3], boxes[:, 3])
75
intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0)
76
union = box_area + boxes_area[:] - intersection[:]
77
iou = intersection / union
78
return iou
79
80
81
def compute_overlaps(boxes1, boxes2):
82
"""Computes IoU overlaps between two sets of boxes.
83
boxes1, boxes2: [N, (y1, x1, y2, x2)].
84
85
For better performance, pass the largest set first and the smaller second.
86
"""
87
# Areas of anchors and GT boxes
88
area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
89
area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
90
91
# Compute overlaps to generate matrix [boxes1 count, boxes2 count]
92
# Each cell contains the IoU value.
93
overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0]))
94
for i in range(overlaps.shape[1]):
95
box2 = boxes2[i]
96
overlaps[:, i] = compute_iou(box2, boxes1, area2[i], area1)
97
return overlaps
98
99
100
def compute_overlaps_masks(masks1, masks2):
101
"""Computes IoU overlaps between two sets of masks.
102
masks1, masks2: [Height, Width, instances]
103
"""
104
105
# If either set of masks is empty return empty result
106
if masks1.shape[-1] == 0 or masks2.shape[-1] == 0:
107
return np.zeros((masks1.shape[-1], masks2.shape[-1]))
108
# flatten masks and compute their areas
109
masks1 = np.reshape(masks1 > .5, (-1, masks1.shape[-1])).astype(np.float32)
110
masks2 = np.reshape(masks2 > .5, (-1, masks2.shape[-1])).astype(np.float32)
111
area1 = np.sum(masks1, axis=0)
112
area2 = np.sum(masks2, axis=0)
113
114
# intersections and union
115
intersections = np.dot(masks1.T, masks2)
116
union = area1[:, None] + area2[None, :] - intersections
117
overlaps = intersections / union
118
119
return overlaps
120
121
122
def non_max_suppression(boxes, scores, threshold):
123
"""Performs non-maximum suppression and returns indices of kept boxes.
124
boxes: [N, (y1, x1, y2, x2)]. Notice that (y2, x2) lays outside the box.
125
scores: 1-D array of box scores.
126
threshold: Float. IoU threshold to use for filtering.
127
"""
128
assert boxes.shape[0] > 0
129
if boxes.dtype.kind != "f":
130
boxes = boxes.astype(np.float32)
131
132
# Compute box areas
133
y1 = boxes[:, 0]
134
x1 = boxes[:, 1]
135
y2 = boxes[:, 2]
136
x2 = boxes[:, 3]
137
area = (y2 - y1) * (x2 - x1)
138
139
# Get indicies of boxes sorted by scores (highest first)
140
ixs = scores.argsort()[::-1]
141
142
pick = []
143
while len(ixs) > 0:
144
# Pick top box and add its index to the list
145
i = ixs[0]
146
pick.append(i)
147
# Compute IoU of the picked box with the rest
148
iou = compute_iou(boxes[i], boxes[ixs[1:]], area[i], area[ixs[1:]])
149
# Identify boxes with IoU over the threshold. This
150
# returns indices into ixs[1:], so add 1 to get
151
# indices into ixs.
152
remove_ixs = np.where(iou > threshold)[0] + 1
153
# Remove indices of the picked and overlapped boxes.
154
ixs = np.delete(ixs, remove_ixs)
155
ixs = np.delete(ixs, 0)
156
return np.array(pick, dtype=np.int32)
157
158
159
def apply_box_deltas(boxes, deltas):
160
"""Applies the given deltas to the given boxes.
161
boxes: [N, (y1, x1, y2, x2)]. Note that (y2, x2) is outside the box.
162
deltas: [N, (dy, dx, log(dh), log(dw))]
163
"""
164
boxes = boxes.astype(np.float32)
165
# Convert to y, x, h, w
166
height = boxes[:, 2] - boxes[:, 0]
167
width = boxes[:, 3] - boxes[:, 1]
168
center_y = boxes[:, 0] + 0.5 * height
169
center_x = boxes[:, 1] + 0.5 * width
170
# Apply deltas
171
center_y += deltas[:, 0] * height
172
center_x += deltas[:, 1] * width
173
height *= np.exp(deltas[:, 2])
174
width *= np.exp(deltas[:, 3])
175
# Convert back to y1, x1, y2, x2
176
y1 = center_y - 0.5 * height
177
x1 = center_x - 0.5 * width
178
y2 = y1 + height
179
x2 = x1 + width
180
return np.stack([y1, x1, y2, x2], axis=1)
181
182
183
def box_refinement_graph(box, gt_box):
184
"""Compute refinement needed to transform box to gt_box.
185
box and gt_box are [N, (y1, x1, y2, x2)]
186
"""
187
box = tf.cast(box, tf.float32)
188
gt_box = tf.cast(gt_box, tf.float32)
189
190
height = box[:, 2] - box[:, 0]
191
width = box[:, 3] - box[:, 1]
192
center_y = box[:, 0] + 0.5 * height
193
center_x = box[:, 1] + 0.5 * width
194
195
gt_height = gt_box[:, 2] - gt_box[:, 0]
196
gt_width = gt_box[:, 3] - gt_box[:, 1]
197
gt_center_y = gt_box[:, 0] + 0.5 * gt_height
198
gt_center_x = gt_box[:, 1] + 0.5 * gt_width
199
200
dy = (gt_center_y - center_y) / height
201
dx = (gt_center_x - center_x) / width
202
dh = tf.log(gt_height / height)
203
dw = tf.log(gt_width / width)
204
205
result = tf.stack([dy, dx, dh, dw], axis=1)
206
return result
207
208
209
def box_refinement(box, gt_box):
210
"""Compute refinement needed to transform box to gt_box.
211
box and gt_box are [N, (y1, x1, y2, x2)]. (y2, x2) is
212
assumed to be outside the box.
213
"""
214
box = box.astype(np.float32)
215
gt_box = gt_box.astype(np.float32)
216
217
height = box[:, 2] - box[:, 0]
218
width = box[:, 3] - box[:, 1]
219
center_y = box[:, 0] + 0.5 * height
220
center_x = box[:, 1] + 0.5 * width
221
222
gt_height = gt_box[:, 2] - gt_box[:, 0]
223
gt_width = gt_box[:, 3] - gt_box[:, 1]
224
gt_center_y = gt_box[:, 0] + 0.5 * gt_height
225
gt_center_x = gt_box[:, 1] + 0.5 * gt_width
226
227
dy = (gt_center_y - center_y) / height
228
dx = (gt_center_x - center_x) / width
229
dh = np.log(gt_height / height)
230
dw = np.log(gt_width / width)
231
232
return np.stack([dy, dx, dh, dw], axis=1)
233
234
235
############################################################
236
# Dataset
237
############################################################
238
239
class Dataset(object):
240
"""The base class for dataset classes.
241
To use it, create a new class that adds functions specific to the dataset
242
you want to use. For example:
243
244
class CatsAndDogsDataset(Dataset):
245
def load_cats_and_dogs(self):
246
...
247
def load_mask(self, image_id):
248
...
249
def image_reference(self, image_id):
250
...
251
252
See COCODataset and ShapesDataset as examples.
253
"""
254
255
def __init__(self, class_map=None):
256
self._image_ids = []
257
self.image_info = []
258
# Background is always the first class
259
self.class_info = [{"source": "", "id": 0, "name": "BG"}]
260
self.source_class_ids = {}
261
262
def add_class(self, source, class_id, class_name):
263
assert "." not in source, "Source name cannot contain a dot"
264
# Does the class exist already?
265
for info in self.class_info:
266
if info['source'] == source and info["id"] == class_id:
267
# source.class_id combination already available, skip
268
return
269
# Add the class
270
self.class_info.append({
271
"source": source,
272
"id": class_id,
273
"name": class_name,
274
})
275
276
def add_image(self, source, image_id, path, **kwargs):
277
image_info = {
278
"id": image_id,
279
"source": source,
280
"path": path,
281
}
282
image_info.update(kwargs)
283
self.image_info.append(image_info)
284
285
def image_reference(self, image_id):
286
"""Return a link to the image in its source Website or details about
287
the image that help looking it up or debugging it.
288
289
Override for your dataset, but pass to this function
290
if you encounter images not in your dataset.
291
"""
292
return ""
293
294
def prepare(self, class_map=None):
295
"""Prepares the Dataset class for use.
296
297
TODO: class map is not supported yet. When done, it should handle mapping
298
classes from different datasets to the same class ID.
299
"""
300
301
def clean_name(name):
302
"""Returns a shorter version of object names for cleaner display."""
303
return ",".join(name.split(",")[:1])
304
305
# Build (or rebuild) everything else from the info dicts.
306
self.num_classes = len(self.class_info)
307
self.class_ids = np.arange(self.num_classes)
308
self.class_names = [clean_name(c["name"]) for c in self.class_info]
309
self.num_images = len(self.image_info)
310
self._image_ids = np.arange(self.num_images)
311
312
# Mapping from source class and image IDs to internal IDs
313
self.class_from_source_map = {"{}.{}".format(info['source'], info['id']): id
314
for info, id in zip(self.class_info, self.class_ids)}
315
self.image_from_source_map = {"{}.{}".format(info['source'], info['id']): id
316
for info, id in zip(self.image_info, self.image_ids)}
317
318
# Map sources to class_ids they support
319
self.sources = list(set([i['source'] for i in self.class_info]))
320
self.source_class_ids = {}
321
# Loop over datasets
322
for source in self.sources:
323
self.source_class_ids[source] = []
324
# Find classes that belong to this dataset
325
for i, info in enumerate(self.class_info):
326
# Include BG class in all datasets
327
if i == 0 or source == info['source']:
328
self.source_class_ids[source].append(i)
329
330
def map_source_class_id(self, source_class_id):
331
"""Takes a source class ID and returns the int class ID assigned to it.
332
333
For example:
334
dataset.map_source_class_id("coco.12") -> 23
335
"""
336
return self.class_from_source_map[source_class_id]
337
338
def get_source_class_id(self, class_id, source):
339
"""Map an internal class ID to the corresponding class ID in the source dataset."""
340
info = self.class_info[class_id]
341
assert info['source'] == source
342
return info['id']
343
344
@property
345
def image_ids(self):
346
return self._image_ids
347
348
def source_image_link(self, image_id):
349
"""Returns the path or URL to the image.
350
Override this to return a URL to the image if it's available online for easy
351
debugging.
352
"""
353
return self.image_info[image_id]["path"]
354
355
def load_image(self, image_id):
356
"""Load the specified image and return a [H,W,3] Numpy array.
357
"""
358
# Load image
359
image = skimage.io.imread(self.image_info[image_id]['path'])
360
# If grayscale. Convert to RGB for consistency.
361
if image.ndim != 3:
362
image = skimage.color.gray2rgb(image)
363
# If has an alpha channel, remove it for consistency
364
if image.shape[-1] == 4:
365
image = image[..., :3]
366
return image
367
368
def load_mask(self, image_id):
369
"""Load instance masks for the given image.
370
371
Different datasets use different ways to store masks. Override this
372
method to load instance masks and return them in the form of am
373
array of binary masks of shape [height, width, instances].
374
375
Returns:
376
masks: A bool array of shape [height, width, instance count] with
377
a binary mask per instance.
378
class_ids: a 1D array of class IDs of the instance masks.
379
"""
380
# Override this function to load a mask from your dataset.
381
# Otherwise, it returns an empty mask.
382
logging.warning("You are using the default load_mask(), maybe you need to define your own one.")
383
mask = np.empty([0, 0, 0])
384
class_ids = np.empty([0], np.int32)
385
return mask, class_ids
386
387
388
def resize_image(image, min_dim=None, max_dim=None, min_scale=None, mode="square"):
389
"""Resizes an image keeping the aspect ratio unchanged.
390
391
min_dim: if provided, resizes the image such that it's smaller
392
dimension == min_dim
393
max_dim: if provided, ensures that the image longest side doesn't
394
exceed this value.
395
min_scale: if provided, ensure that the image is scaled up by at least
396
this percent even if min_dim doesn't require it.
397
mode: Resizing mode.
398
none: No resizing. Return the image unchanged.
399
square: Resize and pad with zeros to get a square image
400
of size [max_dim, max_dim].
401
pad64: Pads width and height with zeros to make them multiples of 64.
402
If min_dim or min_scale are provided, it scales the image up
403
before padding. max_dim is ignored in this mode.
404
The multiple of 64 is needed to ensure smooth scaling of feature
405
maps up and down the 6 levels of the FPN pyramid (2**6=64).
406
crop: Picks random crops from the image. First, scales the image based
407
on min_dim and min_scale, then picks a random crop of
408
size min_dim x min_dim. Can be used in training only.
409
max_dim is not used in this mode.
410
411
Returns:
412
image: the resized image
413
window: (y1, x1, y2, x2). If max_dim is provided, padding might
414
be inserted in the returned image. If so, this window is the
415
coordinates of the image part of the full image (excluding
416
the padding). The x2, y2 pixels are not included.
417
scale: The scale factor used to resize the image
418
padding: Padding added to the image [(top, bottom), (left, right), (0, 0)]
419
"""
420
# Keep track of image dtype and return results in the same dtype
421
image_dtype = image.dtype
422
# Default window (y1, x1, y2, x2) and default scale == 1.
423
h, w = image.shape[:2]
424
window = (0, 0, h, w)
425
scale = 1
426
padding = [(0, 0), (0, 0), (0, 0)]
427
crop = None
428
429
if mode == "none":
430
return image, window, scale, padding, crop
431
432
# Scale?
433
if min_dim:
434
# Scale up but not down
435
scale = max(1, min_dim / min(h, w))
436
if min_scale and scale < min_scale:
437
scale = min_scale
438
439
# Does it exceed max dim?
440
if max_dim and mode == "square":
441
image_max = max(h, w)
442
if round(image_max * scale) > max_dim:
443
scale = max_dim / image_max
444
445
# Resize image using bilinear interpolation
446
if scale != 1:
447
image = resize(image, (round(h * scale), round(w * scale)),
448
preserve_range=True)
449
450
# Need padding or cropping?
451
if mode == "square":
452
# Get new height and width
453
h, w = image.shape[:2]
454
top_pad = (max_dim - h) // 2
455
bottom_pad = max_dim - h - top_pad
456
left_pad = (max_dim - w) // 2
457
right_pad = max_dim - w - left_pad
458
padding = [(top_pad, bottom_pad), (left_pad, right_pad), (0, 0)]
459
image = np.pad(image, padding, mode='constant', constant_values=0)
460
window = (top_pad, left_pad, h + top_pad, w + left_pad)
461
elif mode == "pad64":
462
h, w = image.shape[:2]
463
# Both sides must be divisible by 64
464
assert min_dim % 64 == 0, "Minimum dimension must be a multiple of 64"
465
# Height
466
if h % 64 > 0:
467
max_h = h - (h % 64) + 64
468
top_pad = (max_h - h) // 2
469
bottom_pad = max_h - h - top_pad
470
else:
471
top_pad = bottom_pad = 0
472
# Width
473
if w % 64 > 0:
474
max_w = w - (w % 64) + 64
475
left_pad = (max_w - w) // 2
476
right_pad = max_w - w - left_pad
477
else:
478
left_pad = right_pad = 0
479
padding = [(top_pad, bottom_pad), (left_pad, right_pad), (0, 0)]
480
image = np.pad(image, padding, mode='constant', constant_values=0)
481
window = (top_pad, left_pad, h + top_pad, w + left_pad)
482
elif mode == "crop":
483
# Pick a random crop
484
h, w = image.shape[:2]
485
y = random.randint(0, (h - min_dim))
486
x = random.randint(0, (w - min_dim))
487
crop = (y, x, min_dim, min_dim)
488
image = image[y:y + min_dim, x:x + min_dim]
489
window = (0, 0, min_dim, min_dim)
490
else:
491
raise Exception("Mode {} not supported".format(mode))
492
return image.astype(image_dtype), window, scale, padding, crop
493
494
495
def resize_mask(mask, scale, padding, crop=None):
496
"""Resizes a mask using the given scale and padding.
497
Typically, you get the scale and padding from resize_image() to
498
ensure both, the image and the mask, are resized consistently.
499
500
scale: mask scaling factor
501
padding: Padding to add to the mask in the form
502
[(top, bottom), (left, right), (0, 0)]
503
"""
504
# Suppress warning from scipy 0.13.0, the output shape of zoom() is
505
# calculated with round() instead of int()
506
with warnings.catch_warnings():
507
warnings.simplefilter("ignore")
508
mask = scipy.ndimage.zoom(mask, zoom=[scale, scale, 1], order=0)
509
if crop is not None:
510
y, x, h, w = crop
511
mask = mask[y:y + h, x:x + w]
512
else:
513
mask = np.pad(mask, padding, mode='constant', constant_values=0)
514
return mask
515
516
517
def minimize_mask(bbox, mask, mini_shape):
518
"""Resize masks to a smaller version to reduce memory load.
519
Mini-masks can be resized back to image scale using expand_masks()
520
521
See inspect_data.ipynb notebook for more details.
522
"""
523
mini_mask = np.zeros(mini_shape + (mask.shape[-1],), dtype=bool)
524
for i in range(mask.shape[-1]):
525
# Pick slice and cast to bool in case load_mask() returned wrong dtype
526
m = mask[:, :, i].astype(bool)
527
y1, x1, y2, x2 = bbox[i][:4]
528
m = m[y1:y2, x1:x2]
529
if m.size == 0:
530
raise Exception("Invalid bounding box with area of zero")
531
# Resize with bilinear interpolation
532
m = resize(m, mini_shape)
533
mini_mask[:, :, i] = np.around(m).astype(np.bool)
534
return mini_mask
535
536
537
def expand_mask(bbox, mini_mask, image_shape):
538
"""Resizes mini masks back to image size. Reverses the change
539
of minimize_mask().
540
541
See inspect_data.ipynb notebook for more details.
542
"""
543
mask = np.zeros(image_shape[:2] + (mini_mask.shape[-1],), dtype=bool)
544
for i in range(mask.shape[-1]):
545
m = mini_mask[:, :, i]
546
y1, x1, y2, x2 = bbox[i][:4]
547
h = y2 - y1
548
w = x2 - x1
549
# Resize with bilinear interpolation
550
m = resize(m, (h, w))
551
mask[y1:y2, x1:x2, i] = np.around(m).astype(np.bool)
552
return mask
553
554
555
# TODO: Build and use this function to reduce code duplication
556
def mold_mask(mask, config):
557
pass
558
559
560
def unmold_mask(mask, bbox, image_shape):
561
"""Converts a mask generated by the neural network to a format similar
562
to its original shape.
563
mask: [height, width] of type float. A small, typically 28x28 mask.
564
bbox: [y1, x1, y2, x2]. The box to fit the mask in.
565
566
Returns a binary mask with the same size as the original image.
567
"""
568
threshold = 0.5
569
y1, x1, y2, x2 = bbox
570
mask = resize(mask, (y2 - y1, x2 - x1))
571
mask = np.where(mask >= threshold, 1, 0).astype(np.bool)
572
573
# Put the mask in the right location.
574
full_mask = np.zeros(image_shape[:2], dtype=np.bool)
575
full_mask[y1:y2, x1:x2] = mask
576
return full_mask
577
578
579
############################################################
580
# Anchors
581
############################################################
582
583
def generate_anchors(scales, ratios, shape, feature_stride, anchor_stride):
584
"""
585
scales: 1D array of anchor sizes in pixels. Example: [32, 64, 128]
586
ratios: 1D array of anchor ratios of width/height. Example: [0.5, 1, 2]
587
shape: [height, width] spatial shape of the feature map over which
588
to generate anchors.
589
feature_stride: Stride of the feature map relative to the image in pixels.
590
anchor_stride: Stride of anchors on the feature map. For example, if the
591
value is 2 then generate anchors for every other feature map pixel.
592
"""
593
# Get all combinations of scales and ratios
594
scales, ratios = np.meshgrid(np.array(scales), np.array(ratios))
595
scales = scales.flatten()
596
ratios = ratios.flatten()
597
598
# Enumerate heights and widths from scales and ratios
599
heights = scales / np.sqrt(ratios)
600
widths = scales * np.sqrt(ratios)
601
602
# Enumerate shifts in feature space
603
shifts_y = np.arange(0, shape[0], anchor_stride) * feature_stride
604
shifts_x = np.arange(0, shape[1], anchor_stride) * feature_stride
605
shifts_x, shifts_y = np.meshgrid(shifts_x, shifts_y)
606
607
# Enumerate combinations of shifts, widths, and heights
608
box_widths, box_centers_x = np.meshgrid(widths, shifts_x)
609
box_heights, box_centers_y = np.meshgrid(heights, shifts_y)
610
611
# Reshape to get a list of (y, x) and a list of (h, w)
612
box_centers = np.stack(
613
[box_centers_y, box_centers_x], axis=2).reshape([-1, 2])
614
box_sizes = np.stack([box_heights, box_widths], axis=2).reshape([-1, 2])
615
616
# Convert to corner coordinates (y1, x1, y2, x2)
617
boxes = np.concatenate([box_centers - 0.5 * box_sizes,
618
box_centers + 0.5 * box_sizes], axis=1)
619
return boxes
620
621
622
def generate_pyramid_anchors(scales, ratios, feature_shapes, feature_strides,
623
anchor_stride):
624
"""Generate anchors at different levels of a feature pyramid. Each scale
625
is associated with a level of the pyramid, but each ratio is used in
626
all levels of the pyramid.
627
628
Returns:
629
anchors: [N, (y1, x1, y2, x2)]. All generated anchors in one array. Sorted
630
with the same order of the given scales. So, anchors of scale[0] come
631
first, then anchors of scale[1], and so on.
632
"""
633
# Anchors
634
# [anchor_count, (y1, x1, y2, x2)]
635
anchors = []
636
for i in range(len(scales)):
637
anchors.append(generate_anchors(scales[i], ratios, feature_shapes[i],
638
feature_strides[i], anchor_stride))
639
return np.concatenate(anchors, axis=0)
640
641
642
############################################################
643
# Miscellaneous
644
############################################################
645
646
def trim_zeros(x):
647
"""It's common to have tensors larger than the available data and
648
pad with zeros. This function removes rows that are all zeros.
649
650
x: [rows, columns].
651
"""
652
assert len(x.shape) == 2
653
return x[~np.all(x == 0, axis=1)]
654
655
656
def compute_matches(gt_boxes, gt_class_ids, gt_masks,
657
pred_boxes, pred_class_ids, pred_scores, pred_masks,
658
iou_threshold=0.5, score_threshold=0.0):
659
"""Finds matches between prediction and ground truth instances.
660
661
Returns:
662
gt_match: 1-D array. For each GT box it has the index of the matched
663
predicted box.
664
pred_match: 1-D array. For each predicted box, it has the index of
665
the matched ground truth box.
666
overlaps: [pred_boxes, gt_boxes] IoU overlaps.
667
"""
668
# Trim zero padding
669
# TODO: cleaner to do zero unpadding upstream
670
gt_boxes = trim_zeros(gt_boxes)
671
gt_masks = gt_masks[..., :gt_boxes.shape[0]]
672
pred_boxes = trim_zeros(pred_boxes)
673
pred_scores = pred_scores[:pred_boxes.shape[0]]
674
# Sort predictions by score from high to low
675
indices = np.argsort(pred_scores)[::-1]
676
pred_boxes = pred_boxes[indices]
677
pred_class_ids = pred_class_ids[indices]
678
pred_scores = pred_scores[indices]
679
pred_masks = pred_masks[..., indices]
680
681
# Compute IoU overlaps [pred_masks, gt_masks]
682
overlaps = compute_overlaps_masks(pred_masks, gt_masks)
683
684
# Loop through predictions and find matching ground truth boxes
685
match_count = 0
686
pred_match = -1 * np.ones([pred_boxes.shape[0]])
687
gt_match = -1 * np.ones([gt_boxes.shape[0]])
688
for i in range(len(pred_boxes)):
689
# Find best matching ground truth box
690
# 1. Sort matches by score
691
sorted_ixs = np.argsort(overlaps[i])[::-1]
692
# 2. Remove low scores
693
low_score_idx = np.where(overlaps[i, sorted_ixs] < score_threshold)[0]
694
if low_score_idx.size > 0:
695
sorted_ixs = sorted_ixs[:low_score_idx[0]]
696
# 3. Find the match
697
for j in sorted_ixs:
698
# If ground truth box is already matched, go to next one
699
if gt_match[j] > -1:
700
continue
701
# If we reach IoU smaller than the threshold, end the loop
702
iou = overlaps[i, j]
703
if iou < iou_threshold:
704
break
705
# Do we have a match?
706
if pred_class_ids[i] == gt_class_ids[j]:
707
match_count += 1
708
gt_match[j] = i
709
pred_match[i] = j
710
break
711
712
return gt_match, pred_match, overlaps
713
714
715
def compute_ap(gt_boxes, gt_class_ids, gt_masks,
716
pred_boxes, pred_class_ids, pred_scores, pred_masks,
717
iou_threshold=0.5):
718
"""Compute Average Precision at a set IoU threshold (default 0.5).
719
720
Returns:
721
mAP: Mean Average Precision
722
precisions: List of precisions at different class score thresholds.
723
recalls: List of recall values at different class score thresholds.
724
overlaps: [pred_boxes, gt_boxes] IoU overlaps.
725
"""
726
# Get matches and overlaps
727
gt_match, pred_match, overlaps = compute_matches(
728
gt_boxes, gt_class_ids, gt_masks,
729
pred_boxes, pred_class_ids, pred_scores, pred_masks,
730
iou_threshold)
731
732
# Compute precision and recall at each prediction box step
733
precisions = np.cumsum(pred_match > -1) / (np.arange(len(pred_match)) + 1)
734
recalls = np.cumsum(pred_match > -1).astype(np.float32) / len(gt_match)
735
736
# Pad with start and end values to simplify the math
737
precisions = np.concatenate([[0], precisions, [0]])
738
recalls = np.concatenate([[0], recalls, [1]])
739
740
# Ensure precision values decrease but don't increase. This way, the
741
# precision value at each recall threshold is the maximum it can be
742
# for all following recall thresholds, as specified by the VOC paper.
743
for i in range(len(precisions) - 2, -1, -1):
744
precisions[i] = np.maximum(precisions[i], precisions[i + 1])
745
746
# Compute mean AP over recall range
747
indices = np.where(recalls[:-1] != recalls[1:])[0] + 1
748
mAP = np.sum((recalls[indices] - recalls[indices - 1]) *
749
precisions[indices])
750
751
return mAP, precisions, recalls, overlaps
752
753
754
def compute_ap_range(gt_box, gt_class_id, gt_mask,
755
pred_box, pred_class_id, pred_score, pred_mask,
756
iou_thresholds=None, verbose=1):
757
"""Compute AP over a range or IoU thresholds. Default range is 0.5-0.95."""
758
# Default is 0.5 to 0.95 with increments of 0.05
759
iou_thresholds = iou_thresholds or np.arange(0.5, 1.0, 0.05)
760
761
# Compute AP over range of IoU thresholds
762
AP = []
763
for iou_threshold in iou_thresholds:
764
ap, precisions, recalls, overlaps =\
765
compute_ap(gt_box, gt_class_id, gt_mask,
766
pred_box, pred_class_id, pred_score, pred_mask,
767
iou_threshold=iou_threshold)
768
if verbose:
769
print("AP @{:.2f}:\t {:.3f}".format(iou_threshold, ap))
770
AP.append(ap)
771
AP = np.array(AP).mean()
772
if verbose:
773
print("AP @{:.2f}-{:.2f}:\t {:.3f}".format(
774
iou_thresholds[0], iou_thresholds[-1], AP))
775
return AP
776
777
778
def compute_recall(pred_boxes, gt_boxes, iou):
779
"""Compute the recall at the given IoU threshold. It's an indication
780
of how many GT boxes were found by the given prediction boxes.
781
782
pred_boxes: [N, (y1, x1, y2, x2)] in image coordinates
783
gt_boxes: [N, (y1, x1, y2, x2)] in image coordinates
784
"""
785
# Measure overlaps
786
overlaps = compute_overlaps(pred_boxes, gt_boxes)
787
iou_max = np.max(overlaps, axis=1)
788
iou_argmax = np.argmax(overlaps, axis=1)
789
positive_ids = np.where(iou_max >= iou)[0]
790
matched_gt_boxes = iou_argmax[positive_ids]
791
792
recall = len(set(matched_gt_boxes)) / gt_boxes.shape[0]
793
return recall, positive_ids
794
795
796
# ## Batch Slicing
797
# Some custom layers support a batch size of 1 only, and require a lot of work
798
# to support batches greater than 1. This function slices an input tensor
799
# across the batch dimension and feeds batches of size 1. Effectively,
800
# an easy way to support batches > 1 quickly with little code modification.
801
# In the long run, it's more efficient to modify the code to support large
802
# batches and getting rid of this function. Consider this a temporary solution
803
def batch_slice(inputs, graph_fn, batch_size, names=None):
804
"""Splits inputs into slices and feeds each slice to a copy of the given
805
computation graph and then combines the results. It allows you to run a
806
graph on a batch of inputs even if the graph is written to support one
807
instance only.
808
809
inputs: list of tensors. All must have the same first dimension length
810
graph_fn: A function that returns a TF tensor that's part of a graph.
811
batch_size: number of slices to divide the data into.
812
names: If provided, assigns names to the resulting tensors.
813
"""
814
if not isinstance(inputs, list):
815
inputs = [inputs]
816
817
outputs = []
818
for i in range(batch_size):
819
inputs_slice = [x[i] for x in inputs]
820
output_slice = graph_fn(*inputs_slice)
821
if not isinstance(output_slice, (tuple, list)):
822
output_slice = [output_slice]
823
outputs.append(output_slice)
824
# Change outputs from a list of slices where each is
825
# a list of outputs to a list of outputs and each has
826
# a list of slices
827
outputs = list(zip(*outputs))
828
829
if names is None:
830
names = [None] * len(outputs)
831
832
result = [tf.stack(o, axis=0, name=n)
833
for o, n in zip(outputs, names)]
834
if len(result) == 1:
835
result = result[0]
836
837
return result
838
839
840
def download_trained_weights(coco_model_path, verbose=1):
841
"""Download COCO trained weights from Releases.
842
843
coco_model_path: local path of COCO trained weights
844
"""
845
if verbose > 0:
846
print("Downloading pretrained model to " + coco_model_path + " ...")
847
with urllib.request.urlopen(COCO_MODEL_URL) as resp, open(coco_model_path, 'wb') as out:
848
shutil.copyfileobj(resp, out)
849
if verbose > 0:
850
print("... done downloading pretrained model!")
851
852
853
def norm_boxes(boxes, shape):
854
"""Converts boxes from pixel coordinates to normalized coordinates.
855
boxes: [N, (y1, x1, y2, x2)] in pixel coordinates
856
shape: [..., (height, width)] in pixels
857
858
Note: In pixel coordinates (y2, x2) is outside the box. But in normalized
859
coordinates it's inside the box.
860
861
Returns:
862
[N, (y1, x1, y2, x2)] in normalized coordinates
863
"""
864
h, w = shape
865
scale = np.array([h - 1, w - 1, h - 1, w - 1])
866
shift = np.array([0, 0, 1, 1])
867
return np.divide((boxes - shift), scale).astype(np.float32)
868
869
870
def denorm_boxes(boxes, shape):
871
"""Converts boxes from normalized coordinates to pixel coordinates.
872
boxes: [N, (y1, x1, y2, x2)] in normalized coordinates
873
shape: [..., (height, width)] in pixels
874
875
Note: In pixel coordinates (y2, x2) is outside the box. But in normalized
876
coordinates it's inside the box.
877
878
Returns:
879
[N, (y1, x1, y2, x2)] in pixel coordinates
880
"""
881
h, w = shape
882
scale = np.array([h - 1, w - 1, h - 1, w - 1])
883
shift = np.array([0, 0, 1, 1])
884
return np.around(np.multiply(boxes, scale) + shift).astype(np.int32)
885
886
887
def resize(image, output_shape, order=1, mode='constant', cval=0, clip=True,
888
preserve_range=False, anti_aliasing=False, anti_aliasing_sigma=None):
889
"""A wrapper for Scikit-Image resize().
890
891
Scikit-Image generates warnings on every call to resize() if it doesn't
892
receive the right parameters. The right parameters depend on the version
893
of skimage. This solves the problem by using different parameters per
894
version. And it provides a central place to control resizing defaults.
895
"""
896
if LooseVersion(skimage.__version__) >= LooseVersion("0.14"):
897
# New in 0.14: anti_aliasing. Default it to False for backward
898
# compatibility with skimage 0.13.
899
return skimage.transform.resize(
900
image, output_shape,
901
order=order, mode=mode, cval=cval, clip=clip,
902
preserve_range=preserve_range, anti_aliasing=anti_aliasing,
903
anti_aliasing_sigma=anti_aliasing_sigma)
904
else:
905
return skimage.transform.resize(
906
image, output_shape,
907
order=order, mode=mode, cval=cval, clip=clip,
908
preserve_range=preserve_range)
909
910