Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
matterport
GitHub Repository: matterport/Mask_RCNN
Path: blob/master/samples/nucleus/nucleus.py
240 views
1
"""
2
Mask R-CNN
3
Train on the nuclei segmentation dataset from the
4
Kaggle 2018 Data Science Bowl
5
https://www.kaggle.com/c/data-science-bowl-2018/
6
7
Licensed under the MIT License (see LICENSE for details)
8
Written by Waleed Abdulla
9
10
------------------------------------------------------------
11
12
Usage: import the module (see Jupyter notebooks for examples), or run from
13
the command line as such:
14
15
# Train a new model starting from ImageNet weights
16
python3 nucleus.py train --dataset=/path/to/dataset --subset=train --weights=imagenet
17
18
# Train a new model starting from specific weights file
19
python3 nucleus.py train --dataset=/path/to/dataset --subset=train --weights=/path/to/weights.h5
20
21
# Resume training a model that you had trained earlier
22
python3 nucleus.py train --dataset=/path/to/dataset --subset=train --weights=last
23
24
# Generate submission file
25
python3 nucleus.py detect --dataset=/path/to/dataset --subset=train --weights=<last or /path/to/weights.h5>
26
"""
27
28
# Set matplotlib backend
29
# This has to be done before other importa that might
30
# set it, but only if we're running in script mode
31
# rather than being imported.
32
if __name__ == '__main__':
33
import matplotlib
34
# Agg backend runs without a display
35
matplotlib.use('Agg')
36
import matplotlib.pyplot as plt
37
38
import os
39
import sys
40
import json
41
import datetime
42
import numpy as np
43
import skimage.io
44
from imgaug import augmenters as iaa
45
46
# Root directory of the project
47
ROOT_DIR = os.path.abspath("../../")
48
49
# Import Mask RCNN
50
sys.path.append(ROOT_DIR) # To find local version of the library
51
from mrcnn.config import Config
52
from mrcnn import utils
53
from mrcnn import model as modellib
54
from mrcnn import visualize
55
56
# Path to trained weights file
57
COCO_WEIGHTS_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
58
59
# Directory to save logs and model checkpoints, if not provided
60
# through the command line argument --logs
61
DEFAULT_LOGS_DIR = os.path.join(ROOT_DIR, "logs")
62
63
# Results directory
64
# Save submission files here
65
RESULTS_DIR = os.path.join(ROOT_DIR, "results/nucleus/")
66
67
# The dataset doesn't have a standard train/val split, so I picked
68
# a variety of images to surve as a validation set.
69
VAL_IMAGE_IDS = [
70
"0c2550a23b8a0f29a7575de8c61690d3c31bc897dd5ba66caec201d201a278c2",
71
"92f31f591929a30e4309ab75185c96ff4314ce0a7ead2ed2c2171897ad1da0c7",
72
"1e488c42eb1a54a3e8412b1f12cde530f950f238d71078f2ede6a85a02168e1f",
73
"c901794d1a421d52e5734500c0a2a8ca84651fb93b19cec2f411855e70cae339",
74
"8e507d58f4c27cd2a82bee79fe27b069befd62a46fdaed20970a95a2ba819c7b",
75
"60cb718759bff13f81c4055a7679e81326f78b6a193a2d856546097c949b20ff",
76
"da5f98f2b8a64eee735a398de48ed42cd31bf17a6063db46a9e0783ac13cd844",
77
"9ebcfaf2322932d464f15b5662cae4d669b2d785b8299556d73fffcae8365d32",
78
"1b44d22643830cd4f23c9deadb0bd499fb392fb2cd9526d81547d93077d983df",
79
"97126a9791f0c1176e4563ad679a301dac27c59011f579e808bbd6e9f4cd1034",
80
"e81c758e1ca177b0942ecad62cf8d321ffc315376135bcbed3df932a6e5b40c0",
81
"f29fd9c52e04403cd2c7d43b6fe2479292e53b2f61969d25256d2d2aca7c6a81",
82
"0ea221716cf13710214dcd331a61cea48308c3940df1d28cfc7fd817c83714e1",
83
"3ab9cab6212fabd723a2c5a1949c2ded19980398b56e6080978e796f45cbbc90",
84
"ebc18868864ad075548cc1784f4f9a237bb98335f9645ee727dac8332a3e3716",
85
"bb61fc17daf8bdd4e16fdcf50137a8d7762bec486ede9249d92e511fcb693676",
86
"e1bcb583985325d0ef5f3ef52957d0371c96d4af767b13e48102bca9d5351a9b",
87
"947c0d94c8213ac7aaa41c4efc95d854246550298259cf1bb489654d0e969050",
88
"cbca32daaae36a872a11da4eaff65d1068ff3f154eedc9d3fc0c214a4e5d32bd",
89
"f4c4db3df4ff0de90f44b027fc2e28c16bf7e5c75ea75b0a9762bbb7ac86e7a3",
90
"4193474b2f1c72f735b13633b219d9cabdd43c21d9c2bb4dfc4809f104ba4c06",
91
"f73e37957c74f554be132986f38b6f1d75339f636dfe2b681a0cf3f88d2733af",
92
"a4c44fc5f5bf213e2be6091ccaed49d8bf039d78f6fbd9c4d7b7428cfcb2eda4",
93
"cab4875269f44a701c5e58190a1d2f6fcb577ea79d842522dcab20ccb39b7ad2",
94
"8ecdb93582b2d5270457b36651b62776256ade3aaa2d7432ae65c14f07432d49",
95
]
96
97
98
############################################################
99
# Configurations
100
############################################################
101
102
class NucleusConfig(Config):
103
"""Configuration for training on the nucleus segmentation dataset."""
104
# Give the configuration a recognizable name
105
NAME = "nucleus"
106
107
# Adjust depending on your GPU memory
108
IMAGES_PER_GPU = 6
109
110
# Number of classes (including background)
111
NUM_CLASSES = 1 + 1 # Background + nucleus
112
113
# Number of training and validation steps per epoch
114
STEPS_PER_EPOCH = (657 - len(VAL_IMAGE_IDS)) // IMAGES_PER_GPU
115
VALIDATION_STEPS = max(1, len(VAL_IMAGE_IDS) // IMAGES_PER_GPU)
116
117
# Don't exclude based on confidence. Since we have two classes
118
# then 0.5 is the minimum anyway as it picks between nucleus and BG
119
DETECTION_MIN_CONFIDENCE = 0
120
121
# Backbone network architecture
122
# Supported values are: resnet50, resnet101
123
BACKBONE = "resnet50"
124
125
# Input image resizing
126
# Random crops of size 512x512
127
IMAGE_RESIZE_MODE = "crop"
128
IMAGE_MIN_DIM = 512
129
IMAGE_MAX_DIM = 512
130
IMAGE_MIN_SCALE = 2.0
131
132
# Length of square anchor side in pixels
133
RPN_ANCHOR_SCALES = (8, 16, 32, 64, 128)
134
135
# ROIs kept after non-maximum supression (training and inference)
136
POST_NMS_ROIS_TRAINING = 1000
137
POST_NMS_ROIS_INFERENCE = 2000
138
139
# Non-max suppression threshold to filter RPN proposals.
140
# You can increase this during training to generate more propsals.
141
RPN_NMS_THRESHOLD = 0.9
142
143
# How many anchors per image to use for RPN training
144
RPN_TRAIN_ANCHORS_PER_IMAGE = 64
145
146
# Image mean (RGB)
147
MEAN_PIXEL = np.array([43.53, 39.56, 48.22])
148
149
# If enabled, resizes instance masks to a smaller size to reduce
150
# memory load. Recommended when using high-resolution images.
151
USE_MINI_MASK = True
152
MINI_MASK_SHAPE = (56, 56) # (height, width) of the mini-mask
153
154
# Number of ROIs per image to feed to classifier/mask heads
155
# The Mask RCNN paper uses 512 but often the RPN doesn't generate
156
# enough positive proposals to fill this and keep a positive:negative
157
# ratio of 1:3. You can increase the number of proposals by adjusting
158
# the RPN NMS threshold.
159
TRAIN_ROIS_PER_IMAGE = 128
160
161
# Maximum number of ground truth instances to use in one image
162
MAX_GT_INSTANCES = 200
163
164
# Max number of final detections per image
165
DETECTION_MAX_INSTANCES = 400
166
167
168
class NucleusInferenceConfig(NucleusConfig):
169
# Set batch size to 1 to run one image at a time
170
GPU_COUNT = 1
171
IMAGES_PER_GPU = 1
172
# Don't resize imager for inferencing
173
IMAGE_RESIZE_MODE = "pad64"
174
# Non-max suppression threshold to filter RPN proposals.
175
# You can increase this during training to generate more propsals.
176
RPN_NMS_THRESHOLD = 0.7
177
178
179
############################################################
180
# Dataset
181
############################################################
182
183
class NucleusDataset(utils.Dataset):
184
185
def load_nucleus(self, dataset_dir, subset):
186
"""Load a subset of the nuclei dataset.
187
188
dataset_dir: Root directory of the dataset
189
subset: Subset to load. Either the name of the sub-directory,
190
such as stage1_train, stage1_test, ...etc. or, one of:
191
* train: stage1_train excluding validation images
192
* val: validation images from VAL_IMAGE_IDS
193
"""
194
# Add classes. We have one class.
195
# Naming the dataset nucleus, and the class nucleus
196
self.add_class("nucleus", 1, "nucleus")
197
198
# Which subset?
199
# "val": use hard-coded list above
200
# "train": use data from stage1_train minus the hard-coded list above
201
# else: use the data from the specified sub-directory
202
assert subset in ["train", "val", "stage1_train", "stage1_test", "stage2_test"]
203
subset_dir = "stage1_train" if subset in ["train", "val"] else subset
204
dataset_dir = os.path.join(dataset_dir, subset_dir)
205
if subset == "val":
206
image_ids = VAL_IMAGE_IDS
207
else:
208
# Get image ids from directory names
209
image_ids = next(os.walk(dataset_dir))[1]
210
if subset == "train":
211
image_ids = list(set(image_ids) - set(VAL_IMAGE_IDS))
212
213
# Add images
214
for image_id in image_ids:
215
self.add_image(
216
"nucleus",
217
image_id=image_id,
218
path=os.path.join(dataset_dir, image_id, "images/{}.png".format(image_id)))
219
220
def load_mask(self, image_id):
221
"""Generate instance masks for an image.
222
Returns:
223
masks: A bool array of shape [height, width, instance count] with
224
one mask per instance.
225
class_ids: a 1D array of class IDs of the instance masks.
226
"""
227
info = self.image_info[image_id]
228
# Get mask directory from image path
229
mask_dir = os.path.join(os.path.dirname(os.path.dirname(info['path'])), "masks")
230
231
# Read mask files from .png image
232
mask = []
233
for f in next(os.walk(mask_dir))[2]:
234
if f.endswith(".png"):
235
m = skimage.io.imread(os.path.join(mask_dir, f)).astype(np.bool)
236
mask.append(m)
237
mask = np.stack(mask, axis=-1)
238
# Return mask, and array of class IDs of each instance. Since we have
239
# one class ID, we return an array of ones
240
return mask, np.ones([mask.shape[-1]], dtype=np.int32)
241
242
def image_reference(self, image_id):
243
"""Return the path of the image."""
244
info = self.image_info[image_id]
245
if info["source"] == "nucleus":
246
return info["id"]
247
else:
248
super(self.__class__, self).image_reference(image_id)
249
250
251
############################################################
252
# Training
253
############################################################
254
255
def train(model, dataset_dir, subset):
256
"""Train the model."""
257
# Training dataset.
258
dataset_train = NucleusDataset()
259
dataset_train.load_nucleus(dataset_dir, subset)
260
dataset_train.prepare()
261
262
# Validation dataset
263
dataset_val = NucleusDataset()
264
dataset_val.load_nucleus(dataset_dir, "val")
265
dataset_val.prepare()
266
267
# Image augmentation
268
# http://imgaug.readthedocs.io/en/latest/source/augmenters.html
269
augmentation = iaa.SomeOf((0, 2), [
270
iaa.Fliplr(0.5),
271
iaa.Flipud(0.5),
272
iaa.OneOf([iaa.Affine(rotate=90),
273
iaa.Affine(rotate=180),
274
iaa.Affine(rotate=270)]),
275
iaa.Multiply((0.8, 1.5)),
276
iaa.GaussianBlur(sigma=(0.0, 5.0))
277
])
278
279
# *** This training schedule is an example. Update to your needs ***
280
281
# If starting from imagenet, train heads only for a bit
282
# since they have random weights
283
print("Train network heads")
284
model.train(dataset_train, dataset_val,
285
learning_rate=config.LEARNING_RATE,
286
epochs=20,
287
augmentation=augmentation,
288
layers='heads')
289
290
print("Train all layers")
291
model.train(dataset_train, dataset_val,
292
learning_rate=config.LEARNING_RATE,
293
epochs=40,
294
augmentation=augmentation,
295
layers='all')
296
297
298
############################################################
299
# RLE Encoding
300
############################################################
301
302
def rle_encode(mask):
303
"""Encodes a mask in Run Length Encoding (RLE).
304
Returns a string of space-separated values.
305
"""
306
assert mask.ndim == 2, "Mask must be of shape [Height, Width]"
307
# Flatten it column wise
308
m = mask.T.flatten()
309
# Compute gradient. Equals 1 or -1 at transition points
310
g = np.diff(np.concatenate([[0], m, [0]]), n=1)
311
# 1-based indicies of transition points (where gradient != 0)
312
rle = np.where(g != 0)[0].reshape([-1, 2]) + 1
313
# Convert second index in each pair to lenth
314
rle[:, 1] = rle[:, 1] - rle[:, 0]
315
return " ".join(map(str, rle.flatten()))
316
317
318
def rle_decode(rle, shape):
319
"""Decodes an RLE encoded list of space separated
320
numbers and returns a binary mask."""
321
rle = list(map(int, rle.split()))
322
rle = np.array(rle, dtype=np.int32).reshape([-1, 2])
323
rle[:, 1] += rle[:, 0]
324
rle -= 1
325
mask = np.zeros([shape[0] * shape[1]], np.bool)
326
for s, e in rle:
327
assert 0 <= s < mask.shape[0]
328
assert 1 <= e <= mask.shape[0], "shape: {} s {} e {}".format(shape, s, e)
329
mask[s:e] = 1
330
# Reshape and transpose
331
mask = mask.reshape([shape[1], shape[0]]).T
332
return mask
333
334
335
def mask_to_rle(image_id, mask, scores):
336
"Encodes instance masks to submission format."
337
assert mask.ndim == 3, "Mask must be [H, W, count]"
338
# If mask is empty, return line with image ID only
339
if mask.shape[-1] == 0:
340
return "{},".format(image_id)
341
# Remove mask overlaps
342
# Multiply each instance mask by its score order
343
# then take the maximum across the last dimension
344
order = np.argsort(scores)[::-1] + 1 # 1-based descending
345
mask = np.max(mask * np.reshape(order, [1, 1, -1]), -1)
346
# Loop over instance masks
347
lines = []
348
for o in order:
349
m = np.where(mask == o, 1, 0)
350
# Skip if empty
351
if m.sum() == 0.0:
352
continue
353
rle = rle_encode(m)
354
lines.append("{}, {}".format(image_id, rle))
355
return "\n".join(lines)
356
357
358
############################################################
359
# Detection
360
############################################################
361
362
def detect(model, dataset_dir, subset):
363
"""Run detection on images in the given directory."""
364
print("Running on {}".format(dataset_dir))
365
366
# Create directory
367
if not os.path.exists(RESULTS_DIR):
368
os.makedirs(RESULTS_DIR)
369
submit_dir = "submit_{:%Y%m%dT%H%M%S}".format(datetime.datetime.now())
370
submit_dir = os.path.join(RESULTS_DIR, submit_dir)
371
os.makedirs(submit_dir)
372
373
# Read dataset
374
dataset = NucleusDataset()
375
dataset.load_nucleus(dataset_dir, subset)
376
dataset.prepare()
377
# Load over images
378
submission = []
379
for image_id in dataset.image_ids:
380
# Load image and run detection
381
image = dataset.load_image(image_id)
382
# Detect objects
383
r = model.detect([image], verbose=0)[0]
384
# Encode image to RLE. Returns a string of multiple lines
385
source_id = dataset.image_info[image_id]["id"]
386
rle = mask_to_rle(source_id, r["masks"], r["scores"])
387
submission.append(rle)
388
# Save image with masks
389
visualize.display_instances(
390
image, r['rois'], r['masks'], r['class_ids'],
391
dataset.class_names, r['scores'],
392
show_bbox=False, show_mask=False,
393
title="Predictions")
394
plt.savefig("{}/{}.png".format(submit_dir, dataset.image_info[image_id]["id"]))
395
396
# Save to csv file
397
submission = "ImageId,EncodedPixels\n" + "\n".join(submission)
398
file_path = os.path.join(submit_dir, "submit.csv")
399
with open(file_path, "w") as f:
400
f.write(submission)
401
print("Saved to ", submit_dir)
402
403
404
############################################################
405
# Command Line
406
############################################################
407
408
if __name__ == '__main__':
409
import argparse
410
411
# Parse command line arguments
412
parser = argparse.ArgumentParser(
413
description='Mask R-CNN for nuclei counting and segmentation')
414
parser.add_argument("command",
415
metavar="<command>",
416
help="'train' or 'detect'")
417
parser.add_argument('--dataset', required=False,
418
metavar="/path/to/dataset/",
419
help='Root directory of the dataset')
420
parser.add_argument('--weights', required=True,
421
metavar="/path/to/weights.h5",
422
help="Path to weights .h5 file or 'coco'")
423
parser.add_argument('--logs', required=False,
424
default=DEFAULT_LOGS_DIR,
425
metavar="/path/to/logs/",
426
help='Logs and checkpoints directory (default=logs/)')
427
parser.add_argument('--subset', required=False,
428
metavar="Dataset sub-directory",
429
help="Subset of dataset to run prediction on")
430
args = parser.parse_args()
431
432
# Validate arguments
433
if args.command == "train":
434
assert args.dataset, "Argument --dataset is required for training"
435
elif args.command == "detect":
436
assert args.subset, "Provide --subset to run prediction on"
437
438
print("Weights: ", args.weights)
439
print("Dataset: ", args.dataset)
440
if args.subset:
441
print("Subset: ", args.subset)
442
print("Logs: ", args.logs)
443
444
# Configurations
445
if args.command == "train":
446
config = NucleusConfig()
447
else:
448
config = NucleusInferenceConfig()
449
config.display()
450
451
# Create model
452
if args.command == "train":
453
model = modellib.MaskRCNN(mode="training", config=config,
454
model_dir=args.logs)
455
else:
456
model = modellib.MaskRCNN(mode="inference", config=config,
457
model_dir=args.logs)
458
459
# Select weights file to load
460
if args.weights.lower() == "coco":
461
weights_path = COCO_WEIGHTS_PATH
462
# Download weights file
463
if not os.path.exists(weights_path):
464
utils.download_trained_weights(weights_path)
465
elif args.weights.lower() == "last":
466
# Find last trained weights
467
weights_path = model.find_last()
468
elif args.weights.lower() == "imagenet":
469
# Start from ImageNet trained weights
470
weights_path = model.get_imagenet_weights()
471
else:
472
weights_path = args.weights
473
474
# Load weights
475
print("Loading weights ", weights_path)
476
if args.weights.lower() == "coco":
477
# Exclude the last layers because they require a matching
478
# number of classes
479
model.load_weights(weights_path, by_name=True, exclude=[
480
"mrcnn_class_logits", "mrcnn_bbox_fc",
481
"mrcnn_bbox", "mrcnn_mask"])
482
else:
483
model.load_weights(weights_path, by_name=True)
484
485
# Train or evaluate
486
if args.command == "train":
487
train(model, args.dataset, args.subset)
488
elif args.command == "detect":
489
detect(model, args.dataset, args.subset)
490
else:
491
print("'{}' is not recognized. "
492
"Use 'train' or 'detect'".format(args.command))
493
494