CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/recipes_source/recipes/Captum_Recipe.py
Views: 494
1
"""
2
Model Interpretability using Captum
3
===================================
4
5
"""
6
7
8
######################################################################
9
# Captum helps you understand how the data features impact your model
10
# predictions or neuron activations, shedding light on how your model
11
# operates.
12
#
13
# Using Captum, you can apply a wide range of state-of-the-art feature
14
# attribution algorithms such as \ ``Guided GradCam``\ and
15
# \ ``Integrated Gradients``\ in a unified way.
16
#
17
# In this recipe you will learn how to use Captum to:
18
#
19
# - Attribute the predictions of an image classifier to their corresponding image features.
20
# - Visualize the attribution results.
21
#
22
23
24
######################################################################
25
# Before you begin
26
# ----------------
27
#
28
29
30
######################################################################
31
# Make sure Captum is installed in your active Python environment. Captum
32
# is available both on GitHub, as a ``pip`` package, or as a ``conda``
33
# package. For detailed instructions, consult the installation guide at
34
# https://captum.ai/
35
#
36
37
38
######################################################################
39
# For a model, we use a built-in image classifier in PyTorch. Captum can
40
# reveal which parts of a sample image support certain predictions made by
41
# the model.
42
#
43
44
import torchvision
45
from torchvision import models, transforms
46
from PIL import Image
47
import requests
48
from io import BytesIO
49
50
model = torchvision.models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1).eval()
51
52
response = requests.get("https://image.freepik.com/free-photo/two-beautiful-puppies-cat-dog_58409-6024.jpg")
53
img = Image.open(BytesIO(response.content))
54
55
center_crop = transforms.Compose([
56
transforms.Resize(256),
57
transforms.CenterCrop(224),
58
])
59
60
normalize = transforms.Compose([
61
transforms.ToTensor(), # converts the image to a tensor with values between 0 and 1
62
transforms.Normalize( # normalize to follow 0-centered imagenet pixel RGB distribution
63
mean=[0.485, 0.456, 0.406],
64
std=[0.229, 0.224, 0.225]
65
)
66
])
67
input_img = normalize(center_crop(img)).unsqueeze(0)
68
69
70
######################################################################
71
# Computing Attribution
72
# ---------------------
73
#
74
75
76
######################################################################
77
# Among the top-3 predictions of the models are classes 208 and 283 which
78
# correspond to dog and cat.
79
#
80
# Let us attribute each of these predictions to the corresponding part of
81
# the input, using Captum’s \ ``Occlusion``\ algorithm.
82
#
83
84
from captum.attr import Occlusion
85
86
occlusion = Occlusion(model)
87
88
strides = (3, 9, 9) # smaller = more fine-grained attribution but slower
89
target=208, # Labrador index in ImageNet
90
sliding_window_shapes=(3,45, 45) # choose size enough to change object appearance
91
baselines = 0 # values to occlude the image with. 0 corresponds to gray
92
93
attribution_dog = occlusion.attribute(input_img,
94
strides = strides,
95
target=target,
96
sliding_window_shapes=sliding_window_shapes,
97
baselines=baselines)
98
99
100
target=283, # Persian cat index in ImageNet
101
attribution_cat = occlusion.attribute(input_img,
102
strides = strides,
103
target=target,
104
sliding_window_shapes=sliding_window_shapes,
105
baselines=0)
106
107
108
######################################################################
109
# Besides ``Occlusion``, Captum features many algorithms such as
110
# \ ``Integrated Gradients``\ , \ ``Deconvolution``\ ,
111
# \ ``GuidedBackprop``\ , \ ``Guided GradCam``\ , \ ``DeepLift``\ , and
112
# \ ``GradientShap``\ . All of these algorithms are subclasses of
113
# ``Attribution`` which expects your model as a callable ``forward_func``
114
# upon initialization and has an ``attribute(...)`` method which returns
115
# the attribution result in a unified format.
116
#
117
# Let us visualize the computed attribution results in case of images.
118
#
119
120
121
######################################################################
122
# Visualizing the Results
123
# -----------------------
124
#
125
126
127
######################################################################
128
# Captum’s \ ``visualization``\ utility provides out-of-the-box methods
129
# to visualize attribution results both for pictorial and for textual
130
# inputs.
131
#
132
133
import numpy as np
134
from captum.attr import visualization as viz
135
136
# Convert the compute attribution tensor into an image-like numpy array
137
attribution_dog = np.transpose(attribution_dog.squeeze().cpu().detach().numpy(), (1,2,0))
138
139
vis_types = ["heat_map", "original_image"]
140
vis_signs = ["all", "all"] # "positive", "negative", or "all" to show both
141
# positive attribution indicates that the presence of the area increases the prediction score
142
# negative attribution indicates distractor areas whose absence increases the score
143
144
_ = viz.visualize_image_attr_multiple(attribution_dog,
145
np.array(center_crop(img)),
146
vis_types,
147
vis_signs,
148
["attribution for dog", "image"],
149
show_colorbar = True
150
)
151
152
153
attribution_cat = np.transpose(attribution_cat.squeeze().cpu().detach().numpy(), (1,2,0))
154
155
_ = viz.visualize_image_attr_multiple(attribution_cat,
156
np.array(center_crop(img)),
157
["heat_map", "original_image"],
158
["all", "all"], # positive/negative attribution or all
159
["attribution for cat", "image"],
160
show_colorbar = True
161
)
162
163
164
######################################################################
165
# If your data is textual, ``visualization.visualize_text()`` offers a
166
# dedicated view to explore attribution on top of the input text. Find out
167
# more at http://captum.ai/tutorials/IMDB_TorchText_Interpret
168
#
169
170
171
######################################################################
172
# Final Notes
173
# -----------
174
#
175
176
177
######################################################################
178
# Captum can handle most model types in PyTorch across modalities
179
# including vision, text, and more. With Captum you can: \* Attribute a
180
# specific output to the model input as illustrated above. \* Attribute a
181
# specific output to a hidden-layer neuron (see Captum API reference). \*
182
# Attribute a hidden-layer neuron response to the model input (see Captum
183
# API reference).
184
#
185
# For complete API of the supported methods and a list of tutorials,
186
# consult our website http://captum.ai
187
#
188
# Another useful post by Gilbert Tanner:
189
# https://gilberttanner.com/blog/interpreting-pytorch-models-with-captum
190
#
191
192