Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
matterport
GitHub Repository: matterport/Mask_RCNN
Path: blob/master/samples/coco/coco.py
240 views
1
"""
2
Mask R-CNN
3
Configurations and data loading code for MS COCO.
4
5
Copyright (c) 2017 Matterport, Inc.
6
Licensed under the MIT License (see LICENSE for details)
7
Written by Waleed Abdulla
8
9
------------------------------------------------------------
10
11
Usage: import the module (see Jupyter notebooks for examples), or run from
12
the command line as such:
13
14
# Train a new model starting from pre-trained COCO weights
15
python3 coco.py train --dataset=/path/to/coco/ --model=coco
16
17
# Train a new model starting from ImageNet weights. Also auto download COCO dataset
18
python3 coco.py train --dataset=/path/to/coco/ --model=imagenet --download=True
19
20
# Continue training a model that you had trained earlier
21
python3 coco.py train --dataset=/path/to/coco/ --model=/path/to/weights.h5
22
23
# Continue training the last model you trained
24
python3 coco.py train --dataset=/path/to/coco/ --model=last
25
26
# Run COCO evaluatoin on the last model you trained
27
python3 coco.py evaluate --dataset=/path/to/coco/ --model=last
28
"""
29
30
import os
31
import sys
32
import time
33
import numpy as np
34
import imgaug # https://github.com/aleju/imgaug (pip3 install imgaug)
35
36
# Download and install the Python COCO tools from https://github.com/waleedka/coco
37
# That's a fork from the original https://github.com/pdollar/coco with a bug
38
# fix for Python 3.
39
# I submitted a pull request https://github.com/cocodataset/cocoapi/pull/50
40
# If the PR is merged then use the original repo.
41
# Note: Edit PythonAPI/Makefile and replace "python" with "python3".
42
from pycocotools.coco import COCO
43
from pycocotools.cocoeval import COCOeval
44
from pycocotools import mask as maskUtils
45
46
import zipfile
47
import urllib.request
48
import shutil
49
50
# Root directory of the project
51
ROOT_DIR = os.path.abspath("../../")
52
53
# Import Mask RCNN
54
sys.path.append(ROOT_DIR) # To find local version of the library
55
from mrcnn.config import Config
56
from mrcnn import model as modellib, utils
57
58
# Path to trained weights file
59
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
60
61
# Directory to save logs and model checkpoints, if not provided
62
# through the command line argument --logs
63
DEFAULT_LOGS_DIR = os.path.join(ROOT_DIR, "logs")
64
DEFAULT_DATASET_YEAR = "2014"
65
66
############################################################
67
# Configurations
68
############################################################
69
70
71
class CocoConfig(Config):
72
"""Configuration for training on MS COCO.
73
Derives from the base Config class and overrides values specific
74
to the COCO dataset.
75
"""
76
# Give the configuration a recognizable name
77
NAME = "coco"
78
79
# We use a GPU with 12GB memory, which can fit two images.
80
# Adjust down if you use a smaller GPU.
81
IMAGES_PER_GPU = 2
82
83
# Uncomment to train on 8 GPUs (default is 1)
84
# GPU_COUNT = 8
85
86
# Number of classes (including background)
87
NUM_CLASSES = 1 + 80 # COCO has 80 classes
88
89
90
############################################################
91
# Dataset
92
############################################################
93
94
class CocoDataset(utils.Dataset):
95
def load_coco(self, dataset_dir, subset, year=DEFAULT_DATASET_YEAR, class_ids=None,
96
class_map=None, return_coco=False, auto_download=False):
97
"""Load a subset of the COCO dataset.
98
dataset_dir: The root directory of the COCO dataset.
99
subset: What to load (train, val, minival, valminusminival)
100
year: What dataset year to load (2014, 2017) as a string, not an integer
101
class_ids: If provided, only loads images that have the given classes.
102
class_map: TODO: Not implemented yet. Supports maping classes from
103
different datasets to the same class ID.
104
return_coco: If True, returns the COCO object.
105
auto_download: Automatically download and unzip MS-COCO images and annotations
106
"""
107
108
if auto_download is True:
109
self.auto_download(dataset_dir, subset, year)
110
111
coco = COCO("{}/annotations/instances_{}{}.json".format(dataset_dir, subset, year))
112
if subset == "minival" or subset == "valminusminival":
113
subset = "val"
114
image_dir = "{}/{}{}".format(dataset_dir, subset, year)
115
116
# Load all classes or a subset?
117
if not class_ids:
118
# All classes
119
class_ids = sorted(coco.getCatIds())
120
121
# All images or a subset?
122
if class_ids:
123
image_ids = []
124
for id in class_ids:
125
image_ids.extend(list(coco.getImgIds(catIds=[id])))
126
# Remove duplicates
127
image_ids = list(set(image_ids))
128
else:
129
# All images
130
image_ids = list(coco.imgs.keys())
131
132
# Add classes
133
for i in class_ids:
134
self.add_class("coco", i, coco.loadCats(i)[0]["name"])
135
136
# Add images
137
for i in image_ids:
138
self.add_image(
139
"coco", image_id=i,
140
path=os.path.join(image_dir, coco.imgs[i]['file_name']),
141
width=coco.imgs[i]["width"],
142
height=coco.imgs[i]["height"],
143
annotations=coco.loadAnns(coco.getAnnIds(
144
imgIds=[i], catIds=class_ids, iscrowd=None)))
145
if return_coco:
146
return coco
147
148
def auto_download(self, dataDir, dataType, dataYear):
149
"""Download the COCO dataset/annotations if requested.
150
dataDir: The root directory of the COCO dataset.
151
dataType: What to load (train, val, minival, valminusminival)
152
dataYear: What dataset year to load (2014, 2017) as a string, not an integer
153
Note:
154
For 2014, use "train", "val", "minival", or "valminusminival"
155
For 2017, only "train" and "val" annotations are available
156
"""
157
158
# Setup paths and file names
159
if dataType == "minival" or dataType == "valminusminival":
160
imgDir = "{}/{}{}".format(dataDir, "val", dataYear)
161
imgZipFile = "{}/{}{}.zip".format(dataDir, "val", dataYear)
162
imgURL = "http://images.cocodataset.org/zips/{}{}.zip".format("val", dataYear)
163
else:
164
imgDir = "{}/{}{}".format(dataDir, dataType, dataYear)
165
imgZipFile = "{}/{}{}.zip".format(dataDir, dataType, dataYear)
166
imgURL = "http://images.cocodataset.org/zips/{}{}.zip".format(dataType, dataYear)
167
# print("Image paths:"); print(imgDir); print(imgZipFile); print(imgURL)
168
169
# Create main folder if it doesn't exist yet
170
if not os.path.exists(dataDir):
171
os.makedirs(dataDir)
172
173
# Download images if not available locally
174
if not os.path.exists(imgDir):
175
os.makedirs(imgDir)
176
print("Downloading images to " + imgZipFile + " ...")
177
with urllib.request.urlopen(imgURL) as resp, open(imgZipFile, 'wb') as out:
178
shutil.copyfileobj(resp, out)
179
print("... done downloading.")
180
print("Unzipping " + imgZipFile)
181
with zipfile.ZipFile(imgZipFile, "r") as zip_ref:
182
zip_ref.extractall(dataDir)
183
print("... done unzipping")
184
print("Will use images in " + imgDir)
185
186
# Setup annotations data paths
187
annDir = "{}/annotations".format(dataDir)
188
if dataType == "minival":
189
annZipFile = "{}/instances_minival2014.json.zip".format(dataDir)
190
annFile = "{}/instances_minival2014.json".format(annDir)
191
annURL = "https://dl.dropboxusercontent.com/s/o43o90bna78omob/instances_minival2014.json.zip?dl=0"
192
unZipDir = annDir
193
elif dataType == "valminusminival":
194
annZipFile = "{}/instances_valminusminival2014.json.zip".format(dataDir)
195
annFile = "{}/instances_valminusminival2014.json".format(annDir)
196
annURL = "https://dl.dropboxusercontent.com/s/s3tw5zcg7395368/instances_valminusminival2014.json.zip?dl=0"
197
unZipDir = annDir
198
else:
199
annZipFile = "{}/annotations_trainval{}.zip".format(dataDir, dataYear)
200
annFile = "{}/instances_{}{}.json".format(annDir, dataType, dataYear)
201
annURL = "http://images.cocodataset.org/annotations/annotations_trainval{}.zip".format(dataYear)
202
unZipDir = dataDir
203
# print("Annotations paths:"); print(annDir); print(annFile); print(annZipFile); print(annURL)
204
205
# Download annotations if not available locally
206
if not os.path.exists(annDir):
207
os.makedirs(annDir)
208
if not os.path.exists(annFile):
209
if not os.path.exists(annZipFile):
210
print("Downloading zipped annotations to " + annZipFile + " ...")
211
with urllib.request.urlopen(annURL) as resp, open(annZipFile, 'wb') as out:
212
shutil.copyfileobj(resp, out)
213
print("... done downloading.")
214
print("Unzipping " + annZipFile)
215
with zipfile.ZipFile(annZipFile, "r") as zip_ref:
216
zip_ref.extractall(unZipDir)
217
print("... done unzipping")
218
print("Will use annotations in " + annFile)
219
220
def load_mask(self, image_id):
221
"""Load instance masks for the given image.
222
223
Different datasets use different ways to store masks. This
224
function converts the different mask format to one format
225
in the form of a bitmap [height, width, instances].
226
227
Returns:
228
masks: A bool array of shape [height, width, instance count] with
229
one mask per instance.
230
class_ids: a 1D array of class IDs of the instance masks.
231
"""
232
# If not a COCO image, delegate to parent class.
233
image_info = self.image_info[image_id]
234
if image_info["source"] != "coco":
235
return super(CocoDataset, self).load_mask(image_id)
236
237
instance_masks = []
238
class_ids = []
239
annotations = self.image_info[image_id]["annotations"]
240
# Build mask of shape [height, width, instance_count] and list
241
# of class IDs that correspond to each channel of the mask.
242
for annotation in annotations:
243
class_id = self.map_source_class_id(
244
"coco.{}".format(annotation['category_id']))
245
if class_id:
246
m = self.annToMask(annotation, image_info["height"],
247
image_info["width"])
248
# Some objects are so small that they're less than 1 pixel area
249
# and end up rounded out. Skip those objects.
250
if m.max() < 1:
251
continue
252
# Is it a crowd? If so, use a negative class ID.
253
if annotation['iscrowd']:
254
# Use negative class ID for crowds
255
class_id *= -1
256
# For crowd masks, annToMask() sometimes returns a mask
257
# smaller than the given dimensions. If so, resize it.
258
if m.shape[0] != image_info["height"] or m.shape[1] != image_info["width"]:
259
m = np.ones([image_info["height"], image_info["width"]], dtype=bool)
260
instance_masks.append(m)
261
class_ids.append(class_id)
262
263
# Pack instance masks into an array
264
if class_ids:
265
mask = np.stack(instance_masks, axis=2).astype(np.bool)
266
class_ids = np.array(class_ids, dtype=np.int32)
267
return mask, class_ids
268
else:
269
# Call super class to return an empty mask
270
return super(CocoDataset, self).load_mask(image_id)
271
272
def image_reference(self, image_id):
273
"""Return a link to the image in the COCO Website."""
274
info = self.image_info[image_id]
275
if info["source"] == "coco":
276
return "http://cocodataset.org/#explore?id={}".format(info["id"])
277
else:
278
super(CocoDataset, self).image_reference(image_id)
279
280
# The following two functions are from pycocotools with a few changes.
281
282
def annToRLE(self, ann, height, width):
283
"""
284
Convert annotation which can be polygons, uncompressed RLE to RLE.
285
:return: binary mask (numpy 2D array)
286
"""
287
segm = ann['segmentation']
288
if isinstance(segm, list):
289
# polygon -- a single object might consist of multiple parts
290
# we merge all parts into one mask rle code
291
rles = maskUtils.frPyObjects(segm, height, width)
292
rle = maskUtils.merge(rles)
293
elif isinstance(segm['counts'], list):
294
# uncompressed RLE
295
rle = maskUtils.frPyObjects(segm, height, width)
296
else:
297
# rle
298
rle = ann['segmentation']
299
return rle
300
301
def annToMask(self, ann, height, width):
302
"""
303
Convert annotation which can be polygons, uncompressed RLE, or RLE to binary mask.
304
:return: binary mask (numpy 2D array)
305
"""
306
rle = self.annToRLE(ann, height, width)
307
m = maskUtils.decode(rle)
308
return m
309
310
311
############################################################
312
# COCO Evaluation
313
############################################################
314
315
def build_coco_results(dataset, image_ids, rois, class_ids, scores, masks):
316
"""Arrange resutls to match COCO specs in http://cocodataset.org/#format
317
"""
318
# If no results, return an empty list
319
if rois is None:
320
return []
321
322
results = []
323
for image_id in image_ids:
324
# Loop through detections
325
for i in range(rois.shape[0]):
326
class_id = class_ids[i]
327
score = scores[i]
328
bbox = np.around(rois[i], 1)
329
mask = masks[:, :, i]
330
331
result = {
332
"image_id": image_id,
333
"category_id": dataset.get_source_class_id(class_id, "coco"),
334
"bbox": [bbox[1], bbox[0], bbox[3] - bbox[1], bbox[2] - bbox[0]],
335
"score": score,
336
"segmentation": maskUtils.encode(np.asfortranarray(mask))
337
}
338
results.append(result)
339
return results
340
341
342
def evaluate_coco(model, dataset, coco, eval_type="bbox", limit=0, image_ids=None):
343
"""Runs official COCO evaluation.
344
dataset: A Dataset object with valiadtion data
345
eval_type: "bbox" or "segm" for bounding box or segmentation evaluation
346
limit: if not 0, it's the number of images to use for evaluation
347
"""
348
# Pick COCO images from the dataset
349
image_ids = image_ids or dataset.image_ids
350
351
# Limit to a subset
352
if limit:
353
image_ids = image_ids[:limit]
354
355
# Get corresponding COCO image IDs.
356
coco_image_ids = [dataset.image_info[id]["id"] for id in image_ids]
357
358
t_prediction = 0
359
t_start = time.time()
360
361
results = []
362
for i, image_id in enumerate(image_ids):
363
# Load image
364
image = dataset.load_image(image_id)
365
366
# Run detection
367
t = time.time()
368
r = model.detect([image], verbose=0)[0]
369
t_prediction += (time.time() - t)
370
371
# Convert results to COCO format
372
# Cast masks to uint8 because COCO tools errors out on bool
373
image_results = build_coco_results(dataset, coco_image_ids[i:i + 1],
374
r["rois"], r["class_ids"],
375
r["scores"],
376
r["masks"].astype(np.uint8))
377
results.extend(image_results)
378
379
# Load results. This modifies results with additional attributes.
380
coco_results = coco.loadRes(results)
381
382
# Evaluate
383
cocoEval = COCOeval(coco, coco_results, eval_type)
384
cocoEval.params.imgIds = coco_image_ids
385
cocoEval.evaluate()
386
cocoEval.accumulate()
387
cocoEval.summarize()
388
389
print("Prediction time: {}. Average {}/image".format(
390
t_prediction, t_prediction / len(image_ids)))
391
print("Total time: ", time.time() - t_start)
392
393
394
############################################################
395
# Training
396
############################################################
397
398
399
if __name__ == '__main__':
400
import argparse
401
402
# Parse command line arguments
403
parser = argparse.ArgumentParser(
404
description='Train Mask R-CNN on MS COCO.')
405
parser.add_argument("command",
406
metavar="<command>",
407
help="'train' or 'evaluate' on MS COCO")
408
parser.add_argument('--dataset', required=True,
409
metavar="/path/to/coco/",
410
help='Directory of the MS-COCO dataset')
411
parser.add_argument('--year', required=False,
412
default=DEFAULT_DATASET_YEAR,
413
metavar="<year>",
414
help='Year of the MS-COCO dataset (2014 or 2017) (default=2014)')
415
parser.add_argument('--model', required=True,
416
metavar="/path/to/weights.h5",
417
help="Path to weights .h5 file or 'coco'")
418
parser.add_argument('--logs', required=False,
419
default=DEFAULT_LOGS_DIR,
420
metavar="/path/to/logs/",
421
help='Logs and checkpoints directory (default=logs/)')
422
parser.add_argument('--limit', required=False,
423
default=500,
424
metavar="<image count>",
425
help='Images to use for evaluation (default=500)')
426
parser.add_argument('--download', required=False,
427
default=False,
428
metavar="<True|False>",
429
help='Automatically download and unzip MS-COCO files (default=False)',
430
type=bool)
431
args = parser.parse_args()
432
print("Command: ", args.command)
433
print("Model: ", args.model)
434
print("Dataset: ", args.dataset)
435
print("Year: ", args.year)
436
print("Logs: ", args.logs)
437
print("Auto Download: ", args.download)
438
439
# Configurations
440
if args.command == "train":
441
config = CocoConfig()
442
else:
443
class InferenceConfig(CocoConfig):
444
# Set batch size to 1 since we'll be running inference on
445
# one image at a time. Batch size = GPU_COUNT * IMAGES_PER_GPU
446
GPU_COUNT = 1
447
IMAGES_PER_GPU = 1
448
DETECTION_MIN_CONFIDENCE = 0
449
config = InferenceConfig()
450
config.display()
451
452
# Create model
453
if args.command == "train":
454
model = modellib.MaskRCNN(mode="training", config=config,
455
model_dir=args.logs)
456
else:
457
model = modellib.MaskRCNN(mode="inference", config=config,
458
model_dir=args.logs)
459
460
# Select weights file to load
461
if args.model.lower() == "coco":
462
model_path = COCO_MODEL_PATH
463
elif args.model.lower() == "last":
464
# Find last trained weights
465
model_path = model.find_last()
466
elif args.model.lower() == "imagenet":
467
# Start from ImageNet trained weights
468
model_path = model.get_imagenet_weights()
469
else:
470
model_path = args.model
471
472
# Load weights
473
print("Loading weights ", model_path)
474
model.load_weights(model_path, by_name=True)
475
476
# Train or evaluate
477
if args.command == "train":
478
# Training dataset. Use the training set and 35K from the
479
# validation set, as as in the Mask RCNN paper.
480
dataset_train = CocoDataset()
481
dataset_train.load_coco(args.dataset, "train", year=args.year, auto_download=args.download)
482
if args.year in '2014':
483
dataset_train.load_coco(args.dataset, "valminusminival", year=args.year, auto_download=args.download)
484
dataset_train.prepare()
485
486
# Validation dataset
487
dataset_val = CocoDataset()
488
val_type = "val" if args.year in '2017' else "minival"
489
dataset_val.load_coco(args.dataset, val_type, year=args.year, auto_download=args.download)
490
dataset_val.prepare()
491
492
# Image Augmentation
493
# Right/Left flip 50% of the time
494
augmentation = imgaug.augmenters.Fliplr(0.5)
495
496
# *** This training schedule is an example. Update to your needs ***
497
498
# Training - Stage 1
499
print("Training network heads")
500
model.train(dataset_train, dataset_val,
501
learning_rate=config.LEARNING_RATE,
502
epochs=40,
503
layers='heads',
504
augmentation=augmentation)
505
506
# Training - Stage 2
507
# Finetune layers from ResNet stage 4 and up
508
print("Fine tune Resnet stage 4 and up")
509
model.train(dataset_train, dataset_val,
510
learning_rate=config.LEARNING_RATE,
511
epochs=120,
512
layers='4+',
513
augmentation=augmentation)
514
515
# Training - Stage 3
516
# Fine tune all layers
517
print("Fine tune all layers")
518
model.train(dataset_train, dataset_val,
519
learning_rate=config.LEARNING_RATE / 10,
520
epochs=160,
521
layers='all',
522
augmentation=augmentation)
523
524
elif args.command == "evaluate":
525
# Validation dataset
526
dataset_val = CocoDataset()
527
val_type = "val" if args.year in '2017' else "minival"
528
coco = dataset_val.load_coco(args.dataset, val_type, year=args.year, return_coco=True, auto_download=args.download)
529
dataset_val.prepare()
530
print("Running COCO evaluation on {} images.".format(args.limit))
531
evaluate_coco(model, dataset_val, coco, "bbox", limit=int(args.limit))
532
else:
533
print("'{}' is not recognized. "
534
"Use 'train' or 'evaluate'".format(args.command))
535
536