Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/recipes_source/recipes/Captum_Recipe.py
Views: 713
"""1Model Interpretability using Captum2===================================34"""567######################################################################8# Captum helps you understand how the data features impact your model9# predictions or neuron activations, shedding light on how your model10# operates.11#12# Using Captum, you can apply a wide range of state-of-the-art feature13# attribution algorithms such as \ ``Guided GradCam``\ and14# \ ``Integrated Gradients``\ in a unified way.15#16# In this recipe you will learn how to use Captum to:17#18# - Attribute the predictions of an image classifier to their corresponding image features.19# - Visualize the attribution results.20#212223######################################################################24# Before you begin25# ----------------26#272829######################################################################30# Make sure Captum is installed in your active Python environment. Captum31# is available both on GitHub, as a ``pip`` package, or as a ``conda``32# package. For detailed instructions, consult the installation guide at33# https://captum.ai/34#353637######################################################################38# For a model, we use a built-in image classifier in PyTorch. Captum can39# reveal which parts of a sample image support certain predictions made by40# the model.41#4243import torchvision44from torchvision import models, transforms45from PIL import Image46import requests47from io import BytesIO4849model = torchvision.models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1).eval()5051response = requests.get("https://image.freepik.com/free-photo/two-beautiful-puppies-cat-dog_58409-6024.jpg")52img = Image.open(BytesIO(response.content))5354center_crop = transforms.Compose([55transforms.Resize(256),56transforms.CenterCrop(224),57])5859normalize = transforms.Compose([60transforms.ToTensor(), # converts the image to a tensor with values between 0 and 161transforms.Normalize( # normalize to follow 0-centered imagenet pixel RGB distribution62mean=[0.485, 0.456, 0.406],63std=[0.229, 0.224, 0.225]64)65])66input_img = normalize(center_crop(img)).unsqueeze(0)676869######################################################################70# Computing Attribution71# ---------------------72#737475######################################################################76# Among the top-3 predictions of the models are classes 208 and 283 which77# correspond to dog and cat.78#79# Let us attribute each of these predictions to the corresponding part of80# the input, using Captum’s \ ``Occlusion``\ algorithm.81#8283from captum.attr import Occlusion8485occlusion = Occlusion(model)8687strides = (3, 9, 9) # smaller = more fine-grained attribution but slower88target=208, # Labrador index in ImageNet89sliding_window_shapes=(3,45, 45) # choose size enough to change object appearance90baselines = 0 # values to occlude the image with. 0 corresponds to gray9192attribution_dog = occlusion.attribute(input_img,93strides = strides,94target=target,95sliding_window_shapes=sliding_window_shapes,96baselines=baselines)979899target=283, # Persian cat index in ImageNet100attribution_cat = occlusion.attribute(input_img,101strides = strides,102target=target,103sliding_window_shapes=sliding_window_shapes,104baselines=0)105106107######################################################################108# Besides ``Occlusion``, Captum features many algorithms such as109# \ ``Integrated Gradients``\ , \ ``Deconvolution``\ ,110# \ ``GuidedBackprop``\ , \ ``Guided GradCam``\ , \ ``DeepLift``\ , and111# \ ``GradientShap``\ . All of these algorithms are subclasses of112# ``Attribution`` which expects your model as a callable ``forward_func``113# upon initialization and has an ``attribute(...)`` method which returns114# the attribution result in a unified format.115#116# Let us visualize the computed attribution results in case of images.117#118119120######################################################################121# Visualizing the Results122# -----------------------123#124125126######################################################################127# Captum’s \ ``visualization``\ utility provides out-of-the-box methods128# to visualize attribution results both for pictorial and for textual129# inputs.130#131132import numpy as np133from captum.attr import visualization as viz134135# Convert the compute attribution tensor into an image-like numpy array136attribution_dog = np.transpose(attribution_dog.squeeze().cpu().detach().numpy(), (1,2,0))137138vis_types = ["heat_map", "original_image"]139vis_signs = ["all", "all"] # "positive", "negative", or "all" to show both140# positive attribution indicates that the presence of the area increases the prediction score141# negative attribution indicates distractor areas whose absence increases the score142143_ = viz.visualize_image_attr_multiple(attribution_dog,144np.array(center_crop(img)),145vis_types,146vis_signs,147["attribution for dog", "image"],148show_colorbar = True149)150151152attribution_cat = np.transpose(attribution_cat.squeeze().cpu().detach().numpy(), (1,2,0))153154_ = viz.visualize_image_attr_multiple(attribution_cat,155np.array(center_crop(img)),156["heat_map", "original_image"],157["all", "all"], # positive/negative attribution or all158["attribution for cat", "image"],159show_colorbar = True160)161162163######################################################################164# If your data is textual, ``visualization.visualize_text()`` offers a165# dedicated view to explore attribution on top of the input text. Find out166# more at http://captum.ai/tutorials/IMDB_TorchText_Interpret167#168169170######################################################################171# Final Notes172# -----------173#174175176######################################################################177# Captum can handle most model types in PyTorch across modalities178# including vision, text, and more. With Captum you can: \* Attribute a179# specific output to the model input as illustrated above. \* Attribute a180# specific output to a hidden-layer neuron (see Captum API reference). \*181# Attribute a hidden-layer neuron response to the model input (see Captum182# API reference).183#184# For complete API of the supported methods and a list of tutorials,185# consult our website http://captum.ai186#187# Another useful post by Gilbert Tanner:188# https://gilberttanner.com/blog/interpreting-pytorch-models-with-captum189#190191192