📚 The CoCalc Library - books, templates and other resources
License: OTHER
from __future__ import division12import numpy as np3import matplotlib.pyplot as plt45from scipy.ndimage import grey_dilation67from skimage import img_as_float8from skimage import color9from skimage import exposure10from skimage.util.dtype import dtype_limits111213__all__ = ['imshow_all', 'imshow_with_histogram', 'mean_filter_demo',14'mean_filter_interactive_demo', 'plot_cdf', 'plot_histogram']151617# Gray-scale images should actually be gray!18plt.rcParams['image.cmap'] = 'gray'192021#--------------------------------------------------------------------------22# Custom `imshow` functions23#--------------------------------------------------------------------------2425def imshow_rgb_shifted(rgb_image, shift=100, ax=None):26"""Plot each RGB layer with an x, y shift."""27if ax is None:28ax = plt.gca()2930height, width, n_channels = rgb_image.shape31x = y = 032for i_channel, channel in enumerate(iter_channels(rgb_image)):33image = np.zeros((height, width, n_channels), dtype=channel.dtype)3435image[:, :, i_channel] = channel36ax.imshow(image, extent=[x, x+width, y, y+height], alpha=0.7)37x += shift38y += shift39# `imshow` fits the extents of the last image shown, so we need to rescale.40ax.autoscale()41ax.set_axis_off()424344def imshow_all(*images, **kwargs):45""" Plot a series of images side-by-side.4647Convert all images to float so that images have a common intensity range.4849Parameters50----------51limits : str52Control the intensity limits. By default, 'image' is used set the53min/max intensities to the min/max of all images. Setting `limits` to54'dtype' can also be used if you want to preserve the image exposure.55titles : list of str56Titles for subplots. If the length of titles is less than the number57of images, empty strings are appended.58kwargs : dict59Additional keyword-arguments passed to `imshow`.60"""61images = [img_as_float(img) for img in images]6263titles = kwargs.pop('titles', [])64if len(titles) != len(images):65titles = list(titles) + [''] * (len(images) - len(titles))6667limits = kwargs.pop('limits', 'image')68if limits == 'image':69kwargs.setdefault('vmin', min(img.min() for img in images))70kwargs.setdefault('vmax', max(img.max() for img in images))71elif limits == 'dtype':72vmin, vmax = dtype_limits(images[0])73kwargs.setdefault('vmin', vmin)74kwargs.setdefault('vmax', vmax)7576nrows, ncols = kwargs.get('shape', (1, len(images)))7778size = nrows * kwargs.pop('size', 5)79width = size * len(images)80if nrows > 1:81width /= nrows * 1.3382fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(width, size))83for ax, img, label in zip(axes.ravel(), images, titles):84ax.imshow(img, **kwargs)85ax.set_title(label)868788def imshow_with_histogram(image, **kwargs):89""" Plot an image side-by-side with its histogram.9091- Plot the image next to the histogram92- Plot each RGB channel separately (if input is color)93- Automatically flatten channels94- Select reasonable bins based on the image's dtype9596See `plot_histogram` for information on how the histogram is plotted.97"""98width, height = plt.rcParams['figure.figsize']99fig, (ax_image, ax_hist) = plt.subplots(ncols=2, figsize=(2*width, height))100101kwargs.setdefault('cmap', plt.cm.gray)102ax_image.imshow(image, **kwargs)103plot_histogram(image, ax=ax_hist)104105# pretty it up106ax_image.set_axis_off()107match_axes_height(ax_image, ax_hist)108return ax_image, ax_hist109110111#--------------------------------------------------------------------------112# Helper functions113#--------------------------------------------------------------------------114115116def match_axes_height(ax_src, ax_dst):117""" Match the axes height of two axes objects.118119The height of `ax_dst` is synced to that of `ax_src`.120"""121# HACK: plot geometry isn't set until the plot is drawn122plt.draw()123dst = ax_dst.get_position()124src = ax_src.get_position()125ax_dst.set_position([dst.xmin, src.ymin, dst.width, src.height])126127128def plot_cdf(image, ax=None):129img_cdf, bins = exposure.cumulative_distribution(image)130ax.plot(bins, img_cdf, 'r')131ax.set_ylabel("Fraction of pixels below intensity")132133134def plot_histogram(image, ax=None, **kwargs):135""" Plot the histogram of an image (gray-scale or RGB) on `ax`.136137Calculate histogram using `skimage.exposure.histogram` and plot as filled138line. If an image has a 3rd dimension, assume it's RGB and plot each139channel separately.140"""141ax = ax if ax is not None else plt.gca()142143if image.ndim == 2:144_plot_histogram(ax, image, color='black', **kwargs)145elif image.ndim == 3:146# `channel` is the red, green, or blue channel of the image.147for channel, channel_color in zip(iter_channels(image), 'rgb'):148_plot_histogram(ax, channel, color=channel_color, **kwargs)149150151def _plot_histogram(ax, image, alpha=0.3, **kwargs):152# Use skimage's histogram function which has nice defaults for153# integer and float images.154hist, bin_centers = exposure.histogram(image)155ax.fill_between(bin_centers, hist, alpha=alpha, **kwargs)156ax.set_xlabel('intensity')157ax.set_ylabel('# pixels')158159160def iter_channels(color_image):161"""Yield color channels of an image."""162# Roll array-axis so that we iterate over the color channels of an image.163for channel in np.rollaxis(color_image, -1):164yield channel165166167#--------------------------------------------------------------------------168# Convolution Demo169#--------------------------------------------------------------------------170171def mean_filter_demo(image, vmax=1):172mean_factor = 1.0 / 9.0 # This assumes a 3x3 kernel.173iter_kernel_and_subimage = iter_kernel(image)174175image_cache = []176177def mean_filter_step(i_step):178while i_step >= len(image_cache):179filtered = image if i_step == 0 else image_cache[-1][1]180filtered = filtered.copy()181182(i, j), mask, subimage = iter_kernel_and_subimage.next()183filter_overlay = color.label2rgb(mask, image, bg_label=0,184colors=('yellow', 'red'))185filtered[i, j] = np.sum(mean_factor * subimage)186image_cache.append((filter_overlay, filtered))187188imshow_all(*image_cache[i_step], vmax=vmax)189plt.show()190return mean_filter_step191192193def mean_filter_interactive_demo(image):194from IPython.html import widgets195mean_filter_step = mean_filter_demo(image)196step_slider = widgets.IntSliderWidget(min=0, max=image.size-1, value=0)197widgets.interact(mean_filter_step, i_step=step_slider)198199200def iter_kernel(image, size=1):201""" Yield position, kernel mask, and image for each pixel in the image.202203The kernel mask has a 2 at the center pixel and 1 around it. The actual204width of the kernel is 2*size + 1.205"""206width = 2*size + 1207for (i, j), pixel in iter_pixels(image):208mask = np.zeros(image.shape, dtype='int16')209mask[i, j] = 1210mask = grey_dilation(mask, size=width)211mask[i, j] = 2212subimage = image[bounded_slice((i, j), image.shape[:2], size=size)]213yield (i, j), mask, subimage214215216def iter_pixels(image):217""" Yield pixel position (row, column) and pixel intensity. """218height, width = image.shape[:2]219for i in range(height):220for j in range(width):221yield (i, j), image[i, j]222223224def bounded_slice(center, xy_max, size=1, i_min=0):225slices = []226for i, i_max in zip(center, xy_max):227slices.append(slice(max(i - size, i_min), min(i + size + 1, i_max)))228return slices229230231