CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/introyt/captumyt.py
Views: 494
1
"""
2
`Introduction <introyt1_tutorial.html>`_ ||
3
`Tensors <tensors_deeper_tutorial.html>`_ ||
4
`Autograd <autogradyt_tutorial.html>`_ ||
5
`Building Models <modelsyt_tutorial.html>`_ ||
6
`TensorBoard Support <tensorboardyt_tutorial.html>`_ ||
7
`Training Models <trainingyt.html>`_ ||
8
**Model Understanding**
9
10
Model Understanding with Captum
11
===============================
12
13
Follow along with the video below or on `youtube <https://www.youtube.com/watch?v=Am2EF9CLu-g>`__. Download the notebook and corresponding files
14
`here <https://pytorch-tutorial-assets.s3.amazonaws.com/youtube-series/video7.zip>`__.
15
16
.. raw:: html
17
18
<div style="margin-top:10px; margin-bottom:10px;">
19
<iframe width="560" height="315" src="https://www.youtube.com/embed/Am2EF9CLu-g" frameborder="0" allow="accelerometer; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>
20
</div>
21
22
`Captum <https://captum.ai/>`__ (“comprehension” in Latin) is an open
23
source, extensible library for model interpretability built on PyTorch.
24
25
With the increase in model complexity and the resulting lack of
26
transparency, model interpretability methods have become increasingly
27
important. Model understanding is both an active area of research as
28
well as an area of focus for practical applications across industries
29
using machine learning. Captum provides state-of-the-art algorithms,
30
including Integrated Gradients, to provide researchers and developers
31
with an easy way to understand which features are contributing to a
32
model’s output.
33
34
Full documentation, an API reference, and a suite of tutorials on
35
specific topics are available at the `captum.ai <https://captum.ai/>`__
36
website.
37
38
Introduction
39
------------
40
41
Captum’s approach to model interpretability is in terms of
42
*attributions.* There are three kinds of attributions available in
43
Captum:
44
45
- **Feature Attribution** seeks to explain a particular output in terms
46
of features of the input that generated it. Explaining whether a
47
movie review was positive or negative in terms of certain words in
48
the review is an example of feature attribution.
49
- **Layer Attribution** examines the activity of a model’s hidden layer
50
subsequent to a particular input. Examining the spatially-mapped
51
output of a convolutional layer in response to an input image in an
52
example of layer attribution.
53
- **Neuron Attribution** is analagous to layer attribution, but focuses
54
on the activity of a single neuron.
55
56
In this interactive notebook, we’ll look at Feature Attribution and
57
Layer Attribution.
58
59
Each of the three attribution types has multiple **attribution
60
algorithms** associated with it. Many attribution algorithms fall into
61
two broad categories:
62
63
- **Gradient-based algorithms** calculate the backward gradients of a
64
model output, layer output, or neuron activation with respect to the
65
input. **Integrated Gradients** (for features), **Layer Gradient \*
66
Activation**, and **Neuron Conductance** are all gradient-based
67
algorithms.
68
- **Perturbation-based algorithms** examine the changes in the output
69
of a model, layer, or neuron in response to changes in the input. The
70
input perturbations may be directed or random. **Occlusion,**
71
**Feature Ablation,** and **Feature Permutation** are all
72
perturbation-based algorithms.
73
74
We’ll be examining algorithms of both types below.
75
76
Especially where large models are involved, it can be valuable to
77
visualize attribution data in ways that relate it easily to the input
78
features being examined. While it is certainly possible to create your
79
own visualizations with Matplotlib, Plotly, or similar tools, Captum
80
offers enhanced tools specific to its attributions:
81
82
- The ``captum.attr.visualization`` module (imported below as ``viz``)
83
provides helpful functions for visualizing attributions related to
84
images.
85
- **Captum Insights** is an easy-to-use API on top of Captum that
86
provides a visualization widget with ready-made visualizations for
87
image, text, and arbitrary model types.
88
89
Both of these visualization toolsets will be demonstrated in this
90
notebook. The first few examples will focus on computer vision use
91
cases, but the Captum Insights section at the end will demonstrate
92
visualization of attributions in a multi-model, visual
93
question-and-answer model.
94
95
Installation
96
------------
97
98
Before you get started, you need to have a Python environment with:
99
100
- Python version 3.6 or higher
101
- For the Captum Insights example, Flask 1.1 or higher and Flask-Compress
102
(the latest version is recommended)
103
- PyTorch version 1.2 or higher (the latest version is recommended)
104
- TorchVision version 0.6 or higher (the latest version is recommended)
105
- Captum (the latest version is recommended)
106
- Matplotlib version 3.3.4, since Captum currently uses a Matplotlib
107
function whose arguments have been renamed in later versions
108
109
To install Captum in an Anaconda or pip virtual environment, use the
110
appropriate command for your environment below:
111
112
With ``conda``:
113
114
.. code-block:: sh
115
116
conda install pytorch torchvision captum flask-compress matplotlib=3.3.4 -c pytorch
117
118
With ``pip``:
119
120
.. code-block:: sh
121
122
pip install torch torchvision captum matplotlib==3.3.4 Flask-Compress
123
124
Restart this notebook in the environment you set up, and you’re ready to
125
go!
126
127
128
A First Example
129
---------------
130
131
To start, let’s take a simple, visual example. We’ll start with a ResNet
132
model pretrained on the ImageNet dataset. We’ll get a test input, and
133
use different **Feature Attribution** algorithms to examine how the
134
input images affect the output, and see a helpful visualization of this
135
input attribution map for some test images.
136
137
First, some imports:
138
139
"""
140
141
import torch
142
import torch.nn.functional as F
143
import torchvision.transforms as transforms
144
import torchvision.models as models
145
146
import captum
147
from captum.attr import IntegratedGradients, Occlusion, LayerGradCam, LayerAttribution
148
from captum.attr import visualization as viz
149
150
import os, sys
151
import json
152
153
import numpy as np
154
from PIL import Image
155
import matplotlib.pyplot as plt
156
from matplotlib.colors import LinearSegmentedColormap
157
158
159
#########################################################################
160
# Now we’ll use the TorchVision model library to download a pretrained
161
# ResNet. Since we’re not training, we’ll place it in evaluation mode for
162
# now.
163
#
164
165
model = models.resnet18(weights='IMAGENET1K_V1')
166
model = model.eval()
167
168
169
#######################################################################
170
# The place where you got this interactive notebook should also have an
171
# ``img`` folder with a file ``cat.jpg`` in it.
172
#
173
174
test_img = Image.open('img/cat.jpg')
175
test_img_data = np.asarray(test_img)
176
plt.imshow(test_img_data)
177
plt.show()
178
179
180
##########################################################################
181
# Our ResNet model was trained on the ImageNet dataset, and expects images
182
# to be of a certain size, with the channel data normalized to a specific
183
# range of values. We’ll also pull in the list of human-readable labels
184
# for the categories our model recognizes - that should be in the ``img``
185
# folder as well.
186
#
187
188
# model expects 224x224 3-color image
189
transform = transforms.Compose([
190
transforms.Resize(224),
191
transforms.CenterCrop(224),
192
transforms.ToTensor()
193
])
194
195
# standard ImageNet normalization
196
transform_normalize = transforms.Normalize(
197
mean=[0.485, 0.456, 0.406],
198
std=[0.229, 0.224, 0.225]
199
)
200
201
transformed_img = transform(test_img)
202
input_img = transform_normalize(transformed_img)
203
input_img = input_img.unsqueeze(0) # the model requires a dummy batch dimension
204
205
labels_path = 'img/imagenet_class_index.json'
206
with open(labels_path) as json_data:
207
idx_to_labels = json.load(json_data)
208
209
210
######################################################################
211
# Now, we can ask the question: What does our model think this image
212
# represents?
213
#
214
215
output = model(input_img)
216
output = F.softmax(output, dim=1)
217
prediction_score, pred_label_idx = torch.topk(output, 1)
218
pred_label_idx.squeeze_()
219
predicted_label = idx_to_labels[str(pred_label_idx.item())][1]
220
print('Predicted:', predicted_label, '(', prediction_score.squeeze().item(), ')')
221
222
223
######################################################################
224
# We’ve confirmed that ResNet thinks our image of a cat is, in fact, a
225
# cat. But *why* does the model think this is an image of a cat?
226
#
227
# For the answer to that, we turn to Captum.
228
#
229
230
231
##########################################################################
232
# Feature Attribution with Integrated Gradients
233
# ---------------------------------------------
234
#
235
# **Feature attribution** attributes a particular output to features of
236
# the input. It uses a specific input - here, our test image - to generate
237
# a map of the relative importance of each input feature to a particular
238
# output feature.
239
#
240
# `Integrated
241
# Gradients <https://captum.ai/api/integrated_gradients.html>`__ is one of
242
# the feature attribution algorithms available in Captum. Integrated
243
# Gradients assigns an importance score to each input feature by
244
# approximating the integral of the gradients of the model’s output with
245
# respect to the inputs.
246
#
247
# In our case, we’re going to be taking a specific element of the output
248
# vector - that is, the one indicating the model’s confidence in its
249
# chosen category - and use Integrated Gradients to understand what parts
250
# of the input image contributed to this output.
251
#
252
# Once we have the importance map from Integrated Gradients, we’ll use the
253
# visualization tools in Captum to give a helpful representation of the
254
# importance map. Captum’s ``visualize_image_attr()`` function provides a
255
# variety of options for customizing display of your attribution data.
256
# Here, we pass in a custom Matplotlib color map.
257
#
258
# Running the cell with the ``integrated_gradients.attribute()`` call will
259
# usually take a minute or two.
260
#
261
262
# Initialize the attribution algorithm with the model
263
integrated_gradients = IntegratedGradients(model)
264
265
# Ask the algorithm to attribute our output target to
266
attributions_ig = integrated_gradients.attribute(input_img, target=pred_label_idx, n_steps=200)
267
268
# Show the original image for comparison
269
_ = viz.visualize_image_attr(None, np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
270
method="original_image", title="Original Image")
271
272
default_cmap = LinearSegmentedColormap.from_list('custom blue',
273
[(0, '#ffffff'),
274
(0.25, '#0000ff'),
275
(1, '#0000ff')], N=256)
276
277
_ = viz.visualize_image_attr(np.transpose(attributions_ig.squeeze().cpu().detach().numpy(), (1,2,0)),
278
np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
279
method='heat_map',
280
cmap=default_cmap,
281
show_colorbar=True,
282
sign='positive',
283
title='Integrated Gradients')
284
285
286
#######################################################################
287
# In the image above, you should see that Integrated Gradients gives us
288
# the strongest signal around the cat’s location in the image.
289
#
290
291
292
##########################################################################
293
# Feature Attribution with Occlusion
294
# ----------------------------------
295
#
296
# Gradient-based attribution methods help to understand the model in terms
297
# of directly computing out the output changes with respect to the input.
298
# *Perturbation-based attribution* methods approach this more directly, by
299
# introducing changes to the input to measure the effect on the output.
300
# `Occlusion <https://captum.ai/api/occlusion.html>`__ is one such method.
301
# It involves replacing sections of the input image, and examining the
302
# effect on the output signal.
303
#
304
# Below, we set up Occlusion attribution. Similarly to configuring a
305
# convolutional neural network, you can specify the size of the target
306
# region, and a stride length to determine the spacing of individual
307
# measurements. We’ll visualize the output of our Occlusion attribution
308
# with ``visualize_image_attr_multiple()``, showing heat maps of both
309
# positive and negative attribution by region, and by masking the original
310
# image with the positive attribution regions. The masking gives a very
311
# instructive view of what regions of our cat photo the model found to be
312
# most “cat-like”.
313
#
314
315
occlusion = Occlusion(model)
316
317
attributions_occ = occlusion.attribute(input_img,
318
target=pred_label_idx,
319
strides=(3, 8, 8),
320
sliding_window_shapes=(3,15, 15),
321
baselines=0)
322
323
324
_ = viz.visualize_image_attr_multiple(np.transpose(attributions_occ.squeeze().cpu().detach().numpy(), (1,2,0)),
325
np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
326
["original_image", "heat_map", "heat_map", "masked_image"],
327
["all", "positive", "negative", "positive"],
328
show_colorbar=True,
329
titles=["Original", "Positive Attribution", "Negative Attribution", "Masked"],
330
fig_size=(18, 6)
331
)
332
333
334
######################################################################
335
# Again, we see greater significance placed on the region of the image
336
# that contains the cat.
337
#
338
339
340
#########################################################################
341
# Layer Attribution with Layer GradCAM
342
# ------------------------------------
343
#
344
# **Layer Attribution** allows you to attribute the activity of hidden
345
# layers within your model to features of your input. Below, we’ll use a
346
# layer attribution algorithm to examine the activity of one of the
347
# convolutional layers within our model.
348
#
349
# GradCAM computes the gradients of the target output with respect to the
350
# given layer, averages for each output channel (dimension 2 of output),
351
# and multiplies the average gradient for each channel by the layer
352
# activations. The results are summed over all channels. GradCAM is
353
# designed for convnets; since the activity of convolutional layers often
354
# maps spatially to the input, GradCAM attributions are often upsampled
355
# and used to mask the input.
356
#
357
# Layer attribution is set up similarly to input attribution, except that
358
# in addition to the model, you must specify a hidden layer within the
359
# model that you wish to examine. As above, when we call ``attribute()``,
360
# we specify the target class of interest.
361
#
362
363
layer_gradcam = LayerGradCam(model, model.layer3[1].conv2)
364
attributions_lgc = layer_gradcam.attribute(input_img, target=pred_label_idx)
365
366
_ = viz.visualize_image_attr(attributions_lgc[0].cpu().permute(1,2,0).detach().numpy(),
367
sign="all",
368
title="Layer 3 Block 1 Conv 2")
369
370
371
##########################################################################
372
# We’ll use the convenience method ``interpolate()`` in the
373
# `LayerAttribution <https://captum.ai/api/base_classes.html?highlight=layerattribution#captum.attr.LayerAttribution>`__
374
# base class to upsample this attribution data for comparison to the input
375
# image.
376
#
377
378
upsamp_attr_lgc = LayerAttribution.interpolate(attributions_lgc, input_img.shape[2:])
379
380
print(attributions_lgc.shape)
381
print(upsamp_attr_lgc.shape)
382
print(input_img.shape)
383
384
_ = viz.visualize_image_attr_multiple(upsamp_attr_lgc[0].cpu().permute(1,2,0).detach().numpy(),
385
transformed_img.permute(1,2,0).numpy(),
386
["original_image","blended_heat_map","masked_image"],
387
["all","positive","positive"],
388
show_colorbar=True,
389
titles=["Original", "Positive Attribution", "Masked"],
390
fig_size=(18, 6))
391
392
393
#######################################################################
394
# Visualizations such as this can give you novel insights into how your
395
# hidden layers respond to your input.
396
#
397
398
399
##########################################################################
400
# Visualization with Captum Insights
401
# ----------------------------------
402
#
403
# Captum Insights is an interpretability visualization widget built on top
404
# of Captum to facilitate model understanding. Captum Insights works
405
# across images, text, and other features to help users understand feature
406
# attribution. It allows you to visualize attribution for multiple
407
# input/output pairs, and provides visualization tools for image, text,
408
# and arbitrary data.
409
#
410
# In this section of the notebook, we’ll visualize multiple image
411
# classification inferences with Captum Insights.
412
#
413
# First, let’s gather some image and see what the model thinks of them.
414
# For variety, we’ll take our cat, a teapot, and a trilobite fossil:
415
#
416
417
imgs = ['img/cat.jpg', 'img/teapot.jpg', 'img/trilobite.jpg']
418
419
for img in imgs:
420
img = Image.open(img)
421
transformed_img = transform(img)
422
input_img = transform_normalize(transformed_img)
423
input_img = input_img.unsqueeze(0) # the model requires a dummy batch dimension
424
425
output = model(input_img)
426
output = F.softmax(output, dim=1)
427
prediction_score, pred_label_idx = torch.topk(output, 1)
428
pred_label_idx.squeeze_()
429
predicted_label = idx_to_labels[str(pred_label_idx.item())][1]
430
print('Predicted:', predicted_label, '/', pred_label_idx.item(), ' (', prediction_score.squeeze().item(), ')')
431
432
433
##########################################################################
434
# …and it looks like our model is identifying them all correctly - but of
435
# course, we want to dig deeper. For that we’ll use the Captum Insights
436
# widget, which we configure with an ``AttributionVisualizer`` object,
437
# imported below. The ``AttributionVisualizer`` expects batches of data,
438
# so we’ll bring in Captum’s ``Batch`` helper class. And we’ll be looking
439
# at images specifically, so well also import ``ImageFeature``.
440
#
441
# We configure the ``AttributionVisualizer`` with the following arguments:
442
#
443
# - An array of models to be examined (in our case, just the one)
444
# - A scoring function, which allows Captum Insights to pull out the
445
# top-k predictions from a model
446
# - An ordered, human-readable list of classes our model is trained on
447
# - A list of features to look for - in our case, an ``ImageFeature``
448
# - A dataset, which is an iterable object returning batches of inputs
449
# and labels - just like you’d use for training
450
#
451
452
from captum.insights import AttributionVisualizer, Batch
453
from captum.insights.attr_vis.features import ImageFeature
454
455
# Baseline is all-zeros input - this may differ depending on your data
456
def baseline_func(input):
457
return input * 0
458
459
# merging our image transforms from above
460
def full_img_transform(input):
461
i = Image.open(input)
462
i = transform(i)
463
i = transform_normalize(i)
464
i = i.unsqueeze(0)
465
return i
466
467
468
input_imgs = torch.cat(list(map(lambda i: full_img_transform(i), imgs)), 0)
469
470
visualizer = AttributionVisualizer(
471
models=[model],
472
score_func=lambda o: torch.nn.functional.softmax(o, 1),
473
classes=list(map(lambda k: idx_to_labels[k][1], idx_to_labels.keys())),
474
features=[
475
ImageFeature(
476
"Photo",
477
baseline_transforms=[baseline_func],
478
input_transforms=[],
479
)
480
],
481
dataset=[Batch(input_imgs, labels=[282,849,69])]
482
)
483
484
485
#########################################################################
486
# Note that running the cell above didn’t take much time at all, unlike
487
# our attributions above. That’s because Captum Insights lets you
488
# configure different attribution algorithms in a visual widget, after
489
# which it will compute and display the attributions. *That* process will
490
# take a few minutes.
491
#
492
# Running the cell below will render the Captum Insights widget. You can
493
# then choose attributions methods and their arguments, filter model
494
# responses based on predicted class or prediction correctness, see the
495
# model’s predictions with associated probabilities, and view heatmaps of
496
# the attribution compared with the original image.
497
#
498
499
visualizer.render()
500
501