Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/beginner_source/introyt/captumyt.py
Views: 713
"""1`Introduction <introyt1_tutorial.html>`_ ||2`Tensors <tensors_deeper_tutorial.html>`_ ||3`Autograd <autogradyt_tutorial.html>`_ ||4`Building Models <modelsyt_tutorial.html>`_ ||5`TensorBoard Support <tensorboardyt_tutorial.html>`_ ||6`Training Models <trainingyt.html>`_ ||7**Model Understanding**89Model Understanding with Captum10===============================1112Follow along with the video below or on `youtube <https://www.youtube.com/watch?v=Am2EF9CLu-g>`__. Download the notebook and corresponding files13`here <https://pytorch-tutorial-assets.s3.amazonaws.com/youtube-series/video7.zip>`__.1415.. raw:: html1617<div style="margin-top:10px; margin-bottom:10px;">18<iframe width="560" height="315" src="https://www.youtube.com/embed/Am2EF9CLu-g" frameborder="0" allow="accelerometer; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>19</div>2021`Captum <https://captum.ai/>`__ (“comprehension” in Latin) is an open22source, extensible library for model interpretability built on PyTorch.2324With the increase in model complexity and the resulting lack of25transparency, model interpretability methods have become increasingly26important. Model understanding is both an active area of research as27well as an area of focus for practical applications across industries28using machine learning. Captum provides state-of-the-art algorithms,29including Integrated Gradients, to provide researchers and developers30with an easy way to understand which features are contributing to a31model’s output.3233Full documentation, an API reference, and a suite of tutorials on34specific topics are available at the `captum.ai <https://captum.ai/>`__35website.3637Introduction38------------3940Captum’s approach to model interpretability is in terms of41*attributions.* There are three kinds of attributions available in42Captum:4344- **Feature Attribution** seeks to explain a particular output in terms45of features of the input that generated it. Explaining whether a46movie review was positive or negative in terms of certain words in47the review is an example of feature attribution.48- **Layer Attribution** examines the activity of a model’s hidden layer49subsequent to a particular input. Examining the spatially-mapped50output of a convolutional layer in response to an input image in an51example of layer attribution.52- **Neuron Attribution** is analagous to layer attribution, but focuses53on the activity of a single neuron.5455In this interactive notebook, we’ll look at Feature Attribution and56Layer Attribution.5758Each of the three attribution types has multiple **attribution59algorithms** associated with it. Many attribution algorithms fall into60two broad categories:6162- **Gradient-based algorithms** calculate the backward gradients of a63model output, layer output, or neuron activation with respect to the64input. **Integrated Gradients** (for features), **Layer Gradient \*65Activation**, and **Neuron Conductance** are all gradient-based66algorithms.67- **Perturbation-based algorithms** examine the changes in the output68of a model, layer, or neuron in response to changes in the input. The69input perturbations may be directed or random. **Occlusion,**70**Feature Ablation,** and **Feature Permutation** are all71perturbation-based algorithms.7273We’ll be examining algorithms of both types below.7475Especially where large models are involved, it can be valuable to76visualize attribution data in ways that relate it easily to the input77features being examined. While it is certainly possible to create your78own visualizations with Matplotlib, Plotly, or similar tools, Captum79offers enhanced tools specific to its attributions:8081- The ``captum.attr.visualization`` module (imported below as ``viz``)82provides helpful functions for visualizing attributions related to83images.84- **Captum Insights** is an easy-to-use API on top of Captum that85provides a visualization widget with ready-made visualizations for86image, text, and arbitrary model types.8788Both of these visualization toolsets will be demonstrated in this89notebook. The first few examples will focus on computer vision use90cases, but the Captum Insights section at the end will demonstrate91visualization of attributions in a multi-model, visual92question-and-answer model.9394Installation95------------9697Before you get started, you need to have a Python environment with:9899- Python version 3.6 or higher100- For the Captum Insights example, Flask 1.1 or higher and Flask-Compress101(the latest version is recommended)102- PyTorch version 1.2 or higher (the latest version is recommended)103- TorchVision version 0.6 or higher (the latest version is recommended)104- Captum (the latest version is recommended)105- Matplotlib version 3.3.4, since Captum currently uses a Matplotlib106function whose arguments have been renamed in later versions107108To install Captum in an Anaconda or pip virtual environment, use the109appropriate command for your environment below:110111With ``conda``:112113.. code-block:: sh114115conda install pytorch torchvision captum flask-compress matplotlib=3.3.4 -c pytorch116117With ``pip``:118119.. code-block:: sh120121pip install torch torchvision captum matplotlib==3.3.4 Flask-Compress122123Restart this notebook in the environment you set up, and you’re ready to124go!125126127A First Example128---------------129130To start, let’s take a simple, visual example. We’ll start with a ResNet131model pretrained on the ImageNet dataset. We’ll get a test input, and132use different **Feature Attribution** algorithms to examine how the133input images affect the output, and see a helpful visualization of this134input attribution map for some test images.135136First, some imports:137138"""139140import torch141import torch.nn.functional as F142import torchvision.transforms as transforms143import torchvision.models as models144145import captum146from captum.attr import IntegratedGradients, Occlusion, LayerGradCam, LayerAttribution147from captum.attr import visualization as viz148149import os, sys150import json151152import numpy as np153from PIL import Image154import matplotlib.pyplot as plt155from matplotlib.colors import LinearSegmentedColormap156157158#########################################################################159# Now we’ll use the TorchVision model library to download a pretrained160# ResNet. Since we’re not training, we’ll place it in evaluation mode for161# now.162#163164model = models.resnet18(weights='IMAGENET1K_V1')165model = model.eval()166167168#######################################################################169# The place where you got this interactive notebook should also have an170# ``img`` folder with a file ``cat.jpg`` in it.171#172173test_img = Image.open('img/cat.jpg')174test_img_data = np.asarray(test_img)175plt.imshow(test_img_data)176plt.show()177178179##########################################################################180# Our ResNet model was trained on the ImageNet dataset, and expects images181# to be of a certain size, with the channel data normalized to a specific182# range of values. We’ll also pull in the list of human-readable labels183# for the categories our model recognizes - that should be in the ``img``184# folder as well.185#186187# model expects 224x224 3-color image188transform = transforms.Compose([189transforms.Resize(224),190transforms.CenterCrop(224),191transforms.ToTensor()192])193194# standard ImageNet normalization195transform_normalize = transforms.Normalize(196mean=[0.485, 0.456, 0.406],197std=[0.229, 0.224, 0.225]198)199200transformed_img = transform(test_img)201input_img = transform_normalize(transformed_img)202input_img = input_img.unsqueeze(0) # the model requires a dummy batch dimension203204labels_path = 'img/imagenet_class_index.json'205with open(labels_path) as json_data:206idx_to_labels = json.load(json_data)207208209######################################################################210# Now, we can ask the question: What does our model think this image211# represents?212#213214output = model(input_img)215output = F.softmax(output, dim=1)216prediction_score, pred_label_idx = torch.topk(output, 1)217pred_label_idx.squeeze_()218predicted_label = idx_to_labels[str(pred_label_idx.item())][1]219print('Predicted:', predicted_label, '(', prediction_score.squeeze().item(), ')')220221222######################################################################223# We’ve confirmed that ResNet thinks our image of a cat is, in fact, a224# cat. But *why* does the model think this is an image of a cat?225#226# For the answer to that, we turn to Captum.227#228229230##########################################################################231# Feature Attribution with Integrated Gradients232# ---------------------------------------------233#234# **Feature attribution** attributes a particular output to features of235# the input. It uses a specific input - here, our test image - to generate236# a map of the relative importance of each input feature to a particular237# output feature.238#239# `Integrated240# Gradients <https://captum.ai/api/integrated_gradients.html>`__ is one of241# the feature attribution algorithms available in Captum. Integrated242# Gradients assigns an importance score to each input feature by243# approximating the integral of the gradients of the model’s output with244# respect to the inputs.245#246# In our case, we’re going to be taking a specific element of the output247# vector - that is, the one indicating the model’s confidence in its248# chosen category - and use Integrated Gradients to understand what parts249# of the input image contributed to this output.250#251# Once we have the importance map from Integrated Gradients, we’ll use the252# visualization tools in Captum to give a helpful representation of the253# importance map. Captum’s ``visualize_image_attr()`` function provides a254# variety of options for customizing display of your attribution data.255# Here, we pass in a custom Matplotlib color map.256#257# Running the cell with the ``integrated_gradients.attribute()`` call will258# usually take a minute or two.259#260261# Initialize the attribution algorithm with the model262integrated_gradients = IntegratedGradients(model)263264# Ask the algorithm to attribute our output target to265attributions_ig = integrated_gradients.attribute(input_img, target=pred_label_idx, n_steps=200)266267# Show the original image for comparison268_ = viz.visualize_image_attr(None, np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),269method="original_image", title="Original Image")270271default_cmap = LinearSegmentedColormap.from_list('custom blue',272[(0, '#ffffff'),273(0.25, '#0000ff'),274(1, '#0000ff')], N=256)275276_ = viz.visualize_image_attr(np.transpose(attributions_ig.squeeze().cpu().detach().numpy(), (1,2,0)),277np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),278method='heat_map',279cmap=default_cmap,280show_colorbar=True,281sign='positive',282title='Integrated Gradients')283284285#######################################################################286# In the image above, you should see that Integrated Gradients gives us287# the strongest signal around the cat’s location in the image.288#289290291##########################################################################292# Feature Attribution with Occlusion293# ----------------------------------294#295# Gradient-based attribution methods help to understand the model in terms296# of directly computing out the output changes with respect to the input.297# *Perturbation-based attribution* methods approach this more directly, by298# introducing changes to the input to measure the effect on the output.299# `Occlusion <https://captum.ai/api/occlusion.html>`__ is one such method.300# It involves replacing sections of the input image, and examining the301# effect on the output signal.302#303# Below, we set up Occlusion attribution. Similarly to configuring a304# convolutional neural network, you can specify the size of the target305# region, and a stride length to determine the spacing of individual306# measurements. We’ll visualize the output of our Occlusion attribution307# with ``visualize_image_attr_multiple()``, showing heat maps of both308# positive and negative attribution by region, and by masking the original309# image with the positive attribution regions. The masking gives a very310# instructive view of what regions of our cat photo the model found to be311# most “cat-like”.312#313314occlusion = Occlusion(model)315316attributions_occ = occlusion.attribute(input_img,317target=pred_label_idx,318strides=(3, 8, 8),319sliding_window_shapes=(3,15, 15),320baselines=0)321322323_ = viz.visualize_image_attr_multiple(np.transpose(attributions_occ.squeeze().cpu().detach().numpy(), (1,2,0)),324np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),325["original_image", "heat_map", "heat_map", "masked_image"],326["all", "positive", "negative", "positive"],327show_colorbar=True,328titles=["Original", "Positive Attribution", "Negative Attribution", "Masked"],329fig_size=(18, 6)330)331332333######################################################################334# Again, we see greater significance placed on the region of the image335# that contains the cat.336#337338339#########################################################################340# Layer Attribution with Layer GradCAM341# ------------------------------------342#343# **Layer Attribution** allows you to attribute the activity of hidden344# layers within your model to features of your input. Below, we’ll use a345# layer attribution algorithm to examine the activity of one of the346# convolutional layers within our model.347#348# GradCAM computes the gradients of the target output with respect to the349# given layer, averages for each output channel (dimension 2 of output),350# and multiplies the average gradient for each channel by the layer351# activations. The results are summed over all channels. GradCAM is352# designed for convnets; since the activity of convolutional layers often353# maps spatially to the input, GradCAM attributions are often upsampled354# and used to mask the input.355#356# Layer attribution is set up similarly to input attribution, except that357# in addition to the model, you must specify a hidden layer within the358# model that you wish to examine. As above, when we call ``attribute()``,359# we specify the target class of interest.360#361362layer_gradcam = LayerGradCam(model, model.layer3[1].conv2)363attributions_lgc = layer_gradcam.attribute(input_img, target=pred_label_idx)364365_ = viz.visualize_image_attr(attributions_lgc[0].cpu().permute(1,2,0).detach().numpy(),366sign="all",367title="Layer 3 Block 1 Conv 2")368369370##########################################################################371# We’ll use the convenience method ``interpolate()`` in the372# `LayerAttribution <https://captum.ai/api/base_classes.html?highlight=layerattribution#captum.attr.LayerAttribution>`__373# base class to upsample this attribution data for comparison to the input374# image.375#376377upsamp_attr_lgc = LayerAttribution.interpolate(attributions_lgc, input_img.shape[2:])378379print(attributions_lgc.shape)380print(upsamp_attr_lgc.shape)381print(input_img.shape)382383_ = viz.visualize_image_attr_multiple(upsamp_attr_lgc[0].cpu().permute(1,2,0).detach().numpy(),384transformed_img.permute(1,2,0).numpy(),385["original_image","blended_heat_map","masked_image"],386["all","positive","positive"],387show_colorbar=True,388titles=["Original", "Positive Attribution", "Masked"],389fig_size=(18, 6))390391392#######################################################################393# Visualizations such as this can give you novel insights into how your394# hidden layers respond to your input.395#396397398##########################################################################399# Visualization with Captum Insights400# ----------------------------------401#402# Captum Insights is an interpretability visualization widget built on top403# of Captum to facilitate model understanding. Captum Insights works404# across images, text, and other features to help users understand feature405# attribution. It allows you to visualize attribution for multiple406# input/output pairs, and provides visualization tools for image, text,407# and arbitrary data.408#409# In this section of the notebook, we’ll visualize multiple image410# classification inferences with Captum Insights.411#412# First, let’s gather some image and see what the model thinks of them.413# For variety, we’ll take our cat, a teapot, and a trilobite fossil:414#415416imgs = ['img/cat.jpg', 'img/teapot.jpg', 'img/trilobite.jpg']417418for img in imgs:419img = Image.open(img)420transformed_img = transform(img)421input_img = transform_normalize(transformed_img)422input_img = input_img.unsqueeze(0) # the model requires a dummy batch dimension423424output = model(input_img)425output = F.softmax(output, dim=1)426prediction_score, pred_label_idx = torch.topk(output, 1)427pred_label_idx.squeeze_()428predicted_label = idx_to_labels[str(pred_label_idx.item())][1]429print('Predicted:', predicted_label, '/', pred_label_idx.item(), ' (', prediction_score.squeeze().item(), ')')430431432##########################################################################433# …and it looks like our model is identifying them all correctly - but of434# course, we want to dig deeper. For that we’ll use the Captum Insights435# widget, which we configure with an ``AttributionVisualizer`` object,436# imported below. The ``AttributionVisualizer`` expects batches of data,437# so we’ll bring in Captum’s ``Batch`` helper class. And we’ll be looking438# at images specifically, so well also import ``ImageFeature``.439#440# We configure the ``AttributionVisualizer`` with the following arguments:441#442# - An array of models to be examined (in our case, just the one)443# - A scoring function, which allows Captum Insights to pull out the444# top-k predictions from a model445# - An ordered, human-readable list of classes our model is trained on446# - A list of features to look for - in our case, an ``ImageFeature``447# - A dataset, which is an iterable object returning batches of inputs448# and labels - just like you’d use for training449#450451from captum.insights import AttributionVisualizer, Batch452from captum.insights.attr_vis.features import ImageFeature453454# Baseline is all-zeros input - this may differ depending on your data455def baseline_func(input):456return input * 0457458# merging our image transforms from above459def full_img_transform(input):460i = Image.open(input)461i = transform(i)462i = transform_normalize(i)463i = i.unsqueeze(0)464return i465466467input_imgs = torch.cat(list(map(lambda i: full_img_transform(i), imgs)), 0)468469visualizer = AttributionVisualizer(470models=[model],471score_func=lambda o: torch.nn.functional.softmax(o, 1),472classes=list(map(lambda k: idx_to_labels[k][1], idx_to_labels.keys())),473features=[474ImageFeature(475"Photo",476baseline_transforms=[baseline_func],477input_transforms=[],478)479],480dataset=[Batch(input_imgs, labels=[282,849,69])]481)482483484#########################################################################485# Note that running the cell above didn’t take much time at all, unlike486# our attributions above. That’s because Captum Insights lets you487# configure different attribution algorithms in a visual widget, after488# which it will compute and display the attributions. *That* process will489# take a few minutes.490#491# Running the cell below will render the Captum Insights widget. You can492# then choose attributions methods and their arguments, filter model493# responses based on predicted class or prediction correctness, see the494# model’s predictions with associated probabilities, and view heatmaps of495# the attribution compared with the original image.496#497498visualizer.render()499500501