Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
matterport
GitHub Repository: matterport/Mask_RCNN
Path: blob/master/mrcnn/model.py
239 views
1
"""
2
Mask R-CNN
3
The main Mask R-CNN model implementation.
4
5
Copyright (c) 2017 Matterport, Inc.
6
Licensed under the MIT License (see LICENSE for details)
7
Written by Waleed Abdulla
8
"""
9
10
import os
11
import random
12
import datetime
13
import re
14
import math
15
import logging
16
from collections import OrderedDict
17
import multiprocessing
18
import numpy as np
19
import tensorflow as tf
20
import keras
21
import keras.backend as K
22
import keras.layers as KL
23
import keras.engine as KE
24
import keras.models as KM
25
26
from mrcnn import utils
27
28
# Requires TensorFlow 1.3+ and Keras 2.0.8+.
29
from distutils.version import LooseVersion
30
assert LooseVersion(tf.__version__) >= LooseVersion("1.3")
31
assert LooseVersion(keras.__version__) >= LooseVersion('2.0.8')
32
33
34
############################################################
35
# Utility Functions
36
############################################################
37
38
def log(text, array=None):
39
"""Prints a text message. And, optionally, if a Numpy array is provided it
40
prints it's shape, min, and max values.
41
"""
42
if array is not None:
43
text = text.ljust(25)
44
text += ("shape: {:20} ".format(str(array.shape)))
45
if array.size:
46
text += ("min: {:10.5f} max: {:10.5f}".format(array.min(),array.max()))
47
else:
48
text += ("min: {:10} max: {:10}".format("",""))
49
text += " {}".format(array.dtype)
50
print(text)
51
52
53
class BatchNorm(KL.BatchNormalization):
54
"""Extends the Keras BatchNormalization class to allow a central place
55
to make changes if needed.
56
57
Batch normalization has a negative effect on training if batches are small
58
so this layer is often frozen (via setting in Config class) and functions
59
as linear layer.
60
"""
61
def call(self, inputs, training=None):
62
"""
63
Note about training values:
64
None: Train BN layers. This is the normal mode
65
False: Freeze BN layers. Good when batch size is small
66
True: (don't use). Set layer in training mode even when making inferences
67
"""
68
return super(self.__class__, self).call(inputs, training=training)
69
70
71
def compute_backbone_shapes(config, image_shape):
72
"""Computes the width and height of each stage of the backbone network.
73
74
Returns:
75
[N, (height, width)]. Where N is the number of stages
76
"""
77
if callable(config.BACKBONE):
78
return config.COMPUTE_BACKBONE_SHAPE(image_shape)
79
80
# Currently supports ResNet only
81
assert config.BACKBONE in ["resnet50", "resnet101"]
82
return np.array(
83
[[int(math.ceil(image_shape[0] / stride)),
84
int(math.ceil(image_shape[1] / stride))]
85
for stride in config.BACKBONE_STRIDES])
86
87
88
############################################################
89
# Resnet Graph
90
############################################################
91
92
# Code adopted from:
93
# https://github.com/fchollet/deep-learning-models/blob/master/resnet50.py
94
95
def identity_block(input_tensor, kernel_size, filters, stage, block,
96
use_bias=True, train_bn=True):
97
"""The identity_block is the block that has no conv layer at shortcut
98
# Arguments
99
input_tensor: input tensor
100
kernel_size: default 3, the kernel size of middle conv layer at main path
101
filters: list of integers, the nb_filters of 3 conv layer at main path
102
stage: integer, current stage label, used for generating layer names
103
block: 'a','b'..., current block label, used for generating layer names
104
use_bias: Boolean. To use or not use a bias in conv layers.
105
train_bn: Boolean. Train or freeze Batch Norm layers
106
"""
107
nb_filter1, nb_filter2, nb_filter3 = filters
108
conv_name_base = 'res' + str(stage) + block + '_branch'
109
bn_name_base = 'bn' + str(stage) + block + '_branch'
110
111
x = KL.Conv2D(nb_filter1, (1, 1), name=conv_name_base + '2a',
112
use_bias=use_bias)(input_tensor)
113
x = BatchNorm(name=bn_name_base + '2a')(x, training=train_bn)
114
x = KL.Activation('relu')(x)
115
116
x = KL.Conv2D(nb_filter2, (kernel_size, kernel_size), padding='same',
117
name=conv_name_base + '2b', use_bias=use_bias)(x)
118
x = BatchNorm(name=bn_name_base + '2b')(x, training=train_bn)
119
x = KL.Activation('relu')(x)
120
121
x = KL.Conv2D(nb_filter3, (1, 1), name=conv_name_base + '2c',
122
use_bias=use_bias)(x)
123
x = BatchNorm(name=bn_name_base + '2c')(x, training=train_bn)
124
125
x = KL.Add()([x, input_tensor])
126
x = KL.Activation('relu', name='res' + str(stage) + block + '_out')(x)
127
return x
128
129
130
def conv_block(input_tensor, kernel_size, filters, stage, block,
131
strides=(2, 2), use_bias=True, train_bn=True):
132
"""conv_block is the block that has a conv layer at shortcut
133
# Arguments
134
input_tensor: input tensor
135
kernel_size: default 3, the kernel size of middle conv layer at main path
136
filters: list of integers, the nb_filters of 3 conv layer at main path
137
stage: integer, current stage label, used for generating layer names
138
block: 'a','b'..., current block label, used for generating layer names
139
use_bias: Boolean. To use or not use a bias in conv layers.
140
train_bn: Boolean. Train or freeze Batch Norm layers
141
Note that from stage 3, the first conv layer at main path is with subsample=(2,2)
142
And the shortcut should have subsample=(2,2) as well
143
"""
144
nb_filter1, nb_filter2, nb_filter3 = filters
145
conv_name_base = 'res' + str(stage) + block + '_branch'
146
bn_name_base = 'bn' + str(stage) + block + '_branch'
147
148
x = KL.Conv2D(nb_filter1, (1, 1), strides=strides,
149
name=conv_name_base + '2a', use_bias=use_bias)(input_tensor)
150
x = BatchNorm(name=bn_name_base + '2a')(x, training=train_bn)
151
x = KL.Activation('relu')(x)
152
153
x = KL.Conv2D(nb_filter2, (kernel_size, kernel_size), padding='same',
154
name=conv_name_base + '2b', use_bias=use_bias)(x)
155
x = BatchNorm(name=bn_name_base + '2b')(x, training=train_bn)
156
x = KL.Activation('relu')(x)
157
158
x = KL.Conv2D(nb_filter3, (1, 1), name=conv_name_base +
159
'2c', use_bias=use_bias)(x)
160
x = BatchNorm(name=bn_name_base + '2c')(x, training=train_bn)
161
162
shortcut = KL.Conv2D(nb_filter3, (1, 1), strides=strides,
163
name=conv_name_base + '1', use_bias=use_bias)(input_tensor)
164
shortcut = BatchNorm(name=bn_name_base + '1')(shortcut, training=train_bn)
165
166
x = KL.Add()([x, shortcut])
167
x = KL.Activation('relu', name='res' + str(stage) + block + '_out')(x)
168
return x
169
170
171
def resnet_graph(input_image, architecture, stage5=False, train_bn=True):
172
"""Build a ResNet graph.
173
architecture: Can be resnet50 or resnet101
174
stage5: Boolean. If False, stage5 of the network is not created
175
train_bn: Boolean. Train or freeze Batch Norm layers
176
"""
177
assert architecture in ["resnet50", "resnet101"]
178
# Stage 1
179
x = KL.ZeroPadding2D((3, 3))(input_image)
180
x = KL.Conv2D(64, (7, 7), strides=(2, 2), name='conv1', use_bias=True)(x)
181
x = BatchNorm(name='bn_conv1')(x, training=train_bn)
182
x = KL.Activation('relu')(x)
183
C1 = x = KL.MaxPooling2D((3, 3), strides=(2, 2), padding="same")(x)
184
# Stage 2
185
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1), train_bn=train_bn)
186
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b', train_bn=train_bn)
187
C2 = x = identity_block(x, 3, [64, 64, 256], stage=2, block='c', train_bn=train_bn)
188
# Stage 3
189
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a', train_bn=train_bn)
190
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b', train_bn=train_bn)
191
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c', train_bn=train_bn)
192
C3 = x = identity_block(x, 3, [128, 128, 512], stage=3, block='d', train_bn=train_bn)
193
# Stage 4
194
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a', train_bn=train_bn)
195
block_count = {"resnet50": 5, "resnet101": 22}[architecture]
196
for i in range(block_count):
197
x = identity_block(x, 3, [256, 256, 1024], stage=4, block=chr(98 + i), train_bn=train_bn)
198
C4 = x
199
# Stage 5
200
if stage5:
201
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a', train_bn=train_bn)
202
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b', train_bn=train_bn)
203
C5 = x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c', train_bn=train_bn)
204
else:
205
C5 = None
206
return [C1, C2, C3, C4, C5]
207
208
209
############################################################
210
# Proposal Layer
211
############################################################
212
213
def apply_box_deltas_graph(boxes, deltas):
214
"""Applies the given deltas to the given boxes.
215
boxes: [N, (y1, x1, y2, x2)] boxes to update
216
deltas: [N, (dy, dx, log(dh), log(dw))] refinements to apply
217
"""
218
# Convert to y, x, h, w
219
height = boxes[:, 2] - boxes[:, 0]
220
width = boxes[:, 3] - boxes[:, 1]
221
center_y = boxes[:, 0] + 0.5 * height
222
center_x = boxes[:, 1] + 0.5 * width
223
# Apply deltas
224
center_y += deltas[:, 0] * height
225
center_x += deltas[:, 1] * width
226
height *= tf.exp(deltas[:, 2])
227
width *= tf.exp(deltas[:, 3])
228
# Convert back to y1, x1, y2, x2
229
y1 = center_y - 0.5 * height
230
x1 = center_x - 0.5 * width
231
y2 = y1 + height
232
x2 = x1 + width
233
result = tf.stack([y1, x1, y2, x2], axis=1, name="apply_box_deltas_out")
234
return result
235
236
237
def clip_boxes_graph(boxes, window):
238
"""
239
boxes: [N, (y1, x1, y2, x2)]
240
window: [4] in the form y1, x1, y2, x2
241
"""
242
# Split
243
wy1, wx1, wy2, wx2 = tf.split(window, 4)
244
y1, x1, y2, x2 = tf.split(boxes, 4, axis=1)
245
# Clip
246
y1 = tf.maximum(tf.minimum(y1, wy2), wy1)
247
x1 = tf.maximum(tf.minimum(x1, wx2), wx1)
248
y2 = tf.maximum(tf.minimum(y2, wy2), wy1)
249
x2 = tf.maximum(tf.minimum(x2, wx2), wx1)
250
clipped = tf.concat([y1, x1, y2, x2], axis=1, name="clipped_boxes")
251
clipped.set_shape((clipped.shape[0], 4))
252
return clipped
253
254
255
class ProposalLayer(KE.Layer):
256
"""Receives anchor scores and selects a subset to pass as proposals
257
to the second stage. Filtering is done based on anchor scores and
258
non-max suppression to remove overlaps. It also applies bounding
259
box refinement deltas to anchors.
260
261
Inputs:
262
rpn_probs: [batch, num_anchors, (bg prob, fg prob)]
263
rpn_bbox: [batch, num_anchors, (dy, dx, log(dh), log(dw))]
264
anchors: [batch, num_anchors, (y1, x1, y2, x2)] anchors in normalized coordinates
265
266
Returns:
267
Proposals in normalized coordinates [batch, rois, (y1, x1, y2, x2)]
268
"""
269
270
def __init__(self, proposal_count, nms_threshold, config=None, **kwargs):
271
super(ProposalLayer, self).__init__(**kwargs)
272
self.config = config
273
self.proposal_count = proposal_count
274
self.nms_threshold = nms_threshold
275
276
def call(self, inputs):
277
# Box Scores. Use the foreground class confidence. [Batch, num_rois, 1]
278
scores = inputs[0][:, :, 1]
279
# Box deltas [batch, num_rois, 4]
280
deltas = inputs[1]
281
deltas = deltas * np.reshape(self.config.RPN_BBOX_STD_DEV, [1, 1, 4])
282
# Anchors
283
anchors = inputs[2]
284
285
# Improve performance by trimming to top anchors by score
286
# and doing the rest on the smaller subset.
287
pre_nms_limit = tf.minimum(self.config.PRE_NMS_LIMIT, tf.shape(anchors)[1])
288
ix = tf.nn.top_k(scores, pre_nms_limit, sorted=True,
289
name="top_anchors").indices
290
scores = utils.batch_slice([scores, ix], lambda x, y: tf.gather(x, y),
291
self.config.IMAGES_PER_GPU)
292
deltas = utils.batch_slice([deltas, ix], lambda x, y: tf.gather(x, y),
293
self.config.IMAGES_PER_GPU)
294
pre_nms_anchors = utils.batch_slice([anchors, ix], lambda a, x: tf.gather(a, x),
295
self.config.IMAGES_PER_GPU,
296
names=["pre_nms_anchors"])
297
298
# Apply deltas to anchors to get refined anchors.
299
# [batch, N, (y1, x1, y2, x2)]
300
boxes = utils.batch_slice([pre_nms_anchors, deltas],
301
lambda x, y: apply_box_deltas_graph(x, y),
302
self.config.IMAGES_PER_GPU,
303
names=["refined_anchors"])
304
305
# Clip to image boundaries. Since we're in normalized coordinates,
306
# clip to 0..1 range. [batch, N, (y1, x1, y2, x2)]
307
window = np.array([0, 0, 1, 1], dtype=np.float32)
308
boxes = utils.batch_slice(boxes,
309
lambda x: clip_boxes_graph(x, window),
310
self.config.IMAGES_PER_GPU,
311
names=["refined_anchors_clipped"])
312
313
# Filter out small boxes
314
# According to Xinlei Chen's paper, this reduces detection accuracy
315
# for small objects, so we're skipping it.
316
317
# Non-max suppression
318
def nms(boxes, scores):
319
indices = tf.image.non_max_suppression(
320
boxes, scores, self.proposal_count,
321
self.nms_threshold, name="rpn_non_max_suppression")
322
proposals = tf.gather(boxes, indices)
323
# Pad if needed
324
padding = tf.maximum(self.proposal_count - tf.shape(proposals)[0], 0)
325
proposals = tf.pad(proposals, [(0, padding), (0, 0)])
326
return proposals
327
proposals = utils.batch_slice([boxes, scores], nms,
328
self.config.IMAGES_PER_GPU)
329
return proposals
330
331
def compute_output_shape(self, input_shape):
332
return (None, self.proposal_count, 4)
333
334
335
############################################################
336
# ROIAlign Layer
337
############################################################
338
339
def log2_graph(x):
340
"""Implementation of Log2. TF doesn't have a native implementation."""
341
return tf.log(x) / tf.log(2.0)
342
343
344
class PyramidROIAlign(KE.Layer):
345
"""Implements ROI Pooling on multiple levels of the feature pyramid.
346
347
Params:
348
- pool_shape: [pool_height, pool_width] of the output pooled regions. Usually [7, 7]
349
350
Inputs:
351
- boxes: [batch, num_boxes, (y1, x1, y2, x2)] in normalized
352
coordinates. Possibly padded with zeros if not enough
353
boxes to fill the array.
354
- image_meta: [batch, (meta data)] Image details. See compose_image_meta()
355
- feature_maps: List of feature maps from different levels of the pyramid.
356
Each is [batch, height, width, channels]
357
358
Output:
359
Pooled regions in the shape: [batch, num_boxes, pool_height, pool_width, channels].
360
The width and height are those specific in the pool_shape in the layer
361
constructor.
362
"""
363
364
def __init__(self, pool_shape, **kwargs):
365
super(PyramidROIAlign, self).__init__(**kwargs)
366
self.pool_shape = tuple(pool_shape)
367
368
def call(self, inputs):
369
# Crop boxes [batch, num_boxes, (y1, x1, y2, x2)] in normalized coords
370
boxes = inputs[0]
371
372
# Image meta
373
# Holds details about the image. See compose_image_meta()
374
image_meta = inputs[1]
375
376
# Feature Maps. List of feature maps from different level of the
377
# feature pyramid. Each is [batch, height, width, channels]
378
feature_maps = inputs[2:]
379
380
# Assign each ROI to a level in the pyramid based on the ROI area.
381
y1, x1, y2, x2 = tf.split(boxes, 4, axis=2)
382
h = y2 - y1
383
w = x2 - x1
384
# Use shape of first image. Images in a batch must have the same size.
385
image_shape = parse_image_meta_graph(image_meta)['image_shape'][0]
386
# Equation 1 in the Feature Pyramid Networks paper. Account for
387
# the fact that our coordinates are normalized here.
388
# e.g. a 224x224 ROI (in pixels) maps to P4
389
image_area = tf.cast(image_shape[0] * image_shape[1], tf.float32)
390
roi_level = log2_graph(tf.sqrt(h * w) / (224.0 / tf.sqrt(image_area)))
391
roi_level = tf.minimum(5, tf.maximum(
392
2, 4 + tf.cast(tf.round(roi_level), tf.int32)))
393
roi_level = tf.squeeze(roi_level, 2)
394
395
# Loop through levels and apply ROI pooling to each. P2 to P5.
396
pooled = []
397
box_to_level = []
398
for i, level in enumerate(range(2, 6)):
399
ix = tf.where(tf.equal(roi_level, level))
400
level_boxes = tf.gather_nd(boxes, ix)
401
402
# Box indices for crop_and_resize.
403
box_indices = tf.cast(ix[:, 0], tf.int32)
404
405
# Keep track of which box is mapped to which level
406
box_to_level.append(ix)
407
408
# Stop gradient propogation to ROI proposals
409
level_boxes = tf.stop_gradient(level_boxes)
410
box_indices = tf.stop_gradient(box_indices)
411
412
# Crop and Resize
413
# From Mask R-CNN paper: "We sample four regular locations, so
414
# that we can evaluate either max or average pooling. In fact,
415
# interpolating only a single value at each bin center (without
416
# pooling) is nearly as effective."
417
#
418
# Here we use the simplified approach of a single value per bin,
419
# which is how it's done in tf.crop_and_resize()
420
# Result: [batch * num_boxes, pool_height, pool_width, channels]
421
pooled.append(tf.image.crop_and_resize(
422
feature_maps[i], level_boxes, box_indices, self.pool_shape,
423
method="bilinear"))
424
425
# Pack pooled features into one tensor
426
pooled = tf.concat(pooled, axis=0)
427
428
# Pack box_to_level mapping into one array and add another
429
# column representing the order of pooled boxes
430
box_to_level = tf.concat(box_to_level, axis=0)
431
box_range = tf.expand_dims(tf.range(tf.shape(box_to_level)[0]), 1)
432
box_to_level = tf.concat([tf.cast(box_to_level, tf.int32), box_range],
433
axis=1)
434
435
# Rearrange pooled features to match the order of the original boxes
436
# Sort box_to_level by batch then box index
437
# TF doesn't have a way to sort by two columns, so merge them and sort.
438
sorting_tensor = box_to_level[:, 0] * 100000 + box_to_level[:, 1]
439
ix = tf.nn.top_k(sorting_tensor, k=tf.shape(
440
box_to_level)[0]).indices[::-1]
441
ix = tf.gather(box_to_level[:, 2], ix)
442
pooled = tf.gather(pooled, ix)
443
444
# Re-add the batch dimension
445
shape = tf.concat([tf.shape(boxes)[:2], tf.shape(pooled)[1:]], axis=0)
446
pooled = tf.reshape(pooled, shape)
447
return pooled
448
449
def compute_output_shape(self, input_shape):
450
return input_shape[0][:2] + self.pool_shape + (input_shape[2][-1], )
451
452
453
############################################################
454
# Detection Target Layer
455
############################################################
456
457
def overlaps_graph(boxes1, boxes2):
458
"""Computes IoU overlaps between two sets of boxes.
459
boxes1, boxes2: [N, (y1, x1, y2, x2)].
460
"""
461
# 1. Tile boxes2 and repeat boxes1. This allows us to compare
462
# every boxes1 against every boxes2 without loops.
463
# TF doesn't have an equivalent to np.repeat() so simulate it
464
# using tf.tile() and tf.reshape.
465
b1 = tf.reshape(tf.tile(tf.expand_dims(boxes1, 1),
466
[1, 1, tf.shape(boxes2)[0]]), [-1, 4])
467
b2 = tf.tile(boxes2, [tf.shape(boxes1)[0], 1])
468
# 2. Compute intersections
469
b1_y1, b1_x1, b1_y2, b1_x2 = tf.split(b1, 4, axis=1)
470
b2_y1, b2_x1, b2_y2, b2_x2 = tf.split(b2, 4, axis=1)
471
y1 = tf.maximum(b1_y1, b2_y1)
472
x1 = tf.maximum(b1_x1, b2_x1)
473
y2 = tf.minimum(b1_y2, b2_y2)
474
x2 = tf.minimum(b1_x2, b2_x2)
475
intersection = tf.maximum(x2 - x1, 0) * tf.maximum(y2 - y1, 0)
476
# 3. Compute unions
477
b1_area = (b1_y2 - b1_y1) * (b1_x2 - b1_x1)
478
b2_area = (b2_y2 - b2_y1) * (b2_x2 - b2_x1)
479
union = b1_area + b2_area - intersection
480
# 4. Compute IoU and reshape to [boxes1, boxes2]
481
iou = intersection / union
482
overlaps = tf.reshape(iou, [tf.shape(boxes1)[0], tf.shape(boxes2)[0]])
483
return overlaps
484
485
486
def detection_targets_graph(proposals, gt_class_ids, gt_boxes, gt_masks, config):
487
"""Generates detection targets for one image. Subsamples proposals and
488
generates target class IDs, bounding box deltas, and masks for each.
489
490
Inputs:
491
proposals: [POST_NMS_ROIS_TRAINING, (y1, x1, y2, x2)] in normalized coordinates. Might
492
be zero padded if there are not enough proposals.
493
gt_class_ids: [MAX_GT_INSTANCES] int class IDs
494
gt_boxes: [MAX_GT_INSTANCES, (y1, x1, y2, x2)] in normalized coordinates.
495
gt_masks: [height, width, MAX_GT_INSTANCES] of boolean type.
496
497
Returns: Target ROIs and corresponding class IDs, bounding box shifts,
498
and masks.
499
rois: [TRAIN_ROIS_PER_IMAGE, (y1, x1, y2, x2)] in normalized coordinates
500
class_ids: [TRAIN_ROIS_PER_IMAGE]. Integer class IDs. Zero padded.
501
deltas: [TRAIN_ROIS_PER_IMAGE, (dy, dx, log(dh), log(dw))]
502
masks: [TRAIN_ROIS_PER_IMAGE, height, width]. Masks cropped to bbox
503
boundaries and resized to neural network output size.
504
505
Note: Returned arrays might be zero padded if not enough target ROIs.
506
"""
507
# Assertions
508
asserts = [
509
tf.Assert(tf.greater(tf.shape(proposals)[0], 0), [proposals],
510
name="roi_assertion"),
511
]
512
with tf.control_dependencies(asserts):
513
proposals = tf.identity(proposals)
514
515
# Remove zero padding
516
proposals, _ = trim_zeros_graph(proposals, name="trim_proposals")
517
gt_boxes, non_zeros = trim_zeros_graph(gt_boxes, name="trim_gt_boxes")
518
gt_class_ids = tf.boolean_mask(gt_class_ids, non_zeros,
519
name="trim_gt_class_ids")
520
gt_masks = tf.gather(gt_masks, tf.where(non_zeros)[:, 0], axis=2,
521
name="trim_gt_masks")
522
523
# Handle COCO crowds
524
# A crowd box in COCO is a bounding box around several instances. Exclude
525
# them from training. A crowd box is given a negative class ID.
526
crowd_ix = tf.where(gt_class_ids < 0)[:, 0]
527
non_crowd_ix = tf.where(gt_class_ids > 0)[:, 0]
528
crowd_boxes = tf.gather(gt_boxes, crowd_ix)
529
gt_class_ids = tf.gather(gt_class_ids, non_crowd_ix)
530
gt_boxes = tf.gather(gt_boxes, non_crowd_ix)
531
gt_masks = tf.gather(gt_masks, non_crowd_ix, axis=2)
532
533
# Compute overlaps matrix [proposals, gt_boxes]
534
overlaps = overlaps_graph(proposals, gt_boxes)
535
536
# Compute overlaps with crowd boxes [proposals, crowd_boxes]
537
crowd_overlaps = overlaps_graph(proposals, crowd_boxes)
538
crowd_iou_max = tf.reduce_max(crowd_overlaps, axis=1)
539
no_crowd_bool = (crowd_iou_max < 0.001)
540
541
# Determine positive and negative ROIs
542
roi_iou_max = tf.reduce_max(overlaps, axis=1)
543
# 1. Positive ROIs are those with >= 0.5 IoU with a GT box
544
positive_roi_bool = (roi_iou_max >= 0.5)
545
positive_indices = tf.where(positive_roi_bool)[:, 0]
546
# 2. Negative ROIs are those with < 0.5 with every GT box. Skip crowds.
547
negative_indices = tf.where(tf.logical_and(roi_iou_max < 0.5, no_crowd_bool))[:, 0]
548
549
# Subsample ROIs. Aim for 33% positive
550
# Positive ROIs
551
positive_count = int(config.TRAIN_ROIS_PER_IMAGE *
552
config.ROI_POSITIVE_RATIO)
553
positive_indices = tf.random_shuffle(positive_indices)[:positive_count]
554
positive_count = tf.shape(positive_indices)[0]
555
# Negative ROIs. Add enough to maintain positive:negative ratio.
556
r = 1.0 / config.ROI_POSITIVE_RATIO
557
negative_count = tf.cast(r * tf.cast(positive_count, tf.float32), tf.int32) - positive_count
558
negative_indices = tf.random_shuffle(negative_indices)[:negative_count]
559
# Gather selected ROIs
560
positive_rois = tf.gather(proposals, positive_indices)
561
negative_rois = tf.gather(proposals, negative_indices)
562
563
# Assign positive ROIs to GT boxes.
564
positive_overlaps = tf.gather(overlaps, positive_indices)
565
roi_gt_box_assignment = tf.cond(
566
tf.greater(tf.shape(positive_overlaps)[1], 0),
567
true_fn = lambda: tf.argmax(positive_overlaps, axis=1),
568
false_fn = lambda: tf.cast(tf.constant([]),tf.int64)
569
)
570
roi_gt_boxes = tf.gather(gt_boxes, roi_gt_box_assignment)
571
roi_gt_class_ids = tf.gather(gt_class_ids, roi_gt_box_assignment)
572
573
# Compute bbox refinement for positive ROIs
574
deltas = utils.box_refinement_graph(positive_rois, roi_gt_boxes)
575
deltas /= config.BBOX_STD_DEV
576
577
# Assign positive ROIs to GT masks
578
# Permute masks to [N, height, width, 1]
579
transposed_masks = tf.expand_dims(tf.transpose(gt_masks, [2, 0, 1]), -1)
580
# Pick the right mask for each ROI
581
roi_masks = tf.gather(transposed_masks, roi_gt_box_assignment)
582
583
# Compute mask targets
584
boxes = positive_rois
585
if config.USE_MINI_MASK:
586
# Transform ROI coordinates from normalized image space
587
# to normalized mini-mask space.
588
y1, x1, y2, x2 = tf.split(positive_rois, 4, axis=1)
589
gt_y1, gt_x1, gt_y2, gt_x2 = tf.split(roi_gt_boxes, 4, axis=1)
590
gt_h = gt_y2 - gt_y1
591
gt_w = gt_x2 - gt_x1
592
y1 = (y1 - gt_y1) / gt_h
593
x1 = (x1 - gt_x1) / gt_w
594
y2 = (y2 - gt_y1) / gt_h
595
x2 = (x2 - gt_x1) / gt_w
596
boxes = tf.concat([y1, x1, y2, x2], 1)
597
box_ids = tf.range(0, tf.shape(roi_masks)[0])
598
masks = tf.image.crop_and_resize(tf.cast(roi_masks, tf.float32), boxes,
599
box_ids,
600
config.MASK_SHAPE)
601
# Remove the extra dimension from masks.
602
masks = tf.squeeze(masks, axis=3)
603
604
# Threshold mask pixels at 0.5 to have GT masks be 0 or 1 to use with
605
# binary cross entropy loss.
606
masks = tf.round(masks)
607
608
# Append negative ROIs and pad bbox deltas and masks that
609
# are not used for negative ROIs with zeros.
610
rois = tf.concat([positive_rois, negative_rois], axis=0)
611
N = tf.shape(negative_rois)[0]
612
P = tf.maximum(config.TRAIN_ROIS_PER_IMAGE - tf.shape(rois)[0], 0)
613
rois = tf.pad(rois, [(0, P), (0, 0)])
614
roi_gt_boxes = tf.pad(roi_gt_boxes, [(0, N + P), (0, 0)])
615
roi_gt_class_ids = tf.pad(roi_gt_class_ids, [(0, N + P)])
616
deltas = tf.pad(deltas, [(0, N + P), (0, 0)])
617
masks = tf.pad(masks, [[0, N + P], (0, 0), (0, 0)])
618
619
return rois, roi_gt_class_ids, deltas, masks
620
621
622
class DetectionTargetLayer(KE.Layer):
623
"""Subsamples proposals and generates target box refinement, class_ids,
624
and masks for each.
625
626
Inputs:
627
proposals: [batch, N, (y1, x1, y2, x2)] in normalized coordinates. Might
628
be zero padded if there are not enough proposals.
629
gt_class_ids: [batch, MAX_GT_INSTANCES] Integer class IDs.
630
gt_boxes: [batch, MAX_GT_INSTANCES, (y1, x1, y2, x2)] in normalized
631
coordinates.
632
gt_masks: [batch, height, width, MAX_GT_INSTANCES] of boolean type
633
634
Returns: Target ROIs and corresponding class IDs, bounding box shifts,
635
and masks.
636
rois: [batch, TRAIN_ROIS_PER_IMAGE, (y1, x1, y2, x2)] in normalized
637
coordinates
638
target_class_ids: [batch, TRAIN_ROIS_PER_IMAGE]. Integer class IDs.
639
target_deltas: [batch, TRAIN_ROIS_PER_IMAGE, (dy, dx, log(dh), log(dw)]
640
target_mask: [batch, TRAIN_ROIS_PER_IMAGE, height, width]
641
Masks cropped to bbox boundaries and resized to neural
642
network output size.
643
644
Note: Returned arrays might be zero padded if not enough target ROIs.
645
"""
646
647
def __init__(self, config, **kwargs):
648
super(DetectionTargetLayer, self).__init__(**kwargs)
649
self.config = config
650
651
def call(self, inputs):
652
proposals = inputs[0]
653
gt_class_ids = inputs[1]
654
gt_boxes = inputs[2]
655
gt_masks = inputs[3]
656
657
# Slice the batch and run a graph for each slice
658
# TODO: Rename target_bbox to target_deltas for clarity
659
names = ["rois", "target_class_ids", "target_bbox", "target_mask"]
660
outputs = utils.batch_slice(
661
[proposals, gt_class_ids, gt_boxes, gt_masks],
662
lambda w, x, y, z: detection_targets_graph(
663
w, x, y, z, self.config),
664
self.config.IMAGES_PER_GPU, names=names)
665
return outputs
666
667
def compute_output_shape(self, input_shape):
668
return [
669
(None, self.config.TRAIN_ROIS_PER_IMAGE, 4), # rois
670
(None, self.config.TRAIN_ROIS_PER_IMAGE), # class_ids
671
(None, self.config.TRAIN_ROIS_PER_IMAGE, 4), # deltas
672
(None, self.config.TRAIN_ROIS_PER_IMAGE, self.config.MASK_SHAPE[0],
673
self.config.MASK_SHAPE[1]) # masks
674
]
675
676
def compute_mask(self, inputs, mask=None):
677
return [None, None, None, None]
678
679
680
############################################################
681
# Detection Layer
682
############################################################
683
684
def refine_detections_graph(rois, probs, deltas, window, config):
685
"""Refine classified proposals and filter overlaps and return final
686
detections.
687
688
Inputs:
689
rois: [N, (y1, x1, y2, x2)] in normalized coordinates
690
probs: [N, num_classes]. Class probabilities.
691
deltas: [N, num_classes, (dy, dx, log(dh), log(dw))]. Class-specific
692
bounding box deltas.
693
window: (y1, x1, y2, x2) in normalized coordinates. The part of the image
694
that contains the image excluding the padding.
695
696
Returns detections shaped: [num_detections, (y1, x1, y2, x2, class_id, score)] where
697
coordinates are normalized.
698
"""
699
# Class IDs per ROI
700
class_ids = tf.argmax(probs, axis=1, output_type=tf.int32)
701
# Class probability of the top class of each ROI
702
indices = tf.stack([tf.range(probs.shape[0]), class_ids], axis=1)
703
class_scores = tf.gather_nd(probs, indices)
704
# Class-specific bounding box deltas
705
deltas_specific = tf.gather_nd(deltas, indices)
706
# Apply bounding box deltas
707
# Shape: [boxes, (y1, x1, y2, x2)] in normalized coordinates
708
refined_rois = apply_box_deltas_graph(
709
rois, deltas_specific * config.BBOX_STD_DEV)
710
# Clip boxes to image window
711
refined_rois = clip_boxes_graph(refined_rois, window)
712
713
# TODO: Filter out boxes with zero area
714
715
# Filter out background boxes
716
keep = tf.where(class_ids > 0)[:, 0]
717
# Filter out low confidence boxes
718
if config.DETECTION_MIN_CONFIDENCE:
719
conf_keep = tf.where(class_scores >= config.DETECTION_MIN_CONFIDENCE)[:, 0]
720
keep = tf.sets.set_intersection(tf.expand_dims(keep, 0),
721
tf.expand_dims(conf_keep, 0))
722
keep = tf.sparse_tensor_to_dense(keep)[0]
723
724
# Apply per-class NMS
725
# 1. Prepare variables
726
pre_nms_class_ids = tf.gather(class_ids, keep)
727
pre_nms_scores = tf.gather(class_scores, keep)
728
pre_nms_rois = tf.gather(refined_rois, keep)
729
unique_pre_nms_class_ids = tf.unique(pre_nms_class_ids)[0]
730
731
def nms_keep_map(class_id):
732
"""Apply Non-Maximum Suppression on ROIs of the given class."""
733
# Indices of ROIs of the given class
734
ixs = tf.where(tf.equal(pre_nms_class_ids, class_id))[:, 0]
735
# Apply NMS
736
class_keep = tf.image.non_max_suppression(
737
tf.gather(pre_nms_rois, ixs),
738
tf.gather(pre_nms_scores, ixs),
739
max_output_size=config.DETECTION_MAX_INSTANCES,
740
iou_threshold=config.DETECTION_NMS_THRESHOLD)
741
# Map indices
742
class_keep = tf.gather(keep, tf.gather(ixs, class_keep))
743
# Pad with -1 so returned tensors have the same shape
744
gap = config.DETECTION_MAX_INSTANCES - tf.shape(class_keep)[0]
745
class_keep = tf.pad(class_keep, [(0, gap)],
746
mode='CONSTANT', constant_values=-1)
747
# Set shape so map_fn() can infer result shape
748
class_keep.set_shape([config.DETECTION_MAX_INSTANCES])
749
return class_keep
750
751
# 2. Map over class IDs
752
nms_keep = tf.map_fn(nms_keep_map, unique_pre_nms_class_ids,
753
dtype=tf.int64)
754
# 3. Merge results into one list, and remove -1 padding
755
nms_keep = tf.reshape(nms_keep, [-1])
756
nms_keep = tf.gather(nms_keep, tf.where(nms_keep > -1)[:, 0])
757
# 4. Compute intersection between keep and nms_keep
758
keep = tf.sets.set_intersection(tf.expand_dims(keep, 0),
759
tf.expand_dims(nms_keep, 0))
760
keep = tf.sparse_tensor_to_dense(keep)[0]
761
# Keep top detections
762
roi_count = config.DETECTION_MAX_INSTANCES
763
class_scores_keep = tf.gather(class_scores, keep)
764
num_keep = tf.minimum(tf.shape(class_scores_keep)[0], roi_count)
765
top_ids = tf.nn.top_k(class_scores_keep, k=num_keep, sorted=True)[1]
766
keep = tf.gather(keep, top_ids)
767
768
# Arrange output as [N, (y1, x1, y2, x2, class_id, score)]
769
# Coordinates are normalized.
770
detections = tf.concat([
771
tf.gather(refined_rois, keep),
772
tf.to_float(tf.gather(class_ids, keep))[..., tf.newaxis],
773
tf.gather(class_scores, keep)[..., tf.newaxis]
774
], axis=1)
775
776
# Pad with zeros if detections < DETECTION_MAX_INSTANCES
777
gap = config.DETECTION_MAX_INSTANCES - tf.shape(detections)[0]
778
detections = tf.pad(detections, [(0, gap), (0, 0)], "CONSTANT")
779
return detections
780
781
782
class DetectionLayer(KE.Layer):
783
"""Takes classified proposal boxes and their bounding box deltas and
784
returns the final detection boxes.
785
786
Returns:
787
[batch, num_detections, (y1, x1, y2, x2, class_id, class_score)] where
788
coordinates are normalized.
789
"""
790
791
def __init__(self, config=None, **kwargs):
792
super(DetectionLayer, self).__init__(**kwargs)
793
self.config = config
794
795
def call(self, inputs):
796
rois = inputs[0]
797
mrcnn_class = inputs[1]
798
mrcnn_bbox = inputs[2]
799
image_meta = inputs[3]
800
801
# Get windows of images in normalized coordinates. Windows are the area
802
# in the image that excludes the padding.
803
# Use the shape of the first image in the batch to normalize the window
804
# because we know that all images get resized to the same size.
805
m = parse_image_meta_graph(image_meta)
806
image_shape = m['image_shape'][0]
807
window = norm_boxes_graph(m['window'], image_shape[:2])
808
809
# Run detection refinement graph on each item in the batch
810
detections_batch = utils.batch_slice(
811
[rois, mrcnn_class, mrcnn_bbox, window],
812
lambda x, y, w, z: refine_detections_graph(x, y, w, z, self.config),
813
self.config.IMAGES_PER_GPU)
814
815
# Reshape output
816
# [batch, num_detections, (y1, x1, y2, x2, class_id, class_score)] in
817
# normalized coordinates
818
return tf.reshape(
819
detections_batch,
820
[self.config.BATCH_SIZE, self.config.DETECTION_MAX_INSTANCES, 6])
821
822
def compute_output_shape(self, input_shape):
823
return (None, self.config.DETECTION_MAX_INSTANCES, 6)
824
825
826
############################################################
827
# Region Proposal Network (RPN)
828
############################################################
829
830
def rpn_graph(feature_map, anchors_per_location, anchor_stride):
831
"""Builds the computation graph of Region Proposal Network.
832
833
feature_map: backbone features [batch, height, width, depth]
834
anchors_per_location: number of anchors per pixel in the feature map
835
anchor_stride: Controls the density of anchors. Typically 1 (anchors for
836
every pixel in the feature map), or 2 (every other pixel).
837
838
Returns:
839
rpn_class_logits: [batch, H * W * anchors_per_location, 2] Anchor classifier logits (before softmax)
840
rpn_probs: [batch, H * W * anchors_per_location, 2] Anchor classifier probabilities.
841
rpn_bbox: [batch, H * W * anchors_per_location, (dy, dx, log(dh), log(dw))] Deltas to be
842
applied to anchors.
843
"""
844
# TODO: check if stride of 2 causes alignment issues if the feature map
845
# is not even.
846
# Shared convolutional base of the RPN
847
shared = KL.Conv2D(512, (3, 3), padding='same', activation='relu',
848
strides=anchor_stride,
849
name='rpn_conv_shared')(feature_map)
850
851
# Anchor Score. [batch, height, width, anchors per location * 2].
852
x = KL.Conv2D(2 * anchors_per_location, (1, 1), padding='valid',
853
activation='linear', name='rpn_class_raw')(shared)
854
855
# Reshape to [batch, anchors, 2]
856
rpn_class_logits = KL.Lambda(
857
lambda t: tf.reshape(t, [tf.shape(t)[0], -1, 2]))(x)
858
859
# Softmax on last dimension of BG/FG.
860
rpn_probs = KL.Activation(
861
"softmax", name="rpn_class_xxx")(rpn_class_logits)
862
863
# Bounding box refinement. [batch, H, W, anchors per location * depth]
864
# where depth is [x, y, log(w), log(h)]
865
x = KL.Conv2D(anchors_per_location * 4, (1, 1), padding="valid",
866
activation='linear', name='rpn_bbox_pred')(shared)
867
868
# Reshape to [batch, anchors, 4]
869
rpn_bbox = KL.Lambda(lambda t: tf.reshape(t, [tf.shape(t)[0], -1, 4]))(x)
870
871
return [rpn_class_logits, rpn_probs, rpn_bbox]
872
873
874
def build_rpn_model(anchor_stride, anchors_per_location, depth):
875
"""Builds a Keras model of the Region Proposal Network.
876
It wraps the RPN graph so it can be used multiple times with shared
877
weights.
878
879
anchors_per_location: number of anchors per pixel in the feature map
880
anchor_stride: Controls the density of anchors. Typically 1 (anchors for
881
every pixel in the feature map), or 2 (every other pixel).
882
depth: Depth of the backbone feature map.
883
884
Returns a Keras Model object. The model outputs, when called, are:
885
rpn_class_logits: [batch, H * W * anchors_per_location, 2] Anchor classifier logits (before softmax)
886
rpn_probs: [batch, H * W * anchors_per_location, 2] Anchor classifier probabilities.
887
rpn_bbox: [batch, H * W * anchors_per_location, (dy, dx, log(dh), log(dw))] Deltas to be
888
applied to anchors.
889
"""
890
input_feature_map = KL.Input(shape=[None, None, depth],
891
name="input_rpn_feature_map")
892
outputs = rpn_graph(input_feature_map, anchors_per_location, anchor_stride)
893
return KM.Model([input_feature_map], outputs, name="rpn_model")
894
895
896
############################################################
897
# Feature Pyramid Network Heads
898
############################################################
899
900
def fpn_classifier_graph(rois, feature_maps, image_meta,
901
pool_size, num_classes, train_bn=True,
902
fc_layers_size=1024):
903
"""Builds the computation graph of the feature pyramid network classifier
904
and regressor heads.
905
906
rois: [batch, num_rois, (y1, x1, y2, x2)] Proposal boxes in normalized
907
coordinates.
908
feature_maps: List of feature maps from different layers of the pyramid,
909
[P2, P3, P4, P5]. Each has a different resolution.
910
image_meta: [batch, (meta data)] Image details. See compose_image_meta()
911
pool_size: The width of the square feature map generated from ROI Pooling.
912
num_classes: number of classes, which determines the depth of the results
913
train_bn: Boolean. Train or freeze Batch Norm layers
914
fc_layers_size: Size of the 2 FC layers
915
916
Returns:
917
logits: [batch, num_rois, NUM_CLASSES] classifier logits (before softmax)
918
probs: [batch, num_rois, NUM_CLASSES] classifier probabilities
919
bbox_deltas: [batch, num_rois, NUM_CLASSES, (dy, dx, log(dh), log(dw))] Deltas to apply to
920
proposal boxes
921
"""
922
# ROI Pooling
923
# Shape: [batch, num_rois, POOL_SIZE, POOL_SIZE, channels]
924
x = PyramidROIAlign([pool_size, pool_size],
925
name="roi_align_classifier")([rois, image_meta] + feature_maps)
926
# Two 1024 FC layers (implemented with Conv2D for consistency)
927
x = KL.TimeDistributed(KL.Conv2D(fc_layers_size, (pool_size, pool_size), padding="valid"),
928
name="mrcnn_class_conv1")(x)
929
x = KL.TimeDistributed(BatchNorm(), name='mrcnn_class_bn1')(x, training=train_bn)
930
x = KL.Activation('relu')(x)
931
x = KL.TimeDistributed(KL.Conv2D(fc_layers_size, (1, 1)),
932
name="mrcnn_class_conv2")(x)
933
x = KL.TimeDistributed(BatchNorm(), name='mrcnn_class_bn2')(x, training=train_bn)
934
x = KL.Activation('relu')(x)
935
936
shared = KL.Lambda(lambda x: K.squeeze(K.squeeze(x, 3), 2),
937
name="pool_squeeze")(x)
938
939
# Classifier head
940
mrcnn_class_logits = KL.TimeDistributed(KL.Dense(num_classes),
941
name='mrcnn_class_logits')(shared)
942
mrcnn_probs = KL.TimeDistributed(KL.Activation("softmax"),
943
name="mrcnn_class")(mrcnn_class_logits)
944
945
# BBox head
946
# [batch, num_rois, NUM_CLASSES * (dy, dx, log(dh), log(dw))]
947
x = KL.TimeDistributed(KL.Dense(num_classes * 4, activation='linear'),
948
name='mrcnn_bbox_fc')(shared)
949
# Reshape to [batch, num_rois, NUM_CLASSES, (dy, dx, log(dh), log(dw))]
950
s = K.int_shape(x)
951
mrcnn_bbox = KL.Reshape((s[1], num_classes, 4), name="mrcnn_bbox")(x)
952
953
return mrcnn_class_logits, mrcnn_probs, mrcnn_bbox
954
955
956
def build_fpn_mask_graph(rois, feature_maps, image_meta,
957
pool_size, num_classes, train_bn=True):
958
"""Builds the computation graph of the mask head of Feature Pyramid Network.
959
960
rois: [batch, num_rois, (y1, x1, y2, x2)] Proposal boxes in normalized
961
coordinates.
962
feature_maps: List of feature maps from different layers of the pyramid,
963
[P2, P3, P4, P5]. Each has a different resolution.
964
image_meta: [batch, (meta data)] Image details. See compose_image_meta()
965
pool_size: The width of the square feature map generated from ROI Pooling.
966
num_classes: number of classes, which determines the depth of the results
967
train_bn: Boolean. Train or freeze Batch Norm layers
968
969
Returns: Masks [batch, num_rois, MASK_POOL_SIZE, MASK_POOL_SIZE, NUM_CLASSES]
970
"""
971
# ROI Pooling
972
# Shape: [batch, num_rois, MASK_POOL_SIZE, MASK_POOL_SIZE, channels]
973
x = PyramidROIAlign([pool_size, pool_size],
974
name="roi_align_mask")([rois, image_meta] + feature_maps)
975
976
# Conv layers
977
x = KL.TimeDistributed(KL.Conv2D(256, (3, 3), padding="same"),
978
name="mrcnn_mask_conv1")(x)
979
x = KL.TimeDistributed(BatchNorm(),
980
name='mrcnn_mask_bn1')(x, training=train_bn)
981
x = KL.Activation('relu')(x)
982
983
x = KL.TimeDistributed(KL.Conv2D(256, (3, 3), padding="same"),
984
name="mrcnn_mask_conv2")(x)
985
x = KL.TimeDistributed(BatchNorm(),
986
name='mrcnn_mask_bn2')(x, training=train_bn)
987
x = KL.Activation('relu')(x)
988
989
x = KL.TimeDistributed(KL.Conv2D(256, (3, 3), padding="same"),
990
name="mrcnn_mask_conv3")(x)
991
x = KL.TimeDistributed(BatchNorm(),
992
name='mrcnn_mask_bn3')(x, training=train_bn)
993
x = KL.Activation('relu')(x)
994
995
x = KL.TimeDistributed(KL.Conv2D(256, (3, 3), padding="same"),
996
name="mrcnn_mask_conv4")(x)
997
x = KL.TimeDistributed(BatchNorm(),
998
name='mrcnn_mask_bn4')(x, training=train_bn)
999
x = KL.Activation('relu')(x)
1000
1001
x = KL.TimeDistributed(KL.Conv2DTranspose(256, (2, 2), strides=2, activation="relu"),
1002
name="mrcnn_mask_deconv")(x)
1003
x = KL.TimeDistributed(KL.Conv2D(num_classes, (1, 1), strides=1, activation="sigmoid"),
1004
name="mrcnn_mask")(x)
1005
return x
1006
1007
1008
############################################################
1009
# Loss Functions
1010
############################################################
1011
1012
def smooth_l1_loss(y_true, y_pred):
1013
"""Implements Smooth-L1 loss.
1014
y_true and y_pred are typically: [N, 4], but could be any shape.
1015
"""
1016
diff = K.abs(y_true - y_pred)
1017
less_than_one = K.cast(K.less(diff, 1.0), "float32")
1018
loss = (less_than_one * 0.5 * diff**2) + (1 - less_than_one) * (diff - 0.5)
1019
return loss
1020
1021
1022
def rpn_class_loss_graph(rpn_match, rpn_class_logits):
1023
"""RPN anchor classifier loss.
1024
1025
rpn_match: [batch, anchors, 1]. Anchor match type. 1=positive,
1026
-1=negative, 0=neutral anchor.
1027
rpn_class_logits: [batch, anchors, 2]. RPN classifier logits for BG/FG.
1028
"""
1029
# Squeeze last dim to simplify
1030
rpn_match = tf.squeeze(rpn_match, -1)
1031
# Get anchor classes. Convert the -1/+1 match to 0/1 values.
1032
anchor_class = K.cast(K.equal(rpn_match, 1), tf.int32)
1033
# Positive and Negative anchors contribute to the loss,
1034
# but neutral anchors (match value = 0) don't.
1035
indices = tf.where(K.not_equal(rpn_match, 0))
1036
# Pick rows that contribute to the loss and filter out the rest.
1037
rpn_class_logits = tf.gather_nd(rpn_class_logits, indices)
1038
anchor_class = tf.gather_nd(anchor_class, indices)
1039
# Cross entropy loss
1040
loss = K.sparse_categorical_crossentropy(target=anchor_class,
1041
output=rpn_class_logits,
1042
from_logits=True)
1043
loss = K.switch(tf.size(loss) > 0, K.mean(loss), tf.constant(0.0))
1044
return loss
1045
1046
1047
def rpn_bbox_loss_graph(config, target_bbox, rpn_match, rpn_bbox):
1048
"""Return the RPN bounding box loss graph.
1049
1050
config: the model config object.
1051
target_bbox: [batch, max positive anchors, (dy, dx, log(dh), log(dw))].
1052
Uses 0 padding to fill in unsed bbox deltas.
1053
rpn_match: [batch, anchors, 1]. Anchor match type. 1=positive,
1054
-1=negative, 0=neutral anchor.
1055
rpn_bbox: [batch, anchors, (dy, dx, log(dh), log(dw))]
1056
"""
1057
# Positive anchors contribute to the loss, but negative and
1058
# neutral anchors (match value of 0 or -1) don't.
1059
rpn_match = K.squeeze(rpn_match, -1)
1060
indices = tf.where(K.equal(rpn_match, 1))
1061
1062
# Pick bbox deltas that contribute to the loss
1063
rpn_bbox = tf.gather_nd(rpn_bbox, indices)
1064
1065
# Trim target bounding box deltas to the same length as rpn_bbox.
1066
batch_counts = K.sum(K.cast(K.equal(rpn_match, 1), tf.int32), axis=1)
1067
target_bbox = batch_pack_graph(target_bbox, batch_counts,
1068
config.IMAGES_PER_GPU)
1069
1070
loss = smooth_l1_loss(target_bbox, rpn_bbox)
1071
1072
loss = K.switch(tf.size(loss) > 0, K.mean(loss), tf.constant(0.0))
1073
return loss
1074
1075
1076
def mrcnn_class_loss_graph(target_class_ids, pred_class_logits,
1077
active_class_ids):
1078
"""Loss for the classifier head of Mask RCNN.
1079
1080
target_class_ids: [batch, num_rois]. Integer class IDs. Uses zero
1081
padding to fill in the array.
1082
pred_class_logits: [batch, num_rois, num_classes]
1083
active_class_ids: [batch, num_classes]. Has a value of 1 for
1084
classes that are in the dataset of the image, and 0
1085
for classes that are not in the dataset.
1086
"""
1087
# During model building, Keras calls this function with
1088
# target_class_ids of type float32. Unclear why. Cast it
1089
# to int to get around it.
1090
target_class_ids = tf.cast(target_class_ids, 'int64')
1091
1092
# Find predictions of classes that are not in the dataset.
1093
pred_class_ids = tf.argmax(pred_class_logits, axis=2)
1094
# TODO: Update this line to work with batch > 1. Right now it assumes all
1095
# images in a batch have the same active_class_ids
1096
pred_active = tf.gather(active_class_ids[0], pred_class_ids)
1097
1098
# Loss
1099
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
1100
labels=target_class_ids, logits=pred_class_logits)
1101
1102
# Erase losses of predictions of classes that are not in the active
1103
# classes of the image.
1104
loss = loss * pred_active
1105
1106
# Computer loss mean. Use only predictions that contribute
1107
# to the loss to get a correct mean.
1108
loss = tf.reduce_sum(loss) / tf.reduce_sum(pred_active)
1109
return loss
1110
1111
1112
def mrcnn_bbox_loss_graph(target_bbox, target_class_ids, pred_bbox):
1113
"""Loss for Mask R-CNN bounding box refinement.
1114
1115
target_bbox: [batch, num_rois, (dy, dx, log(dh), log(dw))]
1116
target_class_ids: [batch, num_rois]. Integer class IDs.
1117
pred_bbox: [batch, num_rois, num_classes, (dy, dx, log(dh), log(dw))]
1118
"""
1119
# Reshape to merge batch and roi dimensions for simplicity.
1120
target_class_ids = K.reshape(target_class_ids, (-1,))
1121
target_bbox = K.reshape(target_bbox, (-1, 4))
1122
pred_bbox = K.reshape(pred_bbox, (-1, K.int_shape(pred_bbox)[2], 4))
1123
1124
# Only positive ROIs contribute to the loss. And only
1125
# the right class_id of each ROI. Get their indices.
1126
positive_roi_ix = tf.where(target_class_ids > 0)[:, 0]
1127
positive_roi_class_ids = tf.cast(
1128
tf.gather(target_class_ids, positive_roi_ix), tf.int64)
1129
indices = tf.stack([positive_roi_ix, positive_roi_class_ids], axis=1)
1130
1131
# Gather the deltas (predicted and true) that contribute to loss
1132
target_bbox = tf.gather(target_bbox, positive_roi_ix)
1133
pred_bbox = tf.gather_nd(pred_bbox, indices)
1134
1135
# Smooth-L1 Loss
1136
loss = K.switch(tf.size(target_bbox) > 0,
1137
smooth_l1_loss(y_true=target_bbox, y_pred=pred_bbox),
1138
tf.constant(0.0))
1139
loss = K.mean(loss)
1140
return loss
1141
1142
1143
def mrcnn_mask_loss_graph(target_masks, target_class_ids, pred_masks):
1144
"""Mask binary cross-entropy loss for the masks head.
1145
1146
target_masks: [batch, num_rois, height, width].
1147
A float32 tensor of values 0 or 1. Uses zero padding to fill array.
1148
target_class_ids: [batch, num_rois]. Integer class IDs. Zero padded.
1149
pred_masks: [batch, proposals, height, width, num_classes] float32 tensor
1150
with values from 0 to 1.
1151
"""
1152
# Reshape for simplicity. Merge first two dimensions into one.
1153
target_class_ids = K.reshape(target_class_ids, (-1,))
1154
mask_shape = tf.shape(target_masks)
1155
target_masks = K.reshape(target_masks, (-1, mask_shape[2], mask_shape[3]))
1156
pred_shape = tf.shape(pred_masks)
1157
pred_masks = K.reshape(pred_masks,
1158
(-1, pred_shape[2], pred_shape[3], pred_shape[4]))
1159
# Permute predicted masks to [N, num_classes, height, width]
1160
pred_masks = tf.transpose(pred_masks, [0, 3, 1, 2])
1161
1162
# Only positive ROIs contribute to the loss. And only
1163
# the class specific mask of each ROI.
1164
positive_ix = tf.where(target_class_ids > 0)[:, 0]
1165
positive_class_ids = tf.cast(
1166
tf.gather(target_class_ids, positive_ix), tf.int64)
1167
indices = tf.stack([positive_ix, positive_class_ids], axis=1)
1168
1169
# Gather the masks (predicted and true) that contribute to loss
1170
y_true = tf.gather(target_masks, positive_ix)
1171
y_pred = tf.gather_nd(pred_masks, indices)
1172
1173
# Compute binary cross entropy. If no positive ROIs, then return 0.
1174
# shape: [batch, roi, num_classes]
1175
loss = K.switch(tf.size(y_true) > 0,
1176
K.binary_crossentropy(target=y_true, output=y_pred),
1177
tf.constant(0.0))
1178
loss = K.mean(loss)
1179
return loss
1180
1181
1182
############################################################
1183
# Data Generator
1184
############################################################
1185
1186
def load_image_gt(dataset, config, image_id, augment=False, augmentation=None,
1187
use_mini_mask=False):
1188
"""Load and return ground truth data for an image (image, mask, bounding boxes).
1189
1190
augment: (deprecated. Use augmentation instead). If true, apply random
1191
image augmentation. Currently, only horizontal flipping is offered.
1192
augmentation: Optional. An imgaug (https://github.com/aleju/imgaug) augmentation.
1193
For example, passing imgaug.augmenters.Fliplr(0.5) flips images
1194
right/left 50% of the time.
1195
use_mini_mask: If False, returns full-size masks that are the same height
1196
and width as the original image. These can be big, for example
1197
1024x1024x100 (for 100 instances). Mini masks are smaller, typically,
1198
224x224 and are generated by extracting the bounding box of the
1199
object and resizing it to MINI_MASK_SHAPE.
1200
1201
Returns:
1202
image: [height, width, 3]
1203
shape: the original shape of the image before resizing and cropping.
1204
class_ids: [instance_count] Integer class IDs
1205
bbox: [instance_count, (y1, x1, y2, x2)]
1206
mask: [height, width, instance_count]. The height and width are those
1207
of the image unless use_mini_mask is True, in which case they are
1208
defined in MINI_MASK_SHAPE.
1209
"""
1210
# Load image and mask
1211
image = dataset.load_image(image_id)
1212
mask, class_ids = dataset.load_mask(image_id)
1213
original_shape = image.shape
1214
image, window, scale, padding, crop = utils.resize_image(
1215
image,
1216
min_dim=config.IMAGE_MIN_DIM,
1217
min_scale=config.IMAGE_MIN_SCALE,
1218
max_dim=config.IMAGE_MAX_DIM,
1219
mode=config.IMAGE_RESIZE_MODE)
1220
mask = utils.resize_mask(mask, scale, padding, crop)
1221
1222
# Random horizontal flips.
1223
# TODO: will be removed in a future update in favor of augmentation
1224
if augment:
1225
logging.warning("'augment' is deprecated. Use 'augmentation' instead.")
1226
if random.randint(0, 1):
1227
image = np.fliplr(image)
1228
mask = np.fliplr(mask)
1229
1230
# Augmentation
1231
# This requires the imgaug lib (https://github.com/aleju/imgaug)
1232
if augmentation:
1233
import imgaug
1234
1235
# Augmenters that are safe to apply to masks
1236
# Some, such as Affine, have settings that make them unsafe, so always
1237
# test your augmentation on masks
1238
MASK_AUGMENTERS = ["Sequential", "SomeOf", "OneOf", "Sometimes",
1239
"Fliplr", "Flipud", "CropAndPad",
1240
"Affine", "PiecewiseAffine"]
1241
1242
def hook(images, augmenter, parents, default):
1243
"""Determines which augmenters to apply to masks."""
1244
return augmenter.__class__.__name__ in MASK_AUGMENTERS
1245
1246
# Store shapes before augmentation to compare
1247
image_shape = image.shape
1248
mask_shape = mask.shape
1249
# Make augmenters deterministic to apply similarly to images and masks
1250
det = augmentation.to_deterministic()
1251
image = det.augment_image(image)
1252
# Change mask to np.uint8 because imgaug doesn't support np.bool
1253
mask = det.augment_image(mask.astype(np.uint8),
1254
hooks=imgaug.HooksImages(activator=hook))
1255
# Verify that shapes didn't change
1256
assert image.shape == image_shape, "Augmentation shouldn't change image size"
1257
assert mask.shape == mask_shape, "Augmentation shouldn't change mask size"
1258
# Change mask back to bool
1259
mask = mask.astype(np.bool)
1260
1261
# Note that some boxes might be all zeros if the corresponding mask got cropped out.
1262
# and here is to filter them out
1263
_idx = np.sum(mask, axis=(0, 1)) > 0
1264
mask = mask[:, :, _idx]
1265
class_ids = class_ids[_idx]
1266
# Bounding boxes. Note that some boxes might be all zeros
1267
# if the corresponding mask got cropped out.
1268
# bbox: [num_instances, (y1, x1, y2, x2)]
1269
bbox = utils.extract_bboxes(mask)
1270
1271
# Active classes
1272
# Different datasets have different classes, so track the
1273
# classes supported in the dataset of this image.
1274
active_class_ids = np.zeros([dataset.num_classes], dtype=np.int32)
1275
source_class_ids = dataset.source_class_ids[dataset.image_info[image_id]["source"]]
1276
active_class_ids[source_class_ids] = 1
1277
1278
# Resize masks to smaller size to reduce memory usage
1279
if use_mini_mask:
1280
mask = utils.minimize_mask(bbox, mask, config.MINI_MASK_SHAPE)
1281
1282
# Image meta data
1283
image_meta = compose_image_meta(image_id, original_shape, image.shape,
1284
window, scale, active_class_ids)
1285
1286
return image, image_meta, class_ids, bbox, mask
1287
1288
1289
def build_detection_targets(rpn_rois, gt_class_ids, gt_boxes, gt_masks, config):
1290
"""Generate targets for training Stage 2 classifier and mask heads.
1291
This is not used in normal training. It's useful for debugging or to train
1292
the Mask RCNN heads without using the RPN head.
1293
1294
Inputs:
1295
rpn_rois: [N, (y1, x1, y2, x2)] proposal boxes.
1296
gt_class_ids: [instance count] Integer class IDs
1297
gt_boxes: [instance count, (y1, x1, y2, x2)]
1298
gt_masks: [height, width, instance count] Ground truth masks. Can be full
1299
size or mini-masks.
1300
1301
Returns:
1302
rois: [TRAIN_ROIS_PER_IMAGE, (y1, x1, y2, x2)]
1303
class_ids: [TRAIN_ROIS_PER_IMAGE]. Integer class IDs.
1304
bboxes: [TRAIN_ROIS_PER_IMAGE, NUM_CLASSES, (y, x, log(h), log(w))]. Class-specific
1305
bbox refinements.
1306
masks: [TRAIN_ROIS_PER_IMAGE, height, width, NUM_CLASSES). Class specific masks cropped
1307
to bbox boundaries and resized to neural network output size.
1308
"""
1309
assert rpn_rois.shape[0] > 0
1310
assert gt_class_ids.dtype == np.int32, "Expected int but got {}".format(
1311
gt_class_ids.dtype)
1312
assert gt_boxes.dtype == np.int32, "Expected int but got {}".format(
1313
gt_boxes.dtype)
1314
assert gt_masks.dtype == np.bool_, "Expected bool but got {}".format(
1315
gt_masks.dtype)
1316
1317
# It's common to add GT Boxes to ROIs but we don't do that here because
1318
# according to XinLei Chen's paper, it doesn't help.
1319
1320
# Trim empty padding in gt_boxes and gt_masks parts
1321
instance_ids = np.where(gt_class_ids > 0)[0]
1322
assert instance_ids.shape[0] > 0, "Image must contain instances."
1323
gt_class_ids = gt_class_ids[instance_ids]
1324
gt_boxes = gt_boxes[instance_ids]
1325
gt_masks = gt_masks[:, :, instance_ids]
1326
1327
# Compute areas of ROIs and ground truth boxes.
1328
rpn_roi_area = (rpn_rois[:, 2] - rpn_rois[:, 0]) * \
1329
(rpn_rois[:, 3] - rpn_rois[:, 1])
1330
gt_box_area = (gt_boxes[:, 2] - gt_boxes[:, 0]) * \
1331
(gt_boxes[:, 3] - gt_boxes[:, 1])
1332
1333
# Compute overlaps [rpn_rois, gt_boxes]
1334
overlaps = np.zeros((rpn_rois.shape[0], gt_boxes.shape[0]))
1335
for i in range(overlaps.shape[1]):
1336
gt = gt_boxes[i]
1337
overlaps[:, i] = utils.compute_iou(
1338
gt, rpn_rois, gt_box_area[i], rpn_roi_area)
1339
1340
# Assign ROIs to GT boxes
1341
rpn_roi_iou_argmax = np.argmax(overlaps, axis=1)
1342
rpn_roi_iou_max = overlaps[np.arange(
1343
overlaps.shape[0]), rpn_roi_iou_argmax]
1344
# GT box assigned to each ROI
1345
rpn_roi_gt_boxes = gt_boxes[rpn_roi_iou_argmax]
1346
rpn_roi_gt_class_ids = gt_class_ids[rpn_roi_iou_argmax]
1347
1348
# Positive ROIs are those with >= 0.5 IoU with a GT box.
1349
fg_ids = np.where(rpn_roi_iou_max > 0.5)[0]
1350
1351
# Negative ROIs are those with max IoU 0.1-0.5 (hard example mining)
1352
# TODO: To hard example mine or not to hard example mine, that's the question
1353
# bg_ids = np.where((rpn_roi_iou_max >= 0.1) & (rpn_roi_iou_max < 0.5))[0]
1354
bg_ids = np.where(rpn_roi_iou_max < 0.5)[0]
1355
1356
# Subsample ROIs. Aim for 33% foreground.
1357
# FG
1358
fg_roi_count = int(config.TRAIN_ROIS_PER_IMAGE * config.ROI_POSITIVE_RATIO)
1359
if fg_ids.shape[0] > fg_roi_count:
1360
keep_fg_ids = np.random.choice(fg_ids, fg_roi_count, replace=False)
1361
else:
1362
keep_fg_ids = fg_ids
1363
# BG
1364
remaining = config.TRAIN_ROIS_PER_IMAGE - keep_fg_ids.shape[0]
1365
if bg_ids.shape[0] > remaining:
1366
keep_bg_ids = np.random.choice(bg_ids, remaining, replace=False)
1367
else:
1368
keep_bg_ids = bg_ids
1369
# Combine indices of ROIs to keep
1370
keep = np.concatenate([keep_fg_ids, keep_bg_ids])
1371
# Need more?
1372
remaining = config.TRAIN_ROIS_PER_IMAGE - keep.shape[0]
1373
if remaining > 0:
1374
# Looks like we don't have enough samples to maintain the desired
1375
# balance. Reduce requirements and fill in the rest. This is
1376
# likely different from the Mask RCNN paper.
1377
1378
# There is a small chance we have neither fg nor bg samples.
1379
if keep.shape[0] == 0:
1380
# Pick bg regions with easier IoU threshold
1381
bg_ids = np.where(rpn_roi_iou_max < 0.5)[0]
1382
assert bg_ids.shape[0] >= remaining
1383
keep_bg_ids = np.random.choice(bg_ids, remaining, replace=False)
1384
assert keep_bg_ids.shape[0] == remaining
1385
keep = np.concatenate([keep, keep_bg_ids])
1386
else:
1387
# Fill the rest with repeated bg rois.
1388
keep_extra_ids = np.random.choice(
1389
keep_bg_ids, remaining, replace=True)
1390
keep = np.concatenate([keep, keep_extra_ids])
1391
assert keep.shape[0] == config.TRAIN_ROIS_PER_IMAGE, \
1392
"keep doesn't match ROI batch size {}, {}".format(
1393
keep.shape[0], config.TRAIN_ROIS_PER_IMAGE)
1394
1395
# Reset the gt boxes assigned to BG ROIs.
1396
rpn_roi_gt_boxes[keep_bg_ids, :] = 0
1397
rpn_roi_gt_class_ids[keep_bg_ids] = 0
1398
1399
# For each kept ROI, assign a class_id, and for FG ROIs also add bbox refinement.
1400
rois = rpn_rois[keep]
1401
roi_gt_boxes = rpn_roi_gt_boxes[keep]
1402
roi_gt_class_ids = rpn_roi_gt_class_ids[keep]
1403
roi_gt_assignment = rpn_roi_iou_argmax[keep]
1404
1405
# Class-aware bbox deltas. [y, x, log(h), log(w)]
1406
bboxes = np.zeros((config.TRAIN_ROIS_PER_IMAGE,
1407
config.NUM_CLASSES, 4), dtype=np.float32)
1408
pos_ids = np.where(roi_gt_class_ids > 0)[0]
1409
bboxes[pos_ids, roi_gt_class_ids[pos_ids]] = utils.box_refinement(
1410
rois[pos_ids], roi_gt_boxes[pos_ids, :4])
1411
# Normalize bbox refinements
1412
bboxes /= config.BBOX_STD_DEV
1413
1414
# Generate class-specific target masks
1415
masks = np.zeros((config.TRAIN_ROIS_PER_IMAGE, config.MASK_SHAPE[0], config.MASK_SHAPE[1], config.NUM_CLASSES),
1416
dtype=np.float32)
1417
for i in pos_ids:
1418
class_id = roi_gt_class_ids[i]
1419
assert class_id > 0, "class id must be greater than 0"
1420
gt_id = roi_gt_assignment[i]
1421
class_mask = gt_masks[:, :, gt_id]
1422
1423
if config.USE_MINI_MASK:
1424
# Create a mask placeholder, the size of the image
1425
placeholder = np.zeros(config.IMAGE_SHAPE[:2], dtype=bool)
1426
# GT box
1427
gt_y1, gt_x1, gt_y2, gt_x2 = gt_boxes[gt_id]
1428
gt_w = gt_x2 - gt_x1
1429
gt_h = gt_y2 - gt_y1
1430
# Resize mini mask to size of GT box
1431
placeholder[gt_y1:gt_y2, gt_x1:gt_x2] = \
1432
np.round(utils.resize(class_mask, (gt_h, gt_w))).astype(bool)
1433
# Place the mini batch in the placeholder
1434
class_mask = placeholder
1435
1436
# Pick part of the mask and resize it
1437
y1, x1, y2, x2 = rois[i].astype(np.int32)
1438
m = class_mask[y1:y2, x1:x2]
1439
mask = utils.resize(m, config.MASK_SHAPE)
1440
masks[i, :, :, class_id] = mask
1441
1442
return rois, roi_gt_class_ids, bboxes, masks
1443
1444
1445
def build_rpn_targets(image_shape, anchors, gt_class_ids, gt_boxes, config):
1446
"""Given the anchors and GT boxes, compute overlaps and identify positive
1447
anchors and deltas to refine them to match their corresponding GT boxes.
1448
1449
anchors: [num_anchors, (y1, x1, y2, x2)]
1450
gt_class_ids: [num_gt_boxes] Integer class IDs.
1451
gt_boxes: [num_gt_boxes, (y1, x1, y2, x2)]
1452
1453
Returns:
1454
rpn_match: [N] (int32) matches between anchors and GT boxes.
1455
1 = positive anchor, -1 = negative anchor, 0 = neutral
1456
rpn_bbox: [N, (dy, dx, log(dh), log(dw))] Anchor bbox deltas.
1457
"""
1458
# RPN Match: 1 = positive anchor, -1 = negative anchor, 0 = neutral
1459
rpn_match = np.zeros([anchors.shape[0]], dtype=np.int32)
1460
# RPN bounding boxes: [max anchors per image, (dy, dx, log(dh), log(dw))]
1461
rpn_bbox = np.zeros((config.RPN_TRAIN_ANCHORS_PER_IMAGE, 4))
1462
1463
# Handle COCO crowds
1464
# A crowd box in COCO is a bounding box around several instances. Exclude
1465
# them from training. A crowd box is given a negative class ID.
1466
crowd_ix = np.where(gt_class_ids < 0)[0]
1467
if crowd_ix.shape[0] > 0:
1468
# Filter out crowds from ground truth class IDs and boxes
1469
non_crowd_ix = np.where(gt_class_ids > 0)[0]
1470
crowd_boxes = gt_boxes[crowd_ix]
1471
gt_class_ids = gt_class_ids[non_crowd_ix]
1472
gt_boxes = gt_boxes[non_crowd_ix]
1473
# Compute overlaps with crowd boxes [anchors, crowds]
1474
crowd_overlaps = utils.compute_overlaps(anchors, crowd_boxes)
1475
crowd_iou_max = np.amax(crowd_overlaps, axis=1)
1476
no_crowd_bool = (crowd_iou_max < 0.001)
1477
else:
1478
# All anchors don't intersect a crowd
1479
no_crowd_bool = np.ones([anchors.shape[0]], dtype=bool)
1480
1481
# Compute overlaps [num_anchors, num_gt_boxes]
1482
overlaps = utils.compute_overlaps(anchors, gt_boxes)
1483
1484
# Match anchors to GT Boxes
1485
# If an anchor overlaps a GT box with IoU >= 0.7 then it's positive.
1486
# If an anchor overlaps a GT box with IoU < 0.3 then it's negative.
1487
# Neutral anchors are those that don't match the conditions above,
1488
# and they don't influence the loss function.
1489
# However, don't keep any GT box unmatched (rare, but happens). Instead,
1490
# match it to the closest anchor (even if its max IoU is < 0.3).
1491
#
1492
# 1. Set negative anchors first. They get overwritten below if a GT box is
1493
# matched to them. Skip boxes in crowd areas.
1494
anchor_iou_argmax = np.argmax(overlaps, axis=1)
1495
anchor_iou_max = overlaps[np.arange(overlaps.shape[0]), anchor_iou_argmax]
1496
rpn_match[(anchor_iou_max < 0.3) & (no_crowd_bool)] = -1
1497
# 2. Set an anchor for each GT box (regardless of IoU value).
1498
# If multiple anchors have the same IoU match all of them
1499
gt_iou_argmax = np.argwhere(overlaps == np.max(overlaps, axis=0))[:,0]
1500
rpn_match[gt_iou_argmax] = 1
1501
# 3. Set anchors with high overlap as positive.
1502
rpn_match[anchor_iou_max >= 0.7] = 1
1503
1504
# Subsample to balance positive and negative anchors
1505
# Don't let positives be more than half the anchors
1506
ids = np.where(rpn_match == 1)[0]
1507
extra = len(ids) - (config.RPN_TRAIN_ANCHORS_PER_IMAGE // 2)
1508
if extra > 0:
1509
# Reset the extra ones to neutral
1510
ids = np.random.choice(ids, extra, replace=False)
1511
rpn_match[ids] = 0
1512
# Same for negative proposals
1513
ids = np.where(rpn_match == -1)[0]
1514
extra = len(ids) - (config.RPN_TRAIN_ANCHORS_PER_IMAGE -
1515
np.sum(rpn_match == 1))
1516
if extra > 0:
1517
# Rest the extra ones to neutral
1518
ids = np.random.choice(ids, extra, replace=False)
1519
rpn_match[ids] = 0
1520
1521
# For positive anchors, compute shift and scale needed to transform them
1522
# to match the corresponding GT boxes.
1523
ids = np.where(rpn_match == 1)[0]
1524
ix = 0 # index into rpn_bbox
1525
# TODO: use box_refinement() rather than duplicating the code here
1526
for i, a in zip(ids, anchors[ids]):
1527
# Closest gt box (it might have IoU < 0.7)
1528
gt = gt_boxes[anchor_iou_argmax[i]]
1529
1530
# Convert coordinates to center plus width/height.
1531
# GT Box
1532
gt_h = gt[2] - gt[0]
1533
gt_w = gt[3] - gt[1]
1534
gt_center_y = gt[0] + 0.5 * gt_h
1535
gt_center_x = gt[1] + 0.5 * gt_w
1536
# Anchor
1537
a_h = a[2] - a[0]
1538
a_w = a[3] - a[1]
1539
a_center_y = a[0] + 0.5 * a_h
1540
a_center_x = a[1] + 0.5 * a_w
1541
1542
# Compute the bbox refinement that the RPN should predict.
1543
rpn_bbox[ix] = [
1544
(gt_center_y - a_center_y) / a_h,
1545
(gt_center_x - a_center_x) / a_w,
1546
np.log(gt_h / a_h),
1547
np.log(gt_w / a_w),
1548
]
1549
# Normalize
1550
rpn_bbox[ix] /= config.RPN_BBOX_STD_DEV
1551
ix += 1
1552
1553
return rpn_match, rpn_bbox
1554
1555
1556
def generate_random_rois(image_shape, count, gt_class_ids, gt_boxes):
1557
"""Generates ROI proposals similar to what a region proposal network
1558
would generate.
1559
1560
image_shape: [Height, Width, Depth]
1561
count: Number of ROIs to generate
1562
gt_class_ids: [N] Integer ground truth class IDs
1563
gt_boxes: [N, (y1, x1, y2, x2)] Ground truth boxes in pixels.
1564
1565
Returns: [count, (y1, x1, y2, x2)] ROI boxes in pixels.
1566
"""
1567
# placeholder
1568
rois = np.zeros((count, 4), dtype=np.int32)
1569
1570
# Generate random ROIs around GT boxes (90% of count)
1571
rois_per_box = int(0.9 * count / gt_boxes.shape[0])
1572
for i in range(gt_boxes.shape[0]):
1573
gt_y1, gt_x1, gt_y2, gt_x2 = gt_boxes[i]
1574
h = gt_y2 - gt_y1
1575
w = gt_x2 - gt_x1
1576
# random boundaries
1577
r_y1 = max(gt_y1 - h, 0)
1578
r_y2 = min(gt_y2 + h, image_shape[0])
1579
r_x1 = max(gt_x1 - w, 0)
1580
r_x2 = min(gt_x2 + w, image_shape[1])
1581
1582
# To avoid generating boxes with zero area, we generate double what
1583
# we need and filter out the extra. If we get fewer valid boxes
1584
# than we need, we loop and try again.
1585
while True:
1586
y1y2 = np.random.randint(r_y1, r_y2, (rois_per_box * 2, 2))
1587
x1x2 = np.random.randint(r_x1, r_x2, (rois_per_box * 2, 2))
1588
# Filter out zero area boxes
1589
threshold = 1
1590
y1y2 = y1y2[np.abs(y1y2[:, 0] - y1y2[:, 1]) >=
1591
threshold][:rois_per_box]
1592
x1x2 = x1x2[np.abs(x1x2[:, 0] - x1x2[:, 1]) >=
1593
threshold][:rois_per_box]
1594
if y1y2.shape[0] == rois_per_box and x1x2.shape[0] == rois_per_box:
1595
break
1596
1597
# Sort on axis 1 to ensure x1 <= x2 and y1 <= y2 and then reshape
1598
# into x1, y1, x2, y2 order
1599
x1, x2 = np.split(np.sort(x1x2, axis=1), 2, axis=1)
1600
y1, y2 = np.split(np.sort(y1y2, axis=1), 2, axis=1)
1601
box_rois = np.hstack([y1, x1, y2, x2])
1602
rois[rois_per_box * i:rois_per_box * (i + 1)] = box_rois
1603
1604
# Generate random ROIs anywhere in the image (10% of count)
1605
remaining_count = count - (rois_per_box * gt_boxes.shape[0])
1606
# To avoid generating boxes with zero area, we generate double what
1607
# we need and filter out the extra. If we get fewer valid boxes
1608
# than we need, we loop and try again.
1609
while True:
1610
y1y2 = np.random.randint(0, image_shape[0], (remaining_count * 2, 2))
1611
x1x2 = np.random.randint(0, image_shape[1], (remaining_count * 2, 2))
1612
# Filter out zero area boxes
1613
threshold = 1
1614
y1y2 = y1y2[np.abs(y1y2[:, 0] - y1y2[:, 1]) >=
1615
threshold][:remaining_count]
1616
x1x2 = x1x2[np.abs(x1x2[:, 0] - x1x2[:, 1]) >=
1617
threshold][:remaining_count]
1618
if y1y2.shape[0] == remaining_count and x1x2.shape[0] == remaining_count:
1619
break
1620
1621
# Sort on axis 1 to ensure x1 <= x2 and y1 <= y2 and then reshape
1622
# into x1, y1, x2, y2 order
1623
x1, x2 = np.split(np.sort(x1x2, axis=1), 2, axis=1)
1624
y1, y2 = np.split(np.sort(y1y2, axis=1), 2, axis=1)
1625
global_rois = np.hstack([y1, x1, y2, x2])
1626
rois[-remaining_count:] = global_rois
1627
return rois
1628
1629
1630
def data_generator(dataset, config, shuffle=True, augment=False, augmentation=None,
1631
random_rois=0, batch_size=1, detection_targets=False,
1632
no_augmentation_sources=None):
1633
"""A generator that returns images and corresponding target class ids,
1634
bounding box deltas, and masks.
1635
1636
dataset: The Dataset object to pick data from
1637
config: The model config object
1638
shuffle: If True, shuffles the samples before every epoch
1639
augment: (deprecated. Use augmentation instead). If true, apply random
1640
image augmentation. Currently, only horizontal flipping is offered.
1641
augmentation: Optional. An imgaug (https://github.com/aleju/imgaug) augmentation.
1642
For example, passing imgaug.augmenters.Fliplr(0.5) flips images
1643
right/left 50% of the time.
1644
random_rois: If > 0 then generate proposals to be used to train the
1645
network classifier and mask heads. Useful if training
1646
the Mask RCNN part without the RPN.
1647
batch_size: How many images to return in each call
1648
detection_targets: If True, generate detection targets (class IDs, bbox
1649
deltas, and masks). Typically for debugging or visualizations because
1650
in trainig detection targets are generated by DetectionTargetLayer.
1651
no_augmentation_sources: Optional. List of sources to exclude for
1652
augmentation. A source is string that identifies a dataset and is
1653
defined in the Dataset class.
1654
1655
Returns a Python generator. Upon calling next() on it, the
1656
generator returns two lists, inputs and outputs. The contents
1657
of the lists differs depending on the received arguments:
1658
inputs list:
1659
- images: [batch, H, W, C]
1660
- image_meta: [batch, (meta data)] Image details. See compose_image_meta()
1661
- rpn_match: [batch, N] Integer (1=positive anchor, -1=negative, 0=neutral)
1662
- rpn_bbox: [batch, N, (dy, dx, log(dh), log(dw))] Anchor bbox deltas.
1663
- gt_class_ids: [batch, MAX_GT_INSTANCES] Integer class IDs
1664
- gt_boxes: [batch, MAX_GT_INSTANCES, (y1, x1, y2, x2)]
1665
- gt_masks: [batch, height, width, MAX_GT_INSTANCES]. The height and width
1666
are those of the image unless use_mini_mask is True, in which
1667
case they are defined in MINI_MASK_SHAPE.
1668
1669
outputs list: Usually empty in regular training. But if detection_targets
1670
is True then the outputs list contains target class_ids, bbox deltas,
1671
and masks.
1672
"""
1673
b = 0 # batch item index
1674
image_index = -1
1675
image_ids = np.copy(dataset.image_ids)
1676
error_count = 0
1677
no_augmentation_sources = no_augmentation_sources or []
1678
1679
# Anchors
1680
# [anchor_count, (y1, x1, y2, x2)]
1681
backbone_shapes = compute_backbone_shapes(config, config.IMAGE_SHAPE)
1682
anchors = utils.generate_pyramid_anchors(config.RPN_ANCHOR_SCALES,
1683
config.RPN_ANCHOR_RATIOS,
1684
backbone_shapes,
1685
config.BACKBONE_STRIDES,
1686
config.RPN_ANCHOR_STRIDE)
1687
1688
# Keras requires a generator to run indefinitely.
1689
while True:
1690
try:
1691
# Increment index to pick next image. Shuffle if at the start of an epoch.
1692
image_index = (image_index + 1) % len(image_ids)
1693
if shuffle and image_index == 0:
1694
np.random.shuffle(image_ids)
1695
1696
# Get GT bounding boxes and masks for image.
1697
image_id = image_ids[image_index]
1698
1699
# If the image source is not to be augmented pass None as augmentation
1700
if dataset.image_info[image_id]['source'] in no_augmentation_sources:
1701
image, image_meta, gt_class_ids, gt_boxes, gt_masks = \
1702
load_image_gt(dataset, config, image_id, augment=augment,
1703
augmentation=None,
1704
use_mini_mask=config.USE_MINI_MASK)
1705
else:
1706
image, image_meta, gt_class_ids, gt_boxes, gt_masks = \
1707
load_image_gt(dataset, config, image_id, augment=augment,
1708
augmentation=augmentation,
1709
use_mini_mask=config.USE_MINI_MASK)
1710
1711
# Skip images that have no instances. This can happen in cases
1712
# where we train on a subset of classes and the image doesn't
1713
# have any of the classes we care about.
1714
if not np.any(gt_class_ids > 0):
1715
continue
1716
1717
# RPN Targets
1718
rpn_match, rpn_bbox = build_rpn_targets(image.shape, anchors,
1719
gt_class_ids, gt_boxes, config)
1720
1721
# Mask R-CNN Targets
1722
if random_rois:
1723
rpn_rois = generate_random_rois(
1724
image.shape, random_rois, gt_class_ids, gt_boxes)
1725
if detection_targets:
1726
rois, mrcnn_class_ids, mrcnn_bbox, mrcnn_mask =\
1727
build_detection_targets(
1728
rpn_rois, gt_class_ids, gt_boxes, gt_masks, config)
1729
1730
# Init batch arrays
1731
if b == 0:
1732
batch_image_meta = np.zeros(
1733
(batch_size,) + image_meta.shape, dtype=image_meta.dtype)
1734
batch_rpn_match = np.zeros(
1735
[batch_size, anchors.shape[0], 1], dtype=rpn_match.dtype)
1736
batch_rpn_bbox = np.zeros(
1737
[batch_size, config.RPN_TRAIN_ANCHORS_PER_IMAGE, 4], dtype=rpn_bbox.dtype)
1738
batch_images = np.zeros(
1739
(batch_size,) + image.shape, dtype=np.float32)
1740
batch_gt_class_ids = np.zeros(
1741
(batch_size, config.MAX_GT_INSTANCES), dtype=np.int32)
1742
batch_gt_boxes = np.zeros(
1743
(batch_size, config.MAX_GT_INSTANCES, 4), dtype=np.int32)
1744
batch_gt_masks = np.zeros(
1745
(batch_size, gt_masks.shape[0], gt_masks.shape[1],
1746
config.MAX_GT_INSTANCES), dtype=gt_masks.dtype)
1747
if random_rois:
1748
batch_rpn_rois = np.zeros(
1749
(batch_size, rpn_rois.shape[0], 4), dtype=rpn_rois.dtype)
1750
if detection_targets:
1751
batch_rois = np.zeros(
1752
(batch_size,) + rois.shape, dtype=rois.dtype)
1753
batch_mrcnn_class_ids = np.zeros(
1754
(batch_size,) + mrcnn_class_ids.shape, dtype=mrcnn_class_ids.dtype)
1755
batch_mrcnn_bbox = np.zeros(
1756
(batch_size,) + mrcnn_bbox.shape, dtype=mrcnn_bbox.dtype)
1757
batch_mrcnn_mask = np.zeros(
1758
(batch_size,) + mrcnn_mask.shape, dtype=mrcnn_mask.dtype)
1759
1760
# If more instances than fits in the array, sub-sample from them.
1761
if gt_boxes.shape[0] > config.MAX_GT_INSTANCES:
1762
ids = np.random.choice(
1763
np.arange(gt_boxes.shape[0]), config.MAX_GT_INSTANCES, replace=False)
1764
gt_class_ids = gt_class_ids[ids]
1765
gt_boxes = gt_boxes[ids]
1766
gt_masks = gt_masks[:, :, ids]
1767
1768
# Add to batch
1769
batch_image_meta[b] = image_meta
1770
batch_rpn_match[b] = rpn_match[:, np.newaxis]
1771
batch_rpn_bbox[b] = rpn_bbox
1772
batch_images[b] = mold_image(image.astype(np.float32), config)
1773
batch_gt_class_ids[b, :gt_class_ids.shape[0]] = gt_class_ids
1774
batch_gt_boxes[b, :gt_boxes.shape[0]] = gt_boxes
1775
batch_gt_masks[b, :, :, :gt_masks.shape[-1]] = gt_masks
1776
if random_rois:
1777
batch_rpn_rois[b] = rpn_rois
1778
if detection_targets:
1779
batch_rois[b] = rois
1780
batch_mrcnn_class_ids[b] = mrcnn_class_ids
1781
batch_mrcnn_bbox[b] = mrcnn_bbox
1782
batch_mrcnn_mask[b] = mrcnn_mask
1783
b += 1
1784
1785
# Batch full?
1786
if b >= batch_size:
1787
inputs = [batch_images, batch_image_meta, batch_rpn_match, batch_rpn_bbox,
1788
batch_gt_class_ids, batch_gt_boxes, batch_gt_masks]
1789
outputs = []
1790
1791
if random_rois:
1792
inputs.extend([batch_rpn_rois])
1793
if detection_targets:
1794
inputs.extend([batch_rois])
1795
# Keras requires that output and targets have the same number of dimensions
1796
batch_mrcnn_class_ids = np.expand_dims(
1797
batch_mrcnn_class_ids, -1)
1798
outputs.extend(
1799
[batch_mrcnn_class_ids, batch_mrcnn_bbox, batch_mrcnn_mask])
1800
1801
yield inputs, outputs
1802
1803
# start a new batch
1804
b = 0
1805
except (GeneratorExit, KeyboardInterrupt):
1806
raise
1807
except:
1808
# Log it and skip the image
1809
logging.exception("Error processing image {}".format(
1810
dataset.image_info[image_id]))
1811
error_count += 1
1812
if error_count > 5:
1813
raise
1814
1815
1816
############################################################
1817
# MaskRCNN Class
1818
############################################################
1819
1820
class MaskRCNN():
1821
"""Encapsulates the Mask RCNN model functionality.
1822
1823
The actual Keras model is in the keras_model property.
1824
"""
1825
1826
def __init__(self, mode, config, model_dir):
1827
"""
1828
mode: Either "training" or "inference"
1829
config: A Sub-class of the Config class
1830
model_dir: Directory to save training logs and trained weights
1831
"""
1832
assert mode in ['training', 'inference']
1833
self.mode = mode
1834
self.config = config
1835
self.model_dir = model_dir
1836
self.set_log_dir()
1837
self.keras_model = self.build(mode=mode, config=config)
1838
1839
def build(self, mode, config):
1840
"""Build Mask R-CNN architecture.
1841
input_shape: The shape of the input image.
1842
mode: Either "training" or "inference". The inputs and
1843
outputs of the model differ accordingly.
1844
"""
1845
assert mode in ['training', 'inference']
1846
1847
# Image size must be dividable by 2 multiple times
1848
h, w = config.IMAGE_SHAPE[:2]
1849
if h / 2**6 != int(h / 2**6) or w / 2**6 != int(w / 2**6):
1850
raise Exception("Image size must be dividable by 2 at least 6 times "
1851
"to avoid fractions when downscaling and upscaling."
1852
"For example, use 256, 320, 384, 448, 512, ... etc. ")
1853
1854
# Inputs
1855
input_image = KL.Input(
1856
shape=[None, None, config.IMAGE_SHAPE[2]], name="input_image")
1857
input_image_meta = KL.Input(shape=[config.IMAGE_META_SIZE],
1858
name="input_image_meta")
1859
if mode == "training":
1860
# RPN GT
1861
input_rpn_match = KL.Input(
1862
shape=[None, 1], name="input_rpn_match", dtype=tf.int32)
1863
input_rpn_bbox = KL.Input(
1864
shape=[None, 4], name="input_rpn_bbox", dtype=tf.float32)
1865
1866
# Detection GT (class IDs, bounding boxes, and masks)
1867
# 1. GT Class IDs (zero padded)
1868
input_gt_class_ids = KL.Input(
1869
shape=[None], name="input_gt_class_ids", dtype=tf.int32)
1870
# 2. GT Boxes in pixels (zero padded)
1871
# [batch, MAX_GT_INSTANCES, (y1, x1, y2, x2)] in image coordinates
1872
input_gt_boxes = KL.Input(
1873
shape=[None, 4], name="input_gt_boxes", dtype=tf.float32)
1874
# Normalize coordinates
1875
gt_boxes = KL.Lambda(lambda x: norm_boxes_graph(
1876
x, K.shape(input_image)[1:3]))(input_gt_boxes)
1877
# 3. GT Masks (zero padded)
1878
# [batch, height, width, MAX_GT_INSTANCES]
1879
if config.USE_MINI_MASK:
1880
input_gt_masks = KL.Input(
1881
shape=[config.MINI_MASK_SHAPE[0],
1882
config.MINI_MASK_SHAPE[1], None],
1883
name="input_gt_masks", dtype=bool)
1884
else:
1885
input_gt_masks = KL.Input(
1886
shape=[config.IMAGE_SHAPE[0], config.IMAGE_SHAPE[1], None],
1887
name="input_gt_masks", dtype=bool)
1888
elif mode == "inference":
1889
# Anchors in normalized coordinates
1890
input_anchors = KL.Input(shape=[None, 4], name="input_anchors")
1891
1892
# Build the shared convolutional layers.
1893
# Bottom-up Layers
1894
# Returns a list of the last layers of each stage, 5 in total.
1895
# Don't create the thead (stage 5), so we pick the 4th item in the list.
1896
if callable(config.BACKBONE):
1897
_, C2, C3, C4, C5 = config.BACKBONE(input_image, stage5=True,
1898
train_bn=config.TRAIN_BN)
1899
else:
1900
_, C2, C3, C4, C5 = resnet_graph(input_image, config.BACKBONE,
1901
stage5=True, train_bn=config.TRAIN_BN)
1902
# Top-down Layers
1903
# TODO: add assert to varify feature map sizes match what's in config
1904
P5 = KL.Conv2D(config.TOP_DOWN_PYRAMID_SIZE, (1, 1), name='fpn_c5p5')(C5)
1905
P4 = KL.Add(name="fpn_p4add")([
1906
KL.UpSampling2D(size=(2, 2), name="fpn_p5upsampled")(P5),
1907
KL.Conv2D(config.TOP_DOWN_PYRAMID_SIZE, (1, 1), name='fpn_c4p4')(C4)])
1908
P3 = KL.Add(name="fpn_p3add")([
1909
KL.UpSampling2D(size=(2, 2), name="fpn_p4upsampled")(P4),
1910
KL.Conv2D(config.TOP_DOWN_PYRAMID_SIZE, (1, 1), name='fpn_c3p3')(C3)])
1911
P2 = KL.Add(name="fpn_p2add")([
1912
KL.UpSampling2D(size=(2, 2), name="fpn_p3upsampled")(P3),
1913
KL.Conv2D(config.TOP_DOWN_PYRAMID_SIZE, (1, 1), name='fpn_c2p2')(C2)])
1914
# Attach 3x3 conv to all P layers to get the final feature maps.
1915
P2 = KL.Conv2D(config.TOP_DOWN_PYRAMID_SIZE, (3, 3), padding="SAME", name="fpn_p2")(P2)
1916
P3 = KL.Conv2D(config.TOP_DOWN_PYRAMID_SIZE, (3, 3), padding="SAME", name="fpn_p3")(P3)
1917
P4 = KL.Conv2D(config.TOP_DOWN_PYRAMID_SIZE, (3, 3), padding="SAME", name="fpn_p4")(P4)
1918
P5 = KL.Conv2D(config.TOP_DOWN_PYRAMID_SIZE, (3, 3), padding="SAME", name="fpn_p5")(P5)
1919
# P6 is used for the 5th anchor scale in RPN. Generated by
1920
# subsampling from P5 with stride of 2.
1921
P6 = KL.MaxPooling2D(pool_size=(1, 1), strides=2, name="fpn_p6")(P5)
1922
1923
# Note that P6 is used in RPN, but not in the classifier heads.
1924
rpn_feature_maps = [P2, P3, P4, P5, P6]
1925
mrcnn_feature_maps = [P2, P3, P4, P5]
1926
1927
# Anchors
1928
if mode == "training":
1929
anchors = self.get_anchors(config.IMAGE_SHAPE)
1930
# Duplicate across the batch dimension because Keras requires it
1931
# TODO: can this be optimized to avoid duplicating the anchors?
1932
anchors = np.broadcast_to(anchors, (config.BATCH_SIZE,) + anchors.shape)
1933
# A hack to get around Keras's bad support for constants
1934
anchors = KL.Lambda(lambda x: tf.Variable(anchors), name="anchors")(input_image)
1935
else:
1936
anchors = input_anchors
1937
1938
# RPN Model
1939
rpn = build_rpn_model(config.RPN_ANCHOR_STRIDE,
1940
len(config.RPN_ANCHOR_RATIOS), config.TOP_DOWN_PYRAMID_SIZE)
1941
# Loop through pyramid layers
1942
layer_outputs = [] # list of lists
1943
for p in rpn_feature_maps:
1944
layer_outputs.append(rpn([p]))
1945
# Concatenate layer outputs
1946
# Convert from list of lists of level outputs to list of lists
1947
# of outputs across levels.
1948
# e.g. [[a1, b1, c1], [a2, b2, c2]] => [[a1, a2], [b1, b2], [c1, c2]]
1949
output_names = ["rpn_class_logits", "rpn_class", "rpn_bbox"]
1950
outputs = list(zip(*layer_outputs))
1951
outputs = [KL.Concatenate(axis=1, name=n)(list(o))
1952
for o, n in zip(outputs, output_names)]
1953
1954
rpn_class_logits, rpn_class, rpn_bbox = outputs
1955
1956
# Generate proposals
1957
# Proposals are [batch, N, (y1, x1, y2, x2)] in normalized coordinates
1958
# and zero padded.
1959
proposal_count = config.POST_NMS_ROIS_TRAINING if mode == "training"\
1960
else config.POST_NMS_ROIS_INFERENCE
1961
rpn_rois = ProposalLayer(
1962
proposal_count=proposal_count,
1963
nms_threshold=config.RPN_NMS_THRESHOLD,
1964
name="ROI",
1965
config=config)([rpn_class, rpn_bbox, anchors])
1966
1967
if mode == "training":
1968
# Class ID mask to mark class IDs supported by the dataset the image
1969
# came from.
1970
active_class_ids = KL.Lambda(
1971
lambda x: parse_image_meta_graph(x)["active_class_ids"]
1972
)(input_image_meta)
1973
1974
if not config.USE_RPN_ROIS:
1975
# Ignore predicted ROIs and use ROIs provided as an input.
1976
input_rois = KL.Input(shape=[config.POST_NMS_ROIS_TRAINING, 4],
1977
name="input_roi", dtype=np.int32)
1978
# Normalize coordinates
1979
target_rois = KL.Lambda(lambda x: norm_boxes_graph(
1980
x, K.shape(input_image)[1:3]))(input_rois)
1981
else:
1982
target_rois = rpn_rois
1983
1984
# Generate detection targets
1985
# Subsamples proposals and generates target outputs for training
1986
# Note that proposal class IDs, gt_boxes, and gt_masks are zero
1987
# padded. Equally, returned rois and targets are zero padded.
1988
rois, target_class_ids, target_bbox, target_mask =\
1989
DetectionTargetLayer(config, name="proposal_targets")([
1990
target_rois, input_gt_class_ids, gt_boxes, input_gt_masks])
1991
1992
# Network Heads
1993
# TODO: verify that this handles zero padded ROIs
1994
mrcnn_class_logits, mrcnn_class, mrcnn_bbox =\
1995
fpn_classifier_graph(rois, mrcnn_feature_maps, input_image_meta,
1996
config.POOL_SIZE, config.NUM_CLASSES,
1997
train_bn=config.TRAIN_BN,
1998
fc_layers_size=config.FPN_CLASSIF_FC_LAYERS_SIZE)
1999
2000
mrcnn_mask = build_fpn_mask_graph(rois, mrcnn_feature_maps,
2001
input_image_meta,
2002
config.MASK_POOL_SIZE,
2003
config.NUM_CLASSES,
2004
train_bn=config.TRAIN_BN)
2005
2006
# TODO: clean up (use tf.identify if necessary)
2007
output_rois = KL.Lambda(lambda x: x * 1, name="output_rois")(rois)
2008
2009
# Losses
2010
rpn_class_loss = KL.Lambda(lambda x: rpn_class_loss_graph(*x), name="rpn_class_loss")(
2011
[input_rpn_match, rpn_class_logits])
2012
rpn_bbox_loss = KL.Lambda(lambda x: rpn_bbox_loss_graph(config, *x), name="rpn_bbox_loss")(
2013
[input_rpn_bbox, input_rpn_match, rpn_bbox])
2014
class_loss = KL.Lambda(lambda x: mrcnn_class_loss_graph(*x), name="mrcnn_class_loss")(
2015
[target_class_ids, mrcnn_class_logits, active_class_ids])
2016
bbox_loss = KL.Lambda(lambda x: mrcnn_bbox_loss_graph(*x), name="mrcnn_bbox_loss")(
2017
[target_bbox, target_class_ids, mrcnn_bbox])
2018
mask_loss = KL.Lambda(lambda x: mrcnn_mask_loss_graph(*x), name="mrcnn_mask_loss")(
2019
[target_mask, target_class_ids, mrcnn_mask])
2020
2021
# Model
2022
inputs = [input_image, input_image_meta,
2023
input_rpn_match, input_rpn_bbox, input_gt_class_ids, input_gt_boxes, input_gt_masks]
2024
if not config.USE_RPN_ROIS:
2025
inputs.append(input_rois)
2026
outputs = [rpn_class_logits, rpn_class, rpn_bbox,
2027
mrcnn_class_logits, mrcnn_class, mrcnn_bbox, mrcnn_mask,
2028
rpn_rois, output_rois,
2029
rpn_class_loss, rpn_bbox_loss, class_loss, bbox_loss, mask_loss]
2030
model = KM.Model(inputs, outputs, name='mask_rcnn')
2031
else:
2032
# Network Heads
2033
# Proposal classifier and BBox regressor heads
2034
mrcnn_class_logits, mrcnn_class, mrcnn_bbox =\
2035
fpn_classifier_graph(rpn_rois, mrcnn_feature_maps, input_image_meta,
2036
config.POOL_SIZE, config.NUM_CLASSES,
2037
train_bn=config.TRAIN_BN,
2038
fc_layers_size=config.FPN_CLASSIF_FC_LAYERS_SIZE)
2039
2040
# Detections
2041
# output is [batch, num_detections, (y1, x1, y2, x2, class_id, score)] in
2042
# normalized coordinates
2043
detections = DetectionLayer(config, name="mrcnn_detection")(
2044
[rpn_rois, mrcnn_class, mrcnn_bbox, input_image_meta])
2045
2046
# Create masks for detections
2047
detection_boxes = KL.Lambda(lambda x: x[..., :4])(detections)
2048
mrcnn_mask = build_fpn_mask_graph(detection_boxes, mrcnn_feature_maps,
2049
input_image_meta,
2050
config.MASK_POOL_SIZE,
2051
config.NUM_CLASSES,
2052
train_bn=config.TRAIN_BN)
2053
2054
model = KM.Model([input_image, input_image_meta, input_anchors],
2055
[detections, mrcnn_class, mrcnn_bbox,
2056
mrcnn_mask, rpn_rois, rpn_class, rpn_bbox],
2057
name='mask_rcnn')
2058
2059
# Add multi-GPU support.
2060
if config.GPU_COUNT > 1:
2061
from mrcnn.parallel_model import ParallelModel
2062
model = ParallelModel(model, config.GPU_COUNT)
2063
2064
return model
2065
2066
def find_last(self):
2067
"""Finds the last checkpoint file of the last trained model in the
2068
model directory.
2069
Returns:
2070
The path of the last checkpoint file
2071
"""
2072
# Get directory names. Each directory corresponds to a model
2073
dir_names = next(os.walk(self.model_dir))[1]
2074
key = self.config.NAME.lower()
2075
dir_names = filter(lambda f: f.startswith(key), dir_names)
2076
dir_names = sorted(dir_names)
2077
if not dir_names:
2078
import errno
2079
raise FileNotFoundError(
2080
errno.ENOENT,
2081
"Could not find model directory under {}".format(self.model_dir))
2082
# Pick last directory
2083
dir_name = os.path.join(self.model_dir, dir_names[-1])
2084
# Find the last checkpoint
2085
checkpoints = next(os.walk(dir_name))[2]
2086
checkpoints = filter(lambda f: f.startswith("mask_rcnn"), checkpoints)
2087
checkpoints = sorted(checkpoints)
2088
if not checkpoints:
2089
import errno
2090
raise FileNotFoundError(
2091
errno.ENOENT, "Could not find weight files in {}".format(dir_name))
2092
checkpoint = os.path.join(dir_name, checkpoints[-1])
2093
return checkpoint
2094
2095
def load_weights(self, filepath, by_name=False, exclude=None):
2096
"""Modified version of the corresponding Keras function with
2097
the addition of multi-GPU support and the ability to exclude
2098
some layers from loading.
2099
exclude: list of layer names to exclude
2100
"""
2101
import h5py
2102
# Conditional import to support versions of Keras before 2.2
2103
# TODO: remove in about 6 months (end of 2018)
2104
try:
2105
from keras.engine import saving
2106
except ImportError:
2107
# Keras before 2.2 used the 'topology' namespace.
2108
from keras.engine import topology as saving
2109
2110
if exclude:
2111
by_name = True
2112
2113
if h5py is None:
2114
raise ImportError('`load_weights` requires h5py.')
2115
f = h5py.File(filepath, mode='r')
2116
if 'layer_names' not in f.attrs and 'model_weights' in f:
2117
f = f['model_weights']
2118
2119
# In multi-GPU training, we wrap the model. Get layers
2120
# of the inner model because they have the weights.
2121
keras_model = self.keras_model
2122
layers = keras_model.inner_model.layers if hasattr(keras_model, "inner_model")\
2123
else keras_model.layers
2124
2125
# Exclude some layers
2126
if exclude:
2127
layers = filter(lambda l: l.name not in exclude, layers)
2128
2129
if by_name:
2130
saving.load_weights_from_hdf5_group_by_name(f, layers)
2131
else:
2132
saving.load_weights_from_hdf5_group(f, layers)
2133
if hasattr(f, 'close'):
2134
f.close()
2135
2136
# Update the log directory
2137
self.set_log_dir(filepath)
2138
2139
def get_imagenet_weights(self):
2140
"""Downloads ImageNet trained weights from Keras.
2141
Returns path to weights file.
2142
"""
2143
from keras.utils.data_utils import get_file
2144
TF_WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/'\
2145
'releases/download/v0.2/'\
2146
'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
2147
weights_path = get_file('resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5',
2148
TF_WEIGHTS_PATH_NO_TOP,
2149
cache_subdir='models',
2150
md5_hash='a268eb855778b3df3c7506639542a6af')
2151
return weights_path
2152
2153
def compile(self, learning_rate, momentum):
2154
"""Gets the model ready for training. Adds losses, regularization, and
2155
metrics. Then calls the Keras compile() function.
2156
"""
2157
# Optimizer object
2158
optimizer = keras.optimizers.SGD(
2159
lr=learning_rate, momentum=momentum,
2160
clipnorm=self.config.GRADIENT_CLIP_NORM)
2161
# Add Losses
2162
# First, clear previously set losses to avoid duplication
2163
self.keras_model._losses = []
2164
self.keras_model._per_input_losses = {}
2165
loss_names = [
2166
"rpn_class_loss", "rpn_bbox_loss",
2167
"mrcnn_class_loss", "mrcnn_bbox_loss", "mrcnn_mask_loss"]
2168
for name in loss_names:
2169
layer = self.keras_model.get_layer(name)
2170
if layer.output in self.keras_model.losses:
2171
continue
2172
loss = (
2173
tf.reduce_mean(layer.output, keepdims=True)
2174
* self.config.LOSS_WEIGHTS.get(name, 1.))
2175
self.keras_model.add_loss(loss)
2176
2177
# Add L2 Regularization
2178
# Skip gamma and beta weights of batch normalization layers.
2179
reg_losses = [
2180
keras.regularizers.l2(self.config.WEIGHT_DECAY)(w) / tf.cast(tf.size(w), tf.float32)
2181
for w in self.keras_model.trainable_weights
2182
if 'gamma' not in w.name and 'beta' not in w.name]
2183
self.keras_model.add_loss(tf.add_n(reg_losses))
2184
2185
# Compile
2186
self.keras_model.compile(
2187
optimizer=optimizer,
2188
loss=[None] * len(self.keras_model.outputs))
2189
2190
# Add metrics for losses
2191
for name in loss_names:
2192
if name in self.keras_model.metrics_names:
2193
continue
2194
layer = self.keras_model.get_layer(name)
2195
self.keras_model.metrics_names.append(name)
2196
loss = (
2197
tf.reduce_mean(layer.output, keepdims=True)
2198
* self.config.LOSS_WEIGHTS.get(name, 1.))
2199
self.keras_model.metrics_tensors.append(loss)
2200
2201
def set_trainable(self, layer_regex, keras_model=None, indent=0, verbose=1):
2202
"""Sets model layers as trainable if their names match
2203
the given regular expression.
2204
"""
2205
# Print message on the first call (but not on recursive calls)
2206
if verbose > 0 and keras_model is None:
2207
log("Selecting layers to train")
2208
2209
keras_model = keras_model or self.keras_model
2210
2211
# In multi-GPU training, we wrap the model. Get layers
2212
# of the inner model because they have the weights.
2213
layers = keras_model.inner_model.layers if hasattr(keras_model, "inner_model")\
2214
else keras_model.layers
2215
2216
for layer in layers:
2217
# Is the layer a model?
2218
if layer.__class__.__name__ == 'Model':
2219
print("In model: ", layer.name)
2220
self.set_trainable(
2221
layer_regex, keras_model=layer, indent=indent + 4)
2222
continue
2223
2224
if not layer.weights:
2225
continue
2226
# Is it trainable?
2227
trainable = bool(re.fullmatch(layer_regex, layer.name))
2228
# Update layer. If layer is a container, update inner layer.
2229
if layer.__class__.__name__ == 'TimeDistributed':
2230
layer.layer.trainable = trainable
2231
else:
2232
layer.trainable = trainable
2233
# Print trainable layer names
2234
if trainable and verbose > 0:
2235
log("{}{:20} ({})".format(" " * indent, layer.name,
2236
layer.__class__.__name__))
2237
2238
def set_log_dir(self, model_path=None):
2239
"""Sets the model log directory and epoch counter.
2240
2241
model_path: If None, or a format different from what this code uses
2242
then set a new log directory and start epochs from 0. Otherwise,
2243
extract the log directory and the epoch counter from the file
2244
name.
2245
"""
2246
# Set date and epoch counter as if starting a new model
2247
self.epoch = 0
2248
now = datetime.datetime.now()
2249
2250
# If we have a model path with date and epochs use them
2251
if model_path:
2252
# Continue from we left of. Get epoch and date from the file name
2253
# A sample model path might look like:
2254
# \path\to\logs\coco20171029T2315\mask_rcnn_coco_0001.h5 (Windows)
2255
# /path/to/logs/coco20171029T2315/mask_rcnn_coco_0001.h5 (Linux)
2256
regex = r".*[/\\][\w-]+(\d{4})(\d{2})(\d{2})T(\d{2})(\d{2})[/\\]mask\_rcnn\_[\w-]+(\d{4})\.h5"
2257
m = re.match(regex, model_path)
2258
if m:
2259
now = datetime.datetime(int(m.group(1)), int(m.group(2)), int(m.group(3)),
2260
int(m.group(4)), int(m.group(5)))
2261
# Epoch number in file is 1-based, and in Keras code it's 0-based.
2262
# So, adjust for that then increment by one to start from the next epoch
2263
self.epoch = int(m.group(6)) - 1 + 1
2264
print('Re-starting from epoch %d' % self.epoch)
2265
2266
# Directory for training logs
2267
self.log_dir = os.path.join(self.model_dir, "{}{:%Y%m%dT%H%M}".format(
2268
self.config.NAME.lower(), now))
2269
2270
# Path to save after each epoch. Include placeholders that get filled by Keras.
2271
self.checkpoint_path = os.path.join(self.log_dir, "mask_rcnn_{}_*epoch*.h5".format(
2272
self.config.NAME.lower()))
2273
self.checkpoint_path = self.checkpoint_path.replace(
2274
"*epoch*", "{epoch:04d}")
2275
2276
def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
2277
augmentation=None, custom_callbacks=None, no_augmentation_sources=None):
2278
"""Train the model.
2279
train_dataset, val_dataset: Training and validation Dataset objects.
2280
learning_rate: The learning rate to train with
2281
epochs: Number of training epochs. Note that previous training epochs
2282
are considered to be done alreay, so this actually determines
2283
the epochs to train in total rather than in this particaular
2284
call.
2285
layers: Allows selecting wich layers to train. It can be:
2286
- A regular expression to match layer names to train
2287
- One of these predefined values:
2288
heads: The RPN, classifier and mask heads of the network
2289
all: All the layers
2290
3+: Train Resnet stage 3 and up
2291
4+: Train Resnet stage 4 and up
2292
5+: Train Resnet stage 5 and up
2293
augmentation: Optional. An imgaug (https://github.com/aleju/imgaug)
2294
augmentation. For example, passing imgaug.augmenters.Fliplr(0.5)
2295
flips images right/left 50% of the time. You can pass complex
2296
augmentations as well. This augmentation applies 50% of the
2297
time, and when it does it flips images right/left half the time
2298
and adds a Gaussian blur with a random sigma in range 0 to 5.
2299
2300
augmentation = imgaug.augmenters.Sometimes(0.5, [
2301
imgaug.augmenters.Fliplr(0.5),
2302
imgaug.augmenters.GaussianBlur(sigma=(0.0, 5.0))
2303
])
2304
custom_callbacks: Optional. Add custom callbacks to be called
2305
with the keras fit_generator method. Must be list of type keras.callbacks.
2306
no_augmentation_sources: Optional. List of sources to exclude for
2307
augmentation. A source is string that identifies a dataset and is
2308
defined in the Dataset class.
2309
"""
2310
assert self.mode == "training", "Create model in training mode."
2311
2312
# Pre-defined layer regular expressions
2313
layer_regex = {
2314
# all layers but the backbone
2315
"heads": r"(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
2316
# From a specific Resnet stage and up
2317
"3+": r"(res3.*)|(bn3.*)|(res4.*)|(bn4.*)|(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
2318
"4+": r"(res4.*)|(bn4.*)|(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
2319
"5+": r"(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
2320
# All layers
2321
"all": ".*",
2322
}
2323
if layers in layer_regex.keys():
2324
layers = layer_regex[layers]
2325
2326
# Data generators
2327
train_generator = data_generator(train_dataset, self.config, shuffle=True,
2328
augmentation=augmentation,
2329
batch_size=self.config.BATCH_SIZE,
2330
no_augmentation_sources=no_augmentation_sources)
2331
val_generator = data_generator(val_dataset, self.config, shuffle=True,
2332
batch_size=self.config.BATCH_SIZE)
2333
2334
# Create log_dir if it does not exist
2335
if not os.path.exists(self.log_dir):
2336
os.makedirs(self.log_dir)
2337
2338
# Callbacks
2339
callbacks = [
2340
keras.callbacks.TensorBoard(log_dir=self.log_dir,
2341
histogram_freq=0, write_graph=True, write_images=False),
2342
keras.callbacks.ModelCheckpoint(self.checkpoint_path,
2343
verbose=0, save_weights_only=True),
2344
]
2345
2346
# Add custom callbacks to the list
2347
if custom_callbacks:
2348
callbacks += custom_callbacks
2349
2350
# Train
2351
log("\nStarting at epoch {}. LR={}\n".format(self.epoch, learning_rate))
2352
log("Checkpoint Path: {}".format(self.checkpoint_path))
2353
self.set_trainable(layers)
2354
self.compile(learning_rate, self.config.LEARNING_MOMENTUM)
2355
2356
# Work-around for Windows: Keras fails on Windows when using
2357
# multiprocessing workers. See discussion here:
2358
# https://github.com/matterport/Mask_RCNN/issues/13#issuecomment-353124009
2359
if os.name is 'nt':
2360
workers = 0
2361
else:
2362
workers = multiprocessing.cpu_count()
2363
2364
self.keras_model.fit_generator(
2365
train_generator,
2366
initial_epoch=self.epoch,
2367
epochs=epochs,
2368
steps_per_epoch=self.config.STEPS_PER_EPOCH,
2369
callbacks=callbacks,
2370
validation_data=val_generator,
2371
validation_steps=self.config.VALIDATION_STEPS,
2372
max_queue_size=100,
2373
workers=workers,
2374
use_multiprocessing=True,
2375
)
2376
self.epoch = max(self.epoch, epochs)
2377
2378
def mold_inputs(self, images):
2379
"""Takes a list of images and modifies them to the format expected
2380
as an input to the neural network.
2381
images: List of image matrices [height,width,depth]. Images can have
2382
different sizes.
2383
2384
Returns 3 Numpy matrices:
2385
molded_images: [N, h, w, 3]. Images resized and normalized.
2386
image_metas: [N, length of meta data]. Details about each image.
2387
windows: [N, (y1, x1, y2, x2)]. The portion of the image that has the
2388
original image (padding excluded).
2389
"""
2390
molded_images = []
2391
image_metas = []
2392
windows = []
2393
for image in images:
2394
# Resize image
2395
# TODO: move resizing to mold_image()
2396
molded_image, window, scale, padding, crop = utils.resize_image(
2397
image,
2398
min_dim=self.config.IMAGE_MIN_DIM,
2399
min_scale=self.config.IMAGE_MIN_SCALE,
2400
max_dim=self.config.IMAGE_MAX_DIM,
2401
mode=self.config.IMAGE_RESIZE_MODE)
2402
molded_image = mold_image(molded_image, self.config)
2403
# Build image_meta
2404
image_meta = compose_image_meta(
2405
0, image.shape, molded_image.shape, window, scale,
2406
np.zeros([self.config.NUM_CLASSES], dtype=np.int32))
2407
# Append
2408
molded_images.append(molded_image)
2409
windows.append(window)
2410
image_metas.append(image_meta)
2411
# Pack into arrays
2412
molded_images = np.stack(molded_images)
2413
image_metas = np.stack(image_metas)
2414
windows = np.stack(windows)
2415
return molded_images, image_metas, windows
2416
2417
def unmold_detections(self, detections, mrcnn_mask, original_image_shape,
2418
image_shape, window):
2419
"""Reformats the detections of one image from the format of the neural
2420
network output to a format suitable for use in the rest of the
2421
application.
2422
2423
detections: [N, (y1, x1, y2, x2, class_id, score)] in normalized coordinates
2424
mrcnn_mask: [N, height, width, num_classes]
2425
original_image_shape: [H, W, C] Original image shape before resizing
2426
image_shape: [H, W, C] Shape of the image after resizing and padding
2427
window: [y1, x1, y2, x2] Pixel coordinates of box in the image where the real
2428
image is excluding the padding.
2429
2430
Returns:
2431
boxes: [N, (y1, x1, y2, x2)] Bounding boxes in pixels
2432
class_ids: [N] Integer class IDs for each bounding box
2433
scores: [N] Float probability scores of the class_id
2434
masks: [height, width, num_instances] Instance masks
2435
"""
2436
# How many detections do we have?
2437
# Detections array is padded with zeros. Find the first class_id == 0.
2438
zero_ix = np.where(detections[:, 4] == 0)[0]
2439
N = zero_ix[0] if zero_ix.shape[0] > 0 else detections.shape[0]
2440
2441
# Extract boxes, class_ids, scores, and class-specific masks
2442
boxes = detections[:N, :4]
2443
class_ids = detections[:N, 4].astype(np.int32)
2444
scores = detections[:N, 5]
2445
masks = mrcnn_mask[np.arange(N), :, :, class_ids]
2446
2447
# Translate normalized coordinates in the resized image to pixel
2448
# coordinates in the original image before resizing
2449
window = utils.norm_boxes(window, image_shape[:2])
2450
wy1, wx1, wy2, wx2 = window
2451
shift = np.array([wy1, wx1, wy1, wx1])
2452
wh = wy2 - wy1 # window height
2453
ww = wx2 - wx1 # window width
2454
scale = np.array([wh, ww, wh, ww])
2455
# Convert boxes to normalized coordinates on the window
2456
boxes = np.divide(boxes - shift, scale)
2457
# Convert boxes to pixel coordinates on the original image
2458
boxes = utils.denorm_boxes(boxes, original_image_shape[:2])
2459
2460
# Filter out detections with zero area. Happens in early training when
2461
# network weights are still random
2462
exclude_ix = np.where(
2463
(boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) <= 0)[0]
2464
if exclude_ix.shape[0] > 0:
2465
boxes = np.delete(boxes, exclude_ix, axis=0)
2466
class_ids = np.delete(class_ids, exclude_ix, axis=0)
2467
scores = np.delete(scores, exclude_ix, axis=0)
2468
masks = np.delete(masks, exclude_ix, axis=0)
2469
N = class_ids.shape[0]
2470
2471
# Resize masks to original image size and set boundary threshold.
2472
full_masks = []
2473
for i in range(N):
2474
# Convert neural network mask to full size mask
2475
full_mask = utils.unmold_mask(masks[i], boxes[i], original_image_shape)
2476
full_masks.append(full_mask)
2477
full_masks = np.stack(full_masks, axis=-1)\
2478
if full_masks else np.empty(original_image_shape[:2] + (0,))
2479
2480
return boxes, class_ids, scores, full_masks
2481
2482
def detect(self, images, verbose=0):
2483
"""Runs the detection pipeline.
2484
2485
images: List of images, potentially of different sizes.
2486
2487
Returns a list of dicts, one dict per image. The dict contains:
2488
rois: [N, (y1, x1, y2, x2)] detection bounding boxes
2489
class_ids: [N] int class IDs
2490
scores: [N] float probability scores for the class IDs
2491
masks: [H, W, N] instance binary masks
2492
"""
2493
assert self.mode == "inference", "Create model in inference mode."
2494
assert len(
2495
images) == self.config.BATCH_SIZE, "len(images) must be equal to BATCH_SIZE"
2496
2497
if verbose:
2498
log("Processing {} images".format(len(images)))
2499
for image in images:
2500
log("image", image)
2501
2502
# Mold inputs to format expected by the neural network
2503
molded_images, image_metas, windows = self.mold_inputs(images)
2504
2505
# Validate image sizes
2506
# All images in a batch MUST be of the same size
2507
image_shape = molded_images[0].shape
2508
for g in molded_images[1:]:
2509
assert g.shape == image_shape,\
2510
"After resizing, all images must have the same size. Check IMAGE_RESIZE_MODE and image sizes."
2511
2512
# Anchors
2513
anchors = self.get_anchors(image_shape)
2514
# Duplicate across the batch dimension because Keras requires it
2515
# TODO: can this be optimized to avoid duplicating the anchors?
2516
anchors = np.broadcast_to(anchors, (self.config.BATCH_SIZE,) + anchors.shape)
2517
2518
if verbose:
2519
log("molded_images", molded_images)
2520
log("image_metas", image_metas)
2521
log("anchors", anchors)
2522
# Run object detection
2523
detections, _, _, mrcnn_mask, _, _, _ =\
2524
self.keras_model.predict([molded_images, image_metas, anchors], verbose=0)
2525
# Process detections
2526
results = []
2527
for i, image in enumerate(images):
2528
final_rois, final_class_ids, final_scores, final_masks =\
2529
self.unmold_detections(detections[i], mrcnn_mask[i],
2530
image.shape, molded_images[i].shape,
2531
windows[i])
2532
results.append({
2533
"rois": final_rois,
2534
"class_ids": final_class_ids,
2535
"scores": final_scores,
2536
"masks": final_masks,
2537
})
2538
return results
2539
2540
def detect_molded(self, molded_images, image_metas, verbose=0):
2541
"""Runs the detection pipeline, but expect inputs that are
2542
molded already. Used mostly for debugging and inspecting
2543
the model.
2544
2545
molded_images: List of images loaded using load_image_gt()
2546
image_metas: image meta data, also returned by load_image_gt()
2547
2548
Returns a list of dicts, one dict per image. The dict contains:
2549
rois: [N, (y1, x1, y2, x2)] detection bounding boxes
2550
class_ids: [N] int class IDs
2551
scores: [N] float probability scores for the class IDs
2552
masks: [H, W, N] instance binary masks
2553
"""
2554
assert self.mode == "inference", "Create model in inference mode."
2555
assert len(molded_images) == self.config.BATCH_SIZE,\
2556
"Number of images must be equal to BATCH_SIZE"
2557
2558
if verbose:
2559
log("Processing {} images".format(len(molded_images)))
2560
for image in molded_images:
2561
log("image", image)
2562
2563
# Validate image sizes
2564
# All images in a batch MUST be of the same size
2565
image_shape = molded_images[0].shape
2566
for g in molded_images[1:]:
2567
assert g.shape == image_shape, "Images must have the same size"
2568
2569
# Anchors
2570
anchors = self.get_anchors(image_shape)
2571
# Duplicate across the batch dimension because Keras requires it
2572
# TODO: can this be optimized to avoid duplicating the anchors?
2573
anchors = np.broadcast_to(anchors, (self.config.BATCH_SIZE,) + anchors.shape)
2574
2575
if verbose:
2576
log("molded_images", molded_images)
2577
log("image_metas", image_metas)
2578
log("anchors", anchors)
2579
# Run object detection
2580
detections, _, _, mrcnn_mask, _, _, _ =\
2581
self.keras_model.predict([molded_images, image_metas, anchors], verbose=0)
2582
# Process detections
2583
results = []
2584
for i, image in enumerate(molded_images):
2585
window = [0, 0, image.shape[0], image.shape[1]]
2586
final_rois, final_class_ids, final_scores, final_masks =\
2587
self.unmold_detections(detections[i], mrcnn_mask[i],
2588
image.shape, molded_images[i].shape,
2589
window)
2590
results.append({
2591
"rois": final_rois,
2592
"class_ids": final_class_ids,
2593
"scores": final_scores,
2594
"masks": final_masks,
2595
})
2596
return results
2597
2598
def get_anchors(self, image_shape):
2599
"""Returns anchor pyramid for the given image size."""
2600
backbone_shapes = compute_backbone_shapes(self.config, image_shape)
2601
# Cache anchors and reuse if image shape is the same
2602
if not hasattr(self, "_anchor_cache"):
2603
self._anchor_cache = {}
2604
if not tuple(image_shape) in self._anchor_cache:
2605
# Generate Anchors
2606
a = utils.generate_pyramid_anchors(
2607
self.config.RPN_ANCHOR_SCALES,
2608
self.config.RPN_ANCHOR_RATIOS,
2609
backbone_shapes,
2610
self.config.BACKBONE_STRIDES,
2611
self.config.RPN_ANCHOR_STRIDE)
2612
# Keep a copy of the latest anchors in pixel coordinates because
2613
# it's used in inspect_model notebooks.
2614
# TODO: Remove this after the notebook are refactored to not use it
2615
self.anchors = a
2616
# Normalize coordinates
2617
self._anchor_cache[tuple(image_shape)] = utils.norm_boxes(a, image_shape[:2])
2618
return self._anchor_cache[tuple(image_shape)]
2619
2620
def ancestor(self, tensor, name, checked=None):
2621
"""Finds the ancestor of a TF tensor in the computation graph.
2622
tensor: TensorFlow symbolic tensor.
2623
name: Name of ancestor tensor to find
2624
checked: For internal use. A list of tensors that were already
2625
searched to avoid loops in traversing the graph.
2626
"""
2627
checked = checked if checked is not None else []
2628
# Put a limit on how deep we go to avoid very long loops
2629
if len(checked) > 500:
2630
return None
2631
# Convert name to a regex and allow matching a number prefix
2632
# because Keras adds them automatically
2633
if isinstance(name, str):
2634
name = re.compile(name.replace("/", r"(\_\d+)*/"))
2635
2636
parents = tensor.op.inputs
2637
for p in parents:
2638
if p in checked:
2639
continue
2640
if bool(re.fullmatch(name, p.name)):
2641
return p
2642
checked.append(p)
2643
a = self.ancestor(p, name, checked)
2644
if a is not None:
2645
return a
2646
return None
2647
2648
def find_trainable_layer(self, layer):
2649
"""If a layer is encapsulated by another layer, this function
2650
digs through the encapsulation and returns the layer that holds
2651
the weights.
2652
"""
2653
if layer.__class__.__name__ == 'TimeDistributed':
2654
return self.find_trainable_layer(layer.layer)
2655
return layer
2656
2657
def get_trainable_layers(self):
2658
"""Returns a list of layers that have weights."""
2659
layers = []
2660
# Loop through all layers
2661
for l in self.keras_model.layers:
2662
# If layer is a wrapper, find inner trainable layer
2663
l = self.find_trainable_layer(l)
2664
# Include layer if it has weights
2665
if l.get_weights():
2666
layers.append(l)
2667
return layers
2668
2669
def run_graph(self, images, outputs, image_metas=None):
2670
"""Runs a sub-set of the computation graph that computes the given
2671
outputs.
2672
2673
image_metas: If provided, the images are assumed to be already
2674
molded (i.e. resized, padded, and normalized)
2675
2676
outputs: List of tuples (name, tensor) to compute. The tensors are
2677
symbolic TensorFlow tensors and the names are for easy tracking.
2678
2679
Returns an ordered dict of results. Keys are the names received in the
2680
input and values are Numpy arrays.
2681
"""
2682
model = self.keras_model
2683
2684
# Organize desired outputs into an ordered dict
2685
outputs = OrderedDict(outputs)
2686
for o in outputs.values():
2687
assert o is not None
2688
2689
# Build a Keras function to run parts of the computation graph
2690
inputs = model.inputs
2691
if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
2692
inputs += [K.learning_phase()]
2693
kf = K.function(model.inputs, list(outputs.values()))
2694
2695
# Prepare inputs
2696
if image_metas is None:
2697
molded_images, image_metas, _ = self.mold_inputs(images)
2698
else:
2699
molded_images = images
2700
image_shape = molded_images[0].shape
2701
# Anchors
2702
anchors = self.get_anchors(image_shape)
2703
# Duplicate across the batch dimension because Keras requires it
2704
# TODO: can this be optimized to avoid duplicating the anchors?
2705
anchors = np.broadcast_to(anchors, (self.config.BATCH_SIZE,) + anchors.shape)
2706
model_in = [molded_images, image_metas, anchors]
2707
2708
# Run inference
2709
if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
2710
model_in.append(0.)
2711
outputs_np = kf(model_in)
2712
2713
# Pack the generated Numpy arrays into a a dict and log the results.
2714
outputs_np = OrderedDict([(k, v)
2715
for k, v in zip(outputs.keys(), outputs_np)])
2716
for k, v in outputs_np.items():
2717
log(k, v)
2718
return outputs_np
2719
2720
2721
############################################################
2722
# Data Formatting
2723
############################################################
2724
2725
def compose_image_meta(image_id, original_image_shape, image_shape,
2726
window, scale, active_class_ids):
2727
"""Takes attributes of an image and puts them in one 1D array.
2728
2729
image_id: An int ID of the image. Useful for debugging.
2730
original_image_shape: [H, W, C] before resizing or padding.
2731
image_shape: [H, W, C] after resizing and padding
2732
window: (y1, x1, y2, x2) in pixels. The area of the image where the real
2733
image is (excluding the padding)
2734
scale: The scaling factor applied to the original image (float32)
2735
active_class_ids: List of class_ids available in the dataset from which
2736
the image came. Useful if training on images from multiple datasets
2737
where not all classes are present in all datasets.
2738
"""
2739
meta = np.array(
2740
[image_id] + # size=1
2741
list(original_image_shape) + # size=3
2742
list(image_shape) + # size=3
2743
list(window) + # size=4 (y1, x1, y2, x2) in image cooredinates
2744
[scale] + # size=1
2745
list(active_class_ids) # size=num_classes
2746
)
2747
return meta
2748
2749
2750
def parse_image_meta(meta):
2751
"""Parses an array that contains image attributes to its components.
2752
See compose_image_meta() for more details.
2753
2754
meta: [batch, meta length] where meta length depends on NUM_CLASSES
2755
2756
Returns a dict of the parsed values.
2757
"""
2758
image_id = meta[:, 0]
2759
original_image_shape = meta[:, 1:4]
2760
image_shape = meta[:, 4:7]
2761
window = meta[:, 7:11] # (y1, x1, y2, x2) window of image in in pixels
2762
scale = meta[:, 11]
2763
active_class_ids = meta[:, 12:]
2764
return {
2765
"image_id": image_id.astype(np.int32),
2766
"original_image_shape": original_image_shape.astype(np.int32),
2767
"image_shape": image_shape.astype(np.int32),
2768
"window": window.astype(np.int32),
2769
"scale": scale.astype(np.float32),
2770
"active_class_ids": active_class_ids.astype(np.int32),
2771
}
2772
2773
2774
def parse_image_meta_graph(meta):
2775
"""Parses a tensor that contains image attributes to its components.
2776
See compose_image_meta() for more details.
2777
2778
meta: [batch, meta length] where meta length depends on NUM_CLASSES
2779
2780
Returns a dict of the parsed tensors.
2781
"""
2782
image_id = meta[:, 0]
2783
original_image_shape = meta[:, 1:4]
2784
image_shape = meta[:, 4:7]
2785
window = meta[:, 7:11] # (y1, x1, y2, x2) window of image in in pixels
2786
scale = meta[:, 11]
2787
active_class_ids = meta[:, 12:]
2788
return {
2789
"image_id": image_id,
2790
"original_image_shape": original_image_shape,
2791
"image_shape": image_shape,
2792
"window": window,
2793
"scale": scale,
2794
"active_class_ids": active_class_ids,
2795
}
2796
2797
2798
def mold_image(images, config):
2799
"""Expects an RGB image (or array of images) and subtracts
2800
the mean pixel and converts it to float. Expects image
2801
colors in RGB order.
2802
"""
2803
return images.astype(np.float32) - config.MEAN_PIXEL
2804
2805
2806
def unmold_image(normalized_images, config):
2807
"""Takes a image normalized with mold() and returns the original."""
2808
return (normalized_images + config.MEAN_PIXEL).astype(np.uint8)
2809
2810
2811
############################################################
2812
# Miscellenous Graph Functions
2813
############################################################
2814
2815
def trim_zeros_graph(boxes, name='trim_zeros'):
2816
"""Often boxes are represented with matrices of shape [N, 4] and
2817
are padded with zeros. This removes zero boxes.
2818
2819
boxes: [N, 4] matrix of boxes.
2820
non_zeros: [N] a 1D boolean mask identifying the rows to keep
2821
"""
2822
non_zeros = tf.cast(tf.reduce_sum(tf.abs(boxes), axis=1), tf.bool)
2823
boxes = tf.boolean_mask(boxes, non_zeros, name=name)
2824
return boxes, non_zeros
2825
2826
2827
def batch_pack_graph(x, counts, num_rows):
2828
"""Picks different number of values from each row
2829
in x depending on the values in counts.
2830
"""
2831
outputs = []
2832
for i in range(num_rows):
2833
outputs.append(x[i, :counts[i]])
2834
return tf.concat(outputs, axis=0)
2835
2836
2837
def norm_boxes_graph(boxes, shape):
2838
"""Converts boxes from pixel coordinates to normalized coordinates.
2839
boxes: [..., (y1, x1, y2, x2)] in pixel coordinates
2840
shape: [..., (height, width)] in pixels
2841
2842
Note: In pixel coordinates (y2, x2) is outside the box. But in normalized
2843
coordinates it's inside the box.
2844
2845
Returns:
2846
[..., (y1, x1, y2, x2)] in normalized coordinates
2847
"""
2848
h, w = tf.split(tf.cast(shape, tf.float32), 2)
2849
scale = tf.concat([h, w, h, w], axis=-1) - tf.constant(1.0)
2850
shift = tf.constant([0., 0., 1., 1.])
2851
return tf.divide(boxes - shift, scale)
2852
2853
2854
def denorm_boxes_graph(boxes, shape):
2855
"""Converts boxes from normalized coordinates to pixel coordinates.
2856
boxes: [..., (y1, x1, y2, x2)] in normalized coordinates
2857
shape: [..., (height, width)] in pixels
2858
2859
Note: In pixel coordinates (y2, x2) is outside the box. But in normalized
2860
coordinates it's inside the box.
2861
2862
Returns:
2863
[..., (y1, x1, y2, x2)] in pixel coordinates
2864
"""
2865
h, w = tf.split(tf.cast(shape, tf.float32), 2)
2866
scale = tf.concat([h, w, h, w], axis=-1) - tf.constant(1.0)
2867
shift = tf.constant([0., 0., 1., 1.])
2868
return tf.cast(tf.round(tf.multiply(boxes, scale) + shift), tf.int32)
2869
2870