Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
matterport
GitHub Repository: matterport/Mask_RCNN
Path: blob/master/samples/nucleus/inspect_nucleus_data.ipynb
240 views
Kernel: Python 3

Inspect Nucleus Training Data

Inspect and visualize data loading and pre-processing code.

https://www.kaggle.com/c/data-science-bowl-2018

import os import sys import itertools import math import logging import json import re import random import time import concurrent.futures import numpy as np import matplotlib import matplotlib.pyplot as plt import matplotlib.patches as patches import matplotlib.lines as lines from matplotlib.patches import Polygon import imgaug from imgaug import augmenters as iaa # Root directory of the project ROOT_DIR = os.getcwd() if ROOT_DIR.endswith("samples/nucleus"): # Go up two levels to the repo root ROOT_DIR = os.path.dirname(os.path.dirname(ROOT_DIR)) # Import Mask RCNN sys.path.append(ROOT_DIR) from mrcnn import utils from mrcnn import visualize from mrcnn.visualize import display_images from mrcnn import model as modellib from mrcnn.model import log import nucleus %matplotlib inline
/usr/local/lib/python3.5/dist-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`. from ._conv import register_converters as _register_converters Using TensorFlow backend.
# Comment out to reload imported modules if they change # %load_ext autoreload # %autoreload 2

Configurations

# Dataset directory DATASET_DIR = os.path.join(ROOT_DIR, "datasets/nucleus") # Use configuation from nucleus.py, but override # image resizing so we see the real sizes here class NoResizeConfig(nucleus.NucleusConfig): IMAGE_RESIZE_MODE = "none" config = NoResizeConfig()

Notebook Preferences

def get_ax(rows=1, cols=1, size=16): """Return a Matplotlib Axes array to be used in all visualizations in the notebook. Provide a central point to control graph sizes. Adjust the size attribute to control how big to render images """ _, ax = plt.subplots(rows, cols, figsize=(size*cols, size*rows)) return ax

Dataset

Download the dataset from the competition Website. Unzip it and save it in mask_rcnn/datasets/nucleus. If you prefer a different directory then change the DATASET_DIR variable above.

https://www.kaggle.com/c/data-science-bowl-2018/data

# Load dataset dataset = nucleus.NucleusDataset() # The subset is the name of the sub-directory, such as stage1_train, # stage1_test, ...etc. You can also use these special values: # train: loads stage1_train but excludes validation images # val: loads validation images from stage1_train. For a list # of validation images see nucleus.py dataset.load_nucleus(DATASET_DIR, subset="train") # Must call before using the dataset dataset.prepare() print("Image Count: {}".format(len(dataset.image_ids))) print("Class Count: {}".format(dataset.num_classes)) for i, info in enumerate(dataset.class_info): print("{:3}. {:50}".format(i, info['name']))
Image Count: 639 Class Count: 2 0. BG 1. nucleus

Display Samples

# Load and display random samples image_ids = np.random.choice(dataset.image_ids, 4) for image_id in image_ids: image = dataset.load_image(image_id) mask, class_ids = dataset.load_mask(image_id) visualize.display_top_masks(image, mask, class_ids, dataset.class_names, limit=1)
Image in a Jupyter notebookImage in a Jupyter notebookImage in a Jupyter notebookImage in a Jupyter notebook
# Example of loading a specific image by its source ID source_id = "ed5be4b63e9506ad64660dd92a098ffcc0325195298c13c815a73773f1efc279" # Map source ID to Dataset image_id # Notice the nucleus prefix: it's the name given to the dataset in NucleusDataset image_id = dataset.image_from_source_map["nucleus.{}".format(source_id)] # Load and display image, image_meta, class_ids, bbox, mask = modellib.load_image_gt( dataset, config, image_id, use_mini_mask=False) log("molded_image", image) log("mask", mask) visualize.display_instances(image, bbox, mask, class_ids, dataset.class_names, show_bbox=False)
molded_image shape: (256, 320, 3) min: 28.00000 max: 232.00000 uint8 mask shape: (256, 320, 42) min: 0.00000 max: 1.00000 bool
Image in a Jupyter notebook

Dataset Stats

Loop through all images in the dataset and collect aggregate stats.

def image_stats(image_id): """Returns a dict of stats for one image.""" image = dataset.load_image(image_id) mask, _ = dataset.load_mask(image_id) bbox = utils.extract_bboxes(mask) # Sanity check assert mask.shape[:2] == image.shape[:2] # Return stats dict return { "id": image_id, "shape": list(image.shape), "bbox": [[b[2] - b[0], b[3] - b[1]] for b in bbox # Uncomment to exclude nuclei with 1 pixel width # or height (often on edges) # if b[2] - b[0] > 1 and b[3] - b[1] > 1 ], "color": np.mean(image, axis=(0, 1)), } # Loop through the dataset and compute stats over multiple threads # This might take a few minutes t_start = time.time() with concurrent.futures.ThreadPoolExecutor() as e: stats = list(e.map(image_stats, dataset.image_ids)) t_total = time.time() - t_start print("Total time: {:.1f} seconds".format(t_total))
Total time: 59.5 seconds

Image Size Stats

# Image stats image_shape = np.array([s['shape'] for s in stats]) image_color = np.array([s['color'] for s in stats]) print("Image Count: ", image_shape.shape[0]) print("Height mean: {:.2f} median: {:.2f} min: {:.2f} max: {:.2f}".format( np.mean(image_shape[:, 0]), np.median(image_shape[:, 0]), np.min(image_shape[:, 0]), np.max(image_shape[:, 0]))) print("Width mean: {:.2f} median: {:.2f} min: {:.2f} max: {:.2f}".format( np.mean(image_shape[:, 1]), np.median(image_shape[:, 1]), np.min(image_shape[:, 1]), np.max(image_shape[:, 1]))) print("Color mean (RGB): {:.2f} {:.2f} {:.2f}".format(*np.mean(image_color, axis=0))) # Histograms fig, ax = plt.subplots(1, 3, figsize=(16, 4)) ax[0].set_title("Height") _ = ax[0].hist(image_shape[:, 0], bins=20) ax[1].set_title("Width") _ = ax[1].hist(image_shape[:, 1], bins=20) ax[2].set_title("Height & Width") _ = ax[2].hist2d(image_shape[:, 1], image_shape[:, 0], bins=10, cmap="Blues")
Image Count: 639 Height mean: 332.27 median: 256.00 min: 256.00 max: 1024.00 Width mean: 373.72 median: 256.00 min: 256.00 max: 1272.00 Color mean (RGB): 41.55 37.83 46.13
Image in a Jupyter notebook

Nuclei per Image Stats

# Segment by image area image_area_bins = [256**2, 600**2, 1300**2] print("Nuclei/Image") fig, ax = plt.subplots(1, len(image_area_bins), figsize=(16, 4)) area_threshold = 0 for i, image_area in enumerate(image_area_bins): nuclei_per_image = np.array([len(s['bbox']) for s in stats if area_threshold < (s['shape'][0] * s['shape'][1]) <= image_area]) area_threshold = image_area if len(nuclei_per_image) == 0: print("Image area <= {:4}**2: None".format(np.sqrt(image_area))) continue print("Image area <= {:4.0f}**2: mean: {:.1f} median: {:.1f} min: {:.1f} max: {:.1f}".format( np.sqrt(image_area), nuclei_per_image.mean(), np.median(nuclei_per_image), nuclei_per_image.min(), nuclei_per_image.max())) ax[i].set_title("Image Area <= {:4}**2".format(np.sqrt(image_area))) _ = ax[i].hist(nuclei_per_image, bins=10)
Nuclei/Image Image area <= 256**2: mean: 28.7 median: 22.0 min: 1.0 max: 101.0 Image area <= 600**2: mean: 34.5 median: 25.0 min: 1.0 max: 151.0 Image area <= 1300**2: mean: 103.0 median: 101.0 min: 4.0 max: 375.0
Image in a Jupyter notebook

Nuclei Size Stats

# Nuclei size stats fig, ax = plt.subplots(1, len(image_area_bins), figsize=(16, 4)) area_threshold = 0 for i, image_area in enumerate(image_area_bins): nucleus_shape = np.array([ b for s in stats if area_threshold < (s['shape'][0] * s['shape'][1]) <= image_area for b in s['bbox']]) nucleus_area = nucleus_shape[:, 0] * nucleus_shape[:, 1] area_threshold = image_area print("\nImage Area <= {:.0f}**2".format(np.sqrt(image_area))) print(" Total Nuclei: ", nucleus_shape.shape[0]) print(" Nucleus Height. mean: {:.2f} median: {:.2f} min: {:.2f} max: {:.2f}".format( np.mean(nucleus_shape[:, 0]), np.median(nucleus_shape[:, 0]), np.min(nucleus_shape[:, 0]), np.max(nucleus_shape[:, 0]))) print(" Nucleus Width. mean: {:.2f} median: {:.2f} min: {:.2f} max: {:.2f}".format( np.mean(nucleus_shape[:, 1]), np.median(nucleus_shape[:, 1]), np.min(nucleus_shape[:, 1]), np.max(nucleus_shape[:, 1]))) print(" Nucleus Area. mean: {:.2f} median: {:.2f} min: {:.2f} max: {:.2f}".format( np.mean(nucleus_area), np.median(nucleus_area), np.min(nucleus_area), np.max(nucleus_area))) # Show 2D histogram _ = ax[i].hist2d(nucleus_shape[:, 1], nucleus_shape[:, 0], bins=20, cmap="Blues")
Image Area <= 256**2 Total Nuclei: 9383 Nucleus Height. mean: 14.14 median: 13.00 min: 2.00 max: 60.00 Nucleus Width. mean: 14.37 median: 13.00 min: 2.00 max: 61.00 Nucleus Area. mean: 231.92 median: 168.00 min: 24.00 max: 2880.00 Image Area <= 600**2 Total Nuclei: 7078 Nucleus Height. mean: 28.65 median: 23.00 min: 2.00 max: 120.00 Nucleus Width. mean: 29.74 median: 24.00 min: 1.00 max: 117.00 Nucleus Area. mean: 1094.10 median: 546.00 min: 21.00 max: 13560.00 Image Area <= 1300**2 Total Nuclei: 11023 Nucleus Height. mean: 25.89 median: 26.00 min: 2.00 max: 91.00 Nucleus Width. mean: 25.91 median: 26.00 min: 3.00 max: 88.00 Nucleus Area. mean: 751.73 median: 693.00 min: 21.00 max: 5120.00
Image in a Jupyter notebook
# Nuclei height/width ratio nucleus_aspect_ratio = nucleus_shape[:, 0] / nucleus_shape[:, 1] print("Nucleus Aspect Ratio. mean: {:.2f} median: {:.2f} min: {:.2f} max: {:.2f}".format( np.mean(nucleus_aspect_ratio), np.median(nucleus_aspect_ratio), np.min(nucleus_aspect_ratio), np.max(nucleus_aspect_ratio))) plt.figure(figsize=(15, 5)) _ = plt.hist(nucleus_aspect_ratio, bins=100, range=[0, 5])
Nucleus Aspect Ratio. mean: 1.08 median: 1.00 min: 0.10 max: 8.17
Image in a Jupyter notebook

Image Augmentation

Test out different augmentation methods

# List of augmentations # http://imgaug.readthedocs.io/en/latest/source/augmenters.html augmentation = iaa.Sometimes(0.9, [ iaa.Fliplr(0.5), iaa.Flipud(0.5), iaa.Multiply((0.8, 1.2)), iaa.GaussianBlur(sigma=(0.0, 5.0)) ])
# Load the image multiple times to show augmentations limit = 4 ax = get_ax(rows=2, cols=limit//2) for i in range(limit): image, image_meta, class_ids, bbox, mask = modellib.load_image_gt( dataset, config, image_id, use_mini_mask=False, augment=False, augmentation=augmentation) visualize.display_instances(image, bbox, mask, class_ids, dataset.class_names, ax=ax[i//2, i % 2], show_mask=False, show_bbox=False)
Image in a Jupyter notebook

Image Crops

Microscoy images tend to be large, but nuclei are small. So it's more efficient to train on random crops from large images. This is handled by config.IMAGE_RESIZE_MODE = "crop".

class RandomCropConfig(nucleus.NucleusConfig): IMAGE_RESIZE_MODE = "crop" IMAGE_MIN_DIM = 256 IMAGE_MAX_DIM = 256 crop_config = RandomCropConfig()
# Load the image multiple times to show augmentations limit = 4 image_id = np.random.choice(dataset.image_ids, 1)[0] ax = get_ax(rows=2, cols=limit//2) for i in range(limit): image, image_meta, class_ids, bbox, mask = modellib.load_image_gt( dataset, crop_config, image_id, use_mini_mask=False) visualize.display_instances(image, bbox, mask, class_ids, dataset.class_names, ax=ax[i//2, i % 2], show_mask=False, show_bbox=False)
Image in a Jupyter notebook

Mini Masks

Instance binary masks can get large when training with high resolution images. For example, if training with 1024x1024 image then the mask of a single instance requires 1MB of memory (Numpy uses bytes for boolean values). If an image has 100 instances then that's 100MB for the masks alone.

To improve training speed, we optimize masks:

  • We store mask pixels that are inside the object bounding box, rather than a mask of the full image. Most objects are small compared to the image size, so we save space by not storing a lot of zeros around the object.

  • We resize the mask to a smaller size (e.g. 56x56). For objects that are larger than the selected size we lose a bit of accuracy. But most object annotations are not very accuracy to begin with, so this loss is negligable for most practical purposes. Thie size of the mini_mask can be set in the config class.

To visualize the effect of mask resizing, and to verify the code correctness, we visualize some examples.

# Load random image and mask. image_id = np.random.choice(dataset.image_ids, 1)[0] image = dataset.load_image(image_id) mask, class_ids = dataset.load_mask(image_id) original_shape = image.shape # Resize image, window, scale, padding, _ = utils.resize_image( image, min_dim=config.IMAGE_MIN_DIM, max_dim=config.IMAGE_MAX_DIM, mode=config.IMAGE_RESIZE_MODE) mask = utils.resize_mask(mask, scale, padding) # Compute Bounding box bbox = utils.extract_bboxes(mask) # Display image and additional stats print("image_id: ", image_id, dataset.image_reference(image_id)) print("Original shape: ", original_shape) log("image", image) log("mask", mask) log("class_ids", class_ids) log("bbox", bbox) # Display image and instances visualize.display_instances(image, bbox, mask, class_ids, dataset.class_names)
image_id: 245 3b957237bc1e09740b58a414282393d3a91dde996b061e7061f4198fb03dab2e Original shape: (256, 256, 3) image shape: (256, 256, 3) min: 9.00000 max: 255.00000 uint8 mask shape: (256, 256, 26) min: 0.00000 max: 1.00000 bool class_ids shape: (26,) min: 1.00000 max: 1.00000 int32 bbox shape: (26, 4) min: 0.00000 max: 256.00000 int32
Image in a Jupyter notebook
image_id = np.random.choice(dataset.image_ids, 1)[0] image, image_meta, class_ids, bbox, mask = modellib.load_image_gt( dataset, config, image_id, use_mini_mask=False) log("image", image) log("image_meta", image_meta) log("class_ids", class_ids) log("bbox", bbox) log("mask", mask) display_images([image]+[mask[:,:,i] for i in range(min(mask.shape[-1], 7))])
image shape: (360, 360, 3) min: 3.00000 max: 50.00000 uint8 image_meta shape: (14,) min: 0.00000 max: 399.00000 int64 class_ids shape: (22,) min: 1.00000 max: 1.00000 int32 bbox shape: (22, 4) min: 0.00000 max: 360.00000 int32 mask shape: (360, 360, 22) min: 0.00000 max: 1.00000 bool
Image in a Jupyter notebook
visualize.display_instances(image, bbox, mask, class_ids, dataset.class_names)
Image in a Jupyter notebook
# Add augmentation and mask resizing. image, image_meta, class_ids, bbox, mask = modellib.load_image_gt( dataset, config, image_id, augment=True, use_mini_mask=True) log("mask", mask) display_images([image]+[mask[:,:,i] for i in range(min(mask.shape[-1], 7))])
WARNING:root:'augment' is depricated. Use 'augmentation' instead.
mask shape: (56, 56, 22) min: 0.00000 max: 1.00000 bool
Image in a Jupyter notebook
mask = utils.expand_mask(bbox, mask, image.shape) visualize.display_instances(image, bbox, mask, class_ids, dataset.class_names)
Image in a Jupyter notebook

Anchors

For an FPN network, the anchors must be ordered in a way that makes it easy to match anchors to the output of the convolution layers that predict anchor scores and shifts.

  • Sort by pyramid level first. All anchors of the first level, then all of the second and so on. This makes it easier to separate anchors by level.

  • Within each level, sort anchors by feature map processing sequence. Typically, a convolution layer processes a feature map starting from top-left and moving right row by row.

  • For each feature map cell, pick any sorting order for the anchors of different ratios. Here we match the order of ratios passed to the function.

## Visualize anchors of one cell at the center of the feature map # Load and display random image image_id = np.random.choice(dataset.image_ids, 1)[0] image, image_meta, _, _, _ = modellib.load_image_gt(dataset, crop_config, image_id) # Generate Anchors backbone_shapes = modellib.compute_backbone_shapes(config, image.shape) anchors = utils.generate_pyramid_anchors(config.RPN_ANCHOR_SCALES, config.RPN_ANCHOR_RATIOS, backbone_shapes, config.BACKBONE_STRIDES, config.RPN_ANCHOR_STRIDE) # Print summary of anchors num_levels = len(backbone_shapes) anchors_per_cell = len(config.RPN_ANCHOR_RATIOS) print("Count: ", anchors.shape[0]) print("Scales: ", config.RPN_ANCHOR_SCALES) print("ratios: ", config.RPN_ANCHOR_RATIOS) print("Anchors per Cell: ", anchors_per_cell) print("Levels: ", num_levels) anchors_per_level = [] for l in range(num_levels): num_cells = backbone_shapes[l][0] * backbone_shapes[l][1] anchors_per_level.append(anchors_per_cell * num_cells // config.RPN_ANCHOR_STRIDE**2) print("Anchors in Level {}: {}".format(l, anchors_per_level[l])) # Display fig, ax = plt.subplots(1, figsize=(10, 10)) ax.imshow(image) levels = len(backbone_shapes) for level in range(levels): colors = visualize.random_colors(levels) # Compute the index of the anchors at the center of the image level_start = sum(anchors_per_level[:level]) # sum of anchors of previous levels level_anchors = anchors[level_start:level_start+anchors_per_level[level]] print("Level {}. Anchors: {:6} Feature map Shape: {}".format(level, level_anchors.shape[0], backbone_shapes[level])) center_cell = backbone_shapes[level] // 2 center_cell_index = (center_cell[0] * backbone_shapes[level][1] + center_cell[1]) level_center = center_cell_index * anchors_per_cell center_anchor = anchors_per_cell * ( (center_cell[0] * backbone_shapes[level][1] / config.RPN_ANCHOR_STRIDE**2) \ + center_cell[1] / config.RPN_ANCHOR_STRIDE) level_center = int(center_anchor) # Draw anchors. Brightness show the order in the array, dark to bright. for i, rect in enumerate(level_anchors[level_center:level_center+anchors_per_cell]): y1, x1, y2, x2 = rect p = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, facecolor='none', edgecolor=(i+1)*np.array(colors[level]) / anchors_per_cell) ax.add_patch(p)
Count: 16368 Scales: (8, 16, 32, 64, 128) ratios: [0.5, 1, 2] Anchors per Cell: 3 Levels: 5 Anchors in Level 0: 12288 Anchors in Level 1: 3072 Anchors in Level 2: 768 Anchors in Level 3: 192 Anchors in Level 4: 48 Level 0. Anchors: 12288 Feature map Shape: [64 64] Level 1. Anchors: 3072 Feature map Shape: [32 32] Level 2. Anchors: 768 Feature map Shape: [16 16] Level 3. Anchors: 192 Feature map Shape: [8 8] Level 4. Anchors: 48 Feature map Shape: [4 4]
Image in a Jupyter notebook

Data Generator

# Create data generator random_rois = 2000 g = modellib.data_generator( dataset, crop_config, shuffle=True, random_rois=random_rois, batch_size=4, detection_targets=True)
# Uncomment to run the generator through a lot of images # to catch rare errors # for i in range(1000): # print(i) # _, _ = next(g)
# Get Next Image if random_rois: [normalized_images, image_meta, rpn_match, rpn_bbox, gt_class_ids, gt_boxes, gt_masks, rpn_rois, rois], \ [mrcnn_class_ids, mrcnn_bbox, mrcnn_mask] = next(g) log("rois", rois) log("mrcnn_class_ids", mrcnn_class_ids) log("mrcnn_bbox", mrcnn_bbox) log("mrcnn_mask", mrcnn_mask) else: [normalized_images, image_meta, rpn_match, rpn_bbox, gt_boxes, gt_masks], _ = next(g) log("gt_class_ids", gt_class_ids) log("gt_boxes", gt_boxes) log("gt_masks", gt_masks) log("rpn_match", rpn_match, ) log("rpn_bbox", rpn_bbox) image_id = modellib.parse_image_meta(image_meta)["image_id"][0] print("image_id: ", image_id, dataset.image_reference(image_id)) # Remove the last dim in mrcnn_class_ids. It's only added # to satisfy Keras restriction on target shape. mrcnn_class_ids = mrcnn_class_ids[:,:,0]
rois shape: (4, 128, 4) min: 0.00000 max: 255.00000 int32 mrcnn_class_ids shape: (4, 128, 1) min: 0.00000 max: 1.00000 int32 mrcnn_bbox shape: (4, 128, 2, 4) min: -3.33333 max: 3.00000 float32 mrcnn_mask shape: (4, 128, 28, 28, 2) min: 0.00000 max: 1.00000 float32 gt_class_ids shape: (4, 200) min: 0.00000 max: 1.00000 int32 gt_boxes shape: (4, 200, 4) min: 0.00000 max: 256.00000 int32 gt_masks shape: (4, 56, 56, 200) min: 0.00000 max: 1.00000 bool rpn_match shape: (4, 16368, 1) min: -1.00000 max: 1.00000 int32 rpn_bbox shape: (4, 64, 4) min: -6.93147 max: 5.62500 float64 image_id: 168 edd36ed822e7ed760ff73e0524df22aa5bf5c565efcdc6c39603239c0896e7a8
b = 0 # Restore original image (reverse normalization) sample_image = modellib.unmold_image(normalized_images[b], config) # Compute anchor shifts. indices = np.where(rpn_match[b] == 1)[0] refined_anchors = utils.apply_box_deltas(anchors[indices], rpn_bbox[b, :len(indices)] * config.RPN_BBOX_STD_DEV) log("anchors", anchors) log("refined_anchors", refined_anchors) # Get list of positive anchors positive_anchor_ids = np.where(rpn_match[b] == 1)[0] print("Positive anchors: {}".format(len(positive_anchor_ids))) negative_anchor_ids = np.where(rpn_match[b] == -1)[0] print("Negative anchors: {}".format(len(negative_anchor_ids))) neutral_anchor_ids = np.where(rpn_match[b] == 0)[0] print("Neutral anchors: {}".format(len(neutral_anchor_ids))) # ROI breakdown by class for c, n in zip(dataset.class_names, np.bincount(mrcnn_class_ids[b].flatten())): if n: print("{:23}: {}".format(c[:20], n)) # Show positive anchors fig, ax = plt.subplots(1, figsize=(16, 16)) visualize.draw_boxes(sample_image, boxes=anchors[positive_anchor_ids], refined_boxes=refined_anchors, ax=ax)
anchors shape: (16368, 4) min: -90.50967 max: 282.50967 float64 refined_anchors shape: (32, 4) min: -0.00000 max: 256.00000 float32 Positive anchors: 32 Negative anchors: 32 Neutral anchors: 16304 BG : 86 nucleus : 42
Image in a Jupyter notebook
# Show negative anchors visualize.draw_boxes(sample_image, boxes=anchors[negative_anchor_ids])
Image in a Jupyter notebook
# Show neutral anchors. They don't contribute to training. visualize.draw_boxes(sample_image, boxes=anchors[np.random.choice(neutral_anchor_ids, 100)])
Image in a Jupyter notebook

ROIs

Typically, the RPN network generates region proposals (a.k.a. Regions of Interest, or ROIs). The data generator has the ability to generate proposals as well for illustration and testing purposes. These are controlled by the random_rois parameter.

if random_rois: # Class aware bboxes bbox_specific = mrcnn_bbox[b, np.arange(mrcnn_bbox.shape[1]), mrcnn_class_ids[b], :] # Refined ROIs refined_rois = utils.apply_box_deltas(rois[b].astype(np.float32), bbox_specific[:,:4] * config.BBOX_STD_DEV) # Class aware masks mask_specific = mrcnn_mask[b, np.arange(mrcnn_mask.shape[1]), :, :, mrcnn_class_ids[b]] visualize.draw_rois(sample_image, rois[b], refined_rois, mask_specific, mrcnn_class_ids[b], dataset.class_names) # Any repeated ROIs? rows = np.ascontiguousarray(rois[b]).view(np.dtype((np.void, rois.dtype.itemsize * rois.shape[-1]))) _, idx = np.unique(rows, return_index=True) print("Unique ROIs: {} out of {}".format(len(idx), rois.shape[1]))
Positive ROIs: 42 Negative ROIs: 86 Positive Ratio: 0.33 Unique ROIs: 128 out of 128
Image in a Jupyter notebook
if random_rois: # Dispalay ROIs and corresponding masks and bounding boxes ids = random.sample(range(rois.shape[1]), 8) images = [] titles = [] for i in ids: image = visualize.draw_box(sample_image.copy(), rois[b,i,:4].astype(np.int32), [255, 0, 0]) image = visualize.draw_box(image, refined_rois[i].astype(np.int64), [0, 255, 0]) images.append(image) titles.append("ROI {}".format(i)) images.append(mask_specific[i] * 255) titles.append(dataset.class_names[mrcnn_class_ids[b,i]][:20]) display_images(images, titles, cols=4, cmap="Blues", interpolation="none")
Image in a Jupyter notebook
# Check ratio of positive ROIs in a set of images. if random_rois: limit = 10 temp_g = modellib.data_generator( dataset, crop_config, shuffle=True, random_rois=10000, batch_size=1, detection_targets=True) total = 0 for i in range(limit): _, [ids, _, _] = next(temp_g) positive_rois = np.sum(ids[0] > 0) total += positive_rois print("{:5} {:5.2f}".format(positive_rois, positive_rois/ids.shape[1])) print("Average percent: {:.2f}".format(total/(limit*ids.shape[1])))
42 0.33 42 0.33 42 0.33 42 0.33 42 0.33 42 0.33 42 0.33 42 0.33 42 0.33 42 0.33 Average percent: 0.33