Path: blob/master/examples/vision/integrated_gradients.py
3507 views
"""1Title: Model interpretability with Integrated Gradients2Author: [A_K_Nain](https://twitter.com/A_K_Nain)3Date created: 2020/06/024Last modified: 2020/06/025Description: How to obtain integrated gradients for a classification model.6Accelerator: None7"""89"""10## Integrated Gradients1112[Integrated Gradients](https://arxiv.org/abs/1703.01365) is a technique for13attributing a classification model's prediction to its input features. It is14a model interpretability technique: you can use it to visualize the relationship15between input features and model predictions.1617Integrated Gradients is a variation on computing18the gradient of the prediction output with regard to features of the input.19To compute integrated gradients, we need to perform the following steps:20211. Identify the input and the output. In our case, the input is an image and the22output is the last layer of our model (dense layer with softmax activation).23242. Compute which features are important to a neural network25when making a prediction on a particular data point. To identify these features, we26need to choose a baseline input. A baseline input can be a black image (all pixel27values set to zero) or random noise. The shape of the baseline input needs to be28the same as our input image, e.g. (299, 299, 3).29303. Interpolate the baseline for a given number of steps. The number of steps represents31the steps we need in the gradient approximation for a given input image. The number of32steps is a hyperparameter. The authors recommend using anywhere between3320 and 1000 steps.34354. Preprocess these interpolated images and do a forward pass.365. Get the gradients for these interpolated images.376. Approximate the gradients integral using the trapezoidal rule.3839To read in-depth about integrated gradients and why this method works,40consider reading this excellent41[article](https://distill.pub/2020/attribution-baselines/).4243**References:**4445- Integrated Gradients original [paper](https://arxiv.org/abs/1703.01365)46- [Original implementation](https://github.com/ankurtaly/Integrated-Gradients)47"""4849"""50## Setup51"""525354import numpy as np55import matplotlib.pyplot as plt56from scipy import ndimage57from IPython.display import Image, display5859import tensorflow as tf60import keras61from keras import layers62from keras.applications import xception636465# Size of the input image66img_size = (299, 299, 3)6768# Load Xception model with imagenet weights69model = xception.Xception(weights="imagenet")7071# The local path to our target image72img_path = keras.utils.get_file("elephant.jpg", "https://i.imgur.com/Bvro0YD.png")73display(Image(img_path))7475"""76## Integrated Gradients algorithm77"""787980def get_img_array(img_path, size=(299, 299)):81# `img` is a PIL image of size 299x29982img = keras.utils.load_img(img_path, target_size=size)83# `array` is a float32 Numpy array of shape (299, 299, 3)84array = keras.utils.img_to_array(img)85# We add a dimension to transform our array into a "batch"86# of size (1, 299, 299, 3)87array = np.expand_dims(array, axis=0)88return array899091def get_gradients(img_input, top_pred_idx):92"""Computes the gradients of outputs w.r.t input image.9394Args:95img_input: 4D image tensor96top_pred_idx: Predicted label for the input image9798Returns:99Gradients of the predictions w.r.t img_input100"""101images = tf.cast(img_input, tf.float32)102103with tf.GradientTape() as tape:104tape.watch(images)105preds = model(images)106top_class = preds[:, top_pred_idx]107108grads = tape.gradient(top_class, images)109return grads110111112def get_integrated_gradients(img_input, top_pred_idx, baseline=None, num_steps=50):113"""Computes Integrated Gradients for a predicted label.114115Args:116img_input (ndarray): Original image117top_pred_idx: Predicted label for the input image118baseline (ndarray): The baseline image to start with for interpolation119num_steps: Number of interpolation steps between the baseline120and the input used in the computation of integrated gradients. These121steps along determine the integral approximation error. By default,122num_steps is set to 50.123124Returns:125Integrated gradients w.r.t input image126"""127# If baseline is not provided, start with a black image128# having same size as the input image.129if baseline is None:130baseline = np.zeros(img_size).astype(np.float32)131else:132baseline = baseline.astype(np.float32)133134# 1. Do interpolation.135img_input = img_input.astype(np.float32)136interpolated_image = [137baseline + (step / num_steps) * (img_input - baseline)138for step in range(num_steps + 1)139]140interpolated_image = np.array(interpolated_image).astype(np.float32)141142# 2. Preprocess the interpolated images143interpolated_image = xception.preprocess_input(interpolated_image)144145# 3. Get the gradients146grads = []147for i, img in enumerate(interpolated_image):148img = tf.expand_dims(img, axis=0)149grad = get_gradients(img, top_pred_idx=top_pred_idx)150grads.append(grad[0])151grads = tf.convert_to_tensor(grads, dtype=tf.float32)152153# 4. Approximate the integral using the trapezoidal rule154grads = (grads[:-1] + grads[1:]) / 2.0155avg_grads = tf.reduce_mean(grads, axis=0)156157# 5. Calculate integrated gradients and return158integrated_grads = (img_input - baseline) * avg_grads159return integrated_grads160161162def random_baseline_integrated_gradients(163img_input, top_pred_idx, num_steps=50, num_runs=2164):165"""Generates a number of random baseline images.166167Args:168img_input (ndarray): 3D image169top_pred_idx: Predicted label for the input image170num_steps: Number of interpolation steps between the baseline171and the input used in the computation of integrated gradients. These172steps along determine the integral approximation error. By default,173num_steps is set to 50.174num_runs: number of baseline images to generate175176Returns:177Averaged integrated gradients for `num_runs` baseline images178"""179# 1. List to keep track of Integrated Gradients (IG) for all the images180integrated_grads = []181182# 2. Get the integrated gradients for all the baselines183for run in range(num_runs):184baseline = np.random.random(img_size) * 255185igrads = get_integrated_gradients(186img_input=img_input,187top_pred_idx=top_pred_idx,188baseline=baseline,189num_steps=num_steps,190)191integrated_grads.append(igrads)192193# 3. Return the average integrated gradients for the image194integrated_grads = tf.convert_to_tensor(integrated_grads)195return tf.reduce_mean(integrated_grads, axis=0)196197198"""199## Helper class for visualizing gradients and integrated gradients200"""201202203class GradVisualizer:204"""Plot gradients of the outputs w.r.t an input image."""205206def __init__(self, positive_channel=None, negative_channel=None):207if positive_channel is None:208self.positive_channel = [0, 255, 0]209else:210self.positive_channel = positive_channel211212if negative_channel is None:213self.negative_channel = [255, 0, 0]214else:215self.negative_channel = negative_channel216217def apply_polarity(self, attributions, polarity):218if polarity == "positive":219return np.clip(attributions, 0, 1)220else:221return np.clip(attributions, -1, 0)222223def apply_linear_transformation(224self,225attributions,226clip_above_percentile=99.9,227clip_below_percentile=70.0,228lower_end=0.2,229):230# 1. Get the thresholds231m = self.get_thresholded_attributions(232attributions, percentage=100 - clip_above_percentile233)234e = self.get_thresholded_attributions(235attributions, percentage=100 - clip_below_percentile236)237238# 2. Transform the attributions by a linear function f(x) = a*x + b such that239# f(m) = 1.0 and f(e) = lower_end240transformed_attributions = (1 - lower_end) * (np.abs(attributions) - e) / (241m - e242) + lower_end243244# 3. Make sure that the sign of transformed attributions is the same as original attributions245transformed_attributions *= np.sign(attributions)246247# 4. Only keep values that are bigger than the lower_end248transformed_attributions *= transformed_attributions >= lower_end249250# 5. Clip values and return251transformed_attributions = np.clip(transformed_attributions, 0.0, 1.0)252return transformed_attributions253254def get_thresholded_attributions(self, attributions, percentage):255if percentage == 100.0:256return np.min(attributions)257258# 1. Flatten the attributions259flatten_attr = attributions.flatten()260261# 2. Get the sum of the attributions262total = np.sum(flatten_attr)263264# 3. Sort the attributions from largest to smallest.265sorted_attributions = np.sort(np.abs(flatten_attr))[::-1]266267# 4. Calculate the percentage of the total sum that each attribution268# and the values about it contribute.269cum_sum = 100.0 * np.cumsum(sorted_attributions) / total270271# 5. Threshold the attributions by the percentage272indices_to_consider = np.where(cum_sum >= percentage)[0][0]273274# 6. Select the desired attributions and return275attributions = sorted_attributions[indices_to_consider]276return attributions277278def binarize(self, attributions, threshold=0.001):279return attributions > threshold280281def morphological_cleanup_fn(self, attributions, structure=np.ones((4, 4))):282closed = ndimage.grey_closing(attributions, structure=structure)283opened = ndimage.grey_opening(closed, structure=structure)284return opened285286def draw_outlines(287self,288attributions,289percentage=90,290connected_component_structure=np.ones((3, 3)),291):292# 1. Binarize the attributions.293attributions = self.binarize(attributions)294295# 2. Fill the gaps296attributions = ndimage.binary_fill_holes(attributions)297298# 3. Compute connected components299connected_components, num_comp = ndimage.label(300attributions, structure=connected_component_structure301)302303# 4. Sum up the attributions for each component304total = np.sum(attributions[connected_components > 0])305component_sums = []306for comp in range(1, num_comp + 1):307mask = connected_components == comp308component_sum = np.sum(attributions[mask])309component_sums.append((component_sum, mask))310311# 5. Compute the percentage of top components to keep312sorted_sums_and_masks = sorted(component_sums, key=lambda x: x[0], reverse=True)313sorted_sums = list(zip(*sorted_sums_and_masks))[0]314cumulative_sorted_sums = np.cumsum(sorted_sums)315cutoff_threshold = percentage * total / 100316cutoff_idx = np.where(cumulative_sorted_sums >= cutoff_threshold)[0][0]317if cutoff_idx > 2:318cutoff_idx = 2319320# 6. Set the values for the kept components321border_mask = np.zeros_like(attributions)322for i in range(cutoff_idx + 1):323border_mask[sorted_sums_and_masks[i][1]] = 1324325# 7. Make the mask hollow and show only the border326eroded_mask = ndimage.binary_erosion(border_mask, iterations=1)327border_mask[eroded_mask] = 0328329# 8. Return the outlined mask330return border_mask331332def process_grads(333self,334image,335attributions,336polarity="positive",337clip_above_percentile=99.9,338clip_below_percentile=0,339morphological_cleanup=False,340structure=np.ones((3, 3)),341outlines=False,342outlines_component_percentage=90,343overlay=True,344):345if polarity not in ["positive", "negative"]:346raise ValueError(347f""" Allowed polarity values: 'positive' or 'negative'348but provided {polarity}"""349)350if clip_above_percentile < 0 or clip_above_percentile > 100:351raise ValueError("clip_above_percentile must be in [0, 100]")352353if clip_below_percentile < 0 or clip_below_percentile > 100:354raise ValueError("clip_below_percentile must be in [0, 100]")355356# 1. Apply polarity357if polarity == "positive":358attributions = self.apply_polarity(attributions, polarity=polarity)359channel = self.positive_channel360else:361attributions = self.apply_polarity(attributions, polarity=polarity)362attributions = np.abs(attributions)363channel = self.negative_channel364365# 2. Take average over the channels366attributions = np.average(attributions, axis=2)367368# 3. Apply linear transformation to the attributions369attributions = self.apply_linear_transformation(370attributions,371clip_above_percentile=clip_above_percentile,372clip_below_percentile=clip_below_percentile,373lower_end=0.0,374)375376# 4. Cleanup377if morphological_cleanup:378attributions = self.morphological_cleanup_fn(379attributions, structure=structure380)381# 5. Draw the outlines382if outlines:383attributions = self.draw_outlines(384attributions, percentage=outlines_component_percentage385)386387# 6. Expand the channel axis and convert to RGB388attributions = np.expand_dims(attributions, 2) * channel389390# 7.Superimpose on the original image391if overlay:392attributions = np.clip((attributions * 0.8 + image), 0, 255)393return attributions394395def visualize(396self,397image,398gradients,399integrated_gradients,400polarity="positive",401clip_above_percentile=99.9,402clip_below_percentile=0,403morphological_cleanup=False,404structure=np.ones((3, 3)),405outlines=False,406outlines_component_percentage=90,407overlay=True,408figsize=(15, 8),409):410# 1. Make two copies of the original image411img1 = np.copy(image)412img2 = np.copy(image)413414# 2. Process the normal gradients415grads_attr = self.process_grads(416image=img1,417attributions=gradients,418polarity=polarity,419clip_above_percentile=clip_above_percentile,420clip_below_percentile=clip_below_percentile,421morphological_cleanup=morphological_cleanup,422structure=structure,423outlines=outlines,424outlines_component_percentage=outlines_component_percentage,425overlay=overlay,426)427428# 3. Process the integrated gradients429igrads_attr = self.process_grads(430image=img2,431attributions=integrated_gradients,432polarity=polarity,433clip_above_percentile=clip_above_percentile,434clip_below_percentile=clip_below_percentile,435morphological_cleanup=morphological_cleanup,436structure=structure,437outlines=outlines,438outlines_component_percentage=outlines_component_percentage,439overlay=overlay,440)441442_, ax = plt.subplots(1, 3, figsize=figsize)443ax[0].imshow(image)444ax[1].imshow(grads_attr.astype(np.uint8))445ax[2].imshow(igrads_attr.astype(np.uint8))446447ax[0].set_title("Input")448ax[1].set_title("Normal gradients")449ax[2].set_title("Integrated gradients")450plt.show()451452453"""454## Let's test-drive it455"""456457# 1. Convert the image to numpy array458img = get_img_array(img_path)459460# 2. Keep a copy of the original image461orig_img = np.copy(img[0]).astype(np.uint8)462463# 3. Preprocess the image464img_processed = tf.cast(xception.preprocess_input(img), dtype=tf.float32)465466# 4. Get model predictions467preds = model.predict(img_processed)468top_pred_idx = tf.argmax(preds[0])469print("Predicted:", top_pred_idx, xception.decode_predictions(preds, top=1)[0])470471# 5. Get the gradients of the last layer for the predicted label472grads = get_gradients(img_processed, top_pred_idx=top_pred_idx)473474# 6. Get the integrated gradients475igrads = random_baseline_integrated_gradients(476np.copy(orig_img), top_pred_idx=top_pred_idx, num_steps=50, num_runs=2477)478479# 7. Process the gradients and plot480vis = GradVisualizer()481vis.visualize(482image=orig_img,483gradients=grads[0].numpy(),484integrated_gradients=igrads.numpy(),485clip_above_percentile=99,486clip_below_percentile=0,487)488489vis.visualize(490image=orig_img,491gradients=grads[0].numpy(),492integrated_gradients=igrads.numpy(),493clip_above_percentile=95,494clip_below_percentile=28,495morphological_cleanup=True,496outlines=True,497)498499500