Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/integrated_gradients.py
3507 views
1
"""
2
Title: Model interpretability with Integrated Gradients
3
Author: [A_K_Nain](https://twitter.com/A_K_Nain)
4
Date created: 2020/06/02
5
Last modified: 2020/06/02
6
Description: How to obtain integrated gradients for a classification model.
7
Accelerator: None
8
"""
9
10
"""
11
## Integrated Gradients
12
13
[Integrated Gradients](https://arxiv.org/abs/1703.01365) is a technique for
14
attributing a classification model's prediction to its input features. It is
15
a model interpretability technique: you can use it to visualize the relationship
16
between input features and model predictions.
17
18
Integrated Gradients is a variation on computing
19
the gradient of the prediction output with regard to features of the input.
20
To compute integrated gradients, we need to perform the following steps:
21
22
1. Identify the input and the output. In our case, the input is an image and the
23
output is the last layer of our model (dense layer with softmax activation).
24
25
2. Compute which features are important to a neural network
26
when making a prediction on a particular data point. To identify these features, we
27
need to choose a baseline input. A baseline input can be a black image (all pixel
28
values set to zero) or random noise. The shape of the baseline input needs to be
29
the same as our input image, e.g. (299, 299, 3).
30
31
3. Interpolate the baseline for a given number of steps. The number of steps represents
32
the steps we need in the gradient approximation for a given input image. The number of
33
steps is a hyperparameter. The authors recommend using anywhere between
34
20 and 1000 steps.
35
36
4. Preprocess these interpolated images and do a forward pass.
37
5. Get the gradients for these interpolated images.
38
6. Approximate the gradients integral using the trapezoidal rule.
39
40
To read in-depth about integrated gradients and why this method works,
41
consider reading this excellent
42
[article](https://distill.pub/2020/attribution-baselines/).
43
44
**References:**
45
46
- Integrated Gradients original [paper](https://arxiv.org/abs/1703.01365)
47
- [Original implementation](https://github.com/ankurtaly/Integrated-Gradients)
48
"""
49
50
"""
51
## Setup
52
"""
53
54
55
import numpy as np
56
import matplotlib.pyplot as plt
57
from scipy import ndimage
58
from IPython.display import Image, display
59
60
import tensorflow as tf
61
import keras
62
from keras import layers
63
from keras.applications import xception
64
65
66
# Size of the input image
67
img_size = (299, 299, 3)
68
69
# Load Xception model with imagenet weights
70
model = xception.Xception(weights="imagenet")
71
72
# The local path to our target image
73
img_path = keras.utils.get_file("elephant.jpg", "https://i.imgur.com/Bvro0YD.png")
74
display(Image(img_path))
75
76
"""
77
## Integrated Gradients algorithm
78
"""
79
80
81
def get_img_array(img_path, size=(299, 299)):
82
# `img` is a PIL image of size 299x299
83
img = keras.utils.load_img(img_path, target_size=size)
84
# `array` is a float32 Numpy array of shape (299, 299, 3)
85
array = keras.utils.img_to_array(img)
86
# We add a dimension to transform our array into a "batch"
87
# of size (1, 299, 299, 3)
88
array = np.expand_dims(array, axis=0)
89
return array
90
91
92
def get_gradients(img_input, top_pred_idx):
93
"""Computes the gradients of outputs w.r.t input image.
94
95
Args:
96
img_input: 4D image tensor
97
top_pred_idx: Predicted label for the input image
98
99
Returns:
100
Gradients of the predictions w.r.t img_input
101
"""
102
images = tf.cast(img_input, tf.float32)
103
104
with tf.GradientTape() as tape:
105
tape.watch(images)
106
preds = model(images)
107
top_class = preds[:, top_pred_idx]
108
109
grads = tape.gradient(top_class, images)
110
return grads
111
112
113
def get_integrated_gradients(img_input, top_pred_idx, baseline=None, num_steps=50):
114
"""Computes Integrated Gradients for a predicted label.
115
116
Args:
117
img_input (ndarray): Original image
118
top_pred_idx: Predicted label for the input image
119
baseline (ndarray): The baseline image to start with for interpolation
120
num_steps: Number of interpolation steps between the baseline
121
and the input used in the computation of integrated gradients. These
122
steps along determine the integral approximation error. By default,
123
num_steps is set to 50.
124
125
Returns:
126
Integrated gradients w.r.t input image
127
"""
128
# If baseline is not provided, start with a black image
129
# having same size as the input image.
130
if baseline is None:
131
baseline = np.zeros(img_size).astype(np.float32)
132
else:
133
baseline = baseline.astype(np.float32)
134
135
# 1. Do interpolation.
136
img_input = img_input.astype(np.float32)
137
interpolated_image = [
138
baseline + (step / num_steps) * (img_input - baseline)
139
for step in range(num_steps + 1)
140
]
141
interpolated_image = np.array(interpolated_image).astype(np.float32)
142
143
# 2. Preprocess the interpolated images
144
interpolated_image = xception.preprocess_input(interpolated_image)
145
146
# 3. Get the gradients
147
grads = []
148
for i, img in enumerate(interpolated_image):
149
img = tf.expand_dims(img, axis=0)
150
grad = get_gradients(img, top_pred_idx=top_pred_idx)
151
grads.append(grad[0])
152
grads = tf.convert_to_tensor(grads, dtype=tf.float32)
153
154
# 4. Approximate the integral using the trapezoidal rule
155
grads = (grads[:-1] + grads[1:]) / 2.0
156
avg_grads = tf.reduce_mean(grads, axis=0)
157
158
# 5. Calculate integrated gradients and return
159
integrated_grads = (img_input - baseline) * avg_grads
160
return integrated_grads
161
162
163
def random_baseline_integrated_gradients(
164
img_input, top_pred_idx, num_steps=50, num_runs=2
165
):
166
"""Generates a number of random baseline images.
167
168
Args:
169
img_input (ndarray): 3D image
170
top_pred_idx: Predicted label for the input image
171
num_steps: Number of interpolation steps between the baseline
172
and the input used in the computation of integrated gradients. These
173
steps along determine the integral approximation error. By default,
174
num_steps is set to 50.
175
num_runs: number of baseline images to generate
176
177
Returns:
178
Averaged integrated gradients for `num_runs` baseline images
179
"""
180
# 1. List to keep track of Integrated Gradients (IG) for all the images
181
integrated_grads = []
182
183
# 2. Get the integrated gradients for all the baselines
184
for run in range(num_runs):
185
baseline = np.random.random(img_size) * 255
186
igrads = get_integrated_gradients(
187
img_input=img_input,
188
top_pred_idx=top_pred_idx,
189
baseline=baseline,
190
num_steps=num_steps,
191
)
192
integrated_grads.append(igrads)
193
194
# 3. Return the average integrated gradients for the image
195
integrated_grads = tf.convert_to_tensor(integrated_grads)
196
return tf.reduce_mean(integrated_grads, axis=0)
197
198
199
"""
200
## Helper class for visualizing gradients and integrated gradients
201
"""
202
203
204
class GradVisualizer:
205
"""Plot gradients of the outputs w.r.t an input image."""
206
207
def __init__(self, positive_channel=None, negative_channel=None):
208
if positive_channel is None:
209
self.positive_channel = [0, 255, 0]
210
else:
211
self.positive_channel = positive_channel
212
213
if negative_channel is None:
214
self.negative_channel = [255, 0, 0]
215
else:
216
self.negative_channel = negative_channel
217
218
def apply_polarity(self, attributions, polarity):
219
if polarity == "positive":
220
return np.clip(attributions, 0, 1)
221
else:
222
return np.clip(attributions, -1, 0)
223
224
def apply_linear_transformation(
225
self,
226
attributions,
227
clip_above_percentile=99.9,
228
clip_below_percentile=70.0,
229
lower_end=0.2,
230
):
231
# 1. Get the thresholds
232
m = self.get_thresholded_attributions(
233
attributions, percentage=100 - clip_above_percentile
234
)
235
e = self.get_thresholded_attributions(
236
attributions, percentage=100 - clip_below_percentile
237
)
238
239
# 2. Transform the attributions by a linear function f(x) = a*x + b such that
240
# f(m) = 1.0 and f(e) = lower_end
241
transformed_attributions = (1 - lower_end) * (np.abs(attributions) - e) / (
242
m - e
243
) + lower_end
244
245
# 3. Make sure that the sign of transformed attributions is the same as original attributions
246
transformed_attributions *= np.sign(attributions)
247
248
# 4. Only keep values that are bigger than the lower_end
249
transformed_attributions *= transformed_attributions >= lower_end
250
251
# 5. Clip values and return
252
transformed_attributions = np.clip(transformed_attributions, 0.0, 1.0)
253
return transformed_attributions
254
255
def get_thresholded_attributions(self, attributions, percentage):
256
if percentage == 100.0:
257
return np.min(attributions)
258
259
# 1. Flatten the attributions
260
flatten_attr = attributions.flatten()
261
262
# 2. Get the sum of the attributions
263
total = np.sum(flatten_attr)
264
265
# 3. Sort the attributions from largest to smallest.
266
sorted_attributions = np.sort(np.abs(flatten_attr))[::-1]
267
268
# 4. Calculate the percentage of the total sum that each attribution
269
# and the values about it contribute.
270
cum_sum = 100.0 * np.cumsum(sorted_attributions) / total
271
272
# 5. Threshold the attributions by the percentage
273
indices_to_consider = np.where(cum_sum >= percentage)[0][0]
274
275
# 6. Select the desired attributions and return
276
attributions = sorted_attributions[indices_to_consider]
277
return attributions
278
279
def binarize(self, attributions, threshold=0.001):
280
return attributions > threshold
281
282
def morphological_cleanup_fn(self, attributions, structure=np.ones((4, 4))):
283
closed = ndimage.grey_closing(attributions, structure=structure)
284
opened = ndimage.grey_opening(closed, structure=structure)
285
return opened
286
287
def draw_outlines(
288
self,
289
attributions,
290
percentage=90,
291
connected_component_structure=np.ones((3, 3)),
292
):
293
# 1. Binarize the attributions.
294
attributions = self.binarize(attributions)
295
296
# 2. Fill the gaps
297
attributions = ndimage.binary_fill_holes(attributions)
298
299
# 3. Compute connected components
300
connected_components, num_comp = ndimage.label(
301
attributions, structure=connected_component_structure
302
)
303
304
# 4. Sum up the attributions for each component
305
total = np.sum(attributions[connected_components > 0])
306
component_sums = []
307
for comp in range(1, num_comp + 1):
308
mask = connected_components == comp
309
component_sum = np.sum(attributions[mask])
310
component_sums.append((component_sum, mask))
311
312
# 5. Compute the percentage of top components to keep
313
sorted_sums_and_masks = sorted(component_sums, key=lambda x: x[0], reverse=True)
314
sorted_sums = list(zip(*sorted_sums_and_masks))[0]
315
cumulative_sorted_sums = np.cumsum(sorted_sums)
316
cutoff_threshold = percentage * total / 100
317
cutoff_idx = np.where(cumulative_sorted_sums >= cutoff_threshold)[0][0]
318
if cutoff_idx > 2:
319
cutoff_idx = 2
320
321
# 6. Set the values for the kept components
322
border_mask = np.zeros_like(attributions)
323
for i in range(cutoff_idx + 1):
324
border_mask[sorted_sums_and_masks[i][1]] = 1
325
326
# 7. Make the mask hollow and show only the border
327
eroded_mask = ndimage.binary_erosion(border_mask, iterations=1)
328
border_mask[eroded_mask] = 0
329
330
# 8. Return the outlined mask
331
return border_mask
332
333
def process_grads(
334
self,
335
image,
336
attributions,
337
polarity="positive",
338
clip_above_percentile=99.9,
339
clip_below_percentile=0,
340
morphological_cleanup=False,
341
structure=np.ones((3, 3)),
342
outlines=False,
343
outlines_component_percentage=90,
344
overlay=True,
345
):
346
if polarity not in ["positive", "negative"]:
347
raise ValueError(
348
f""" Allowed polarity values: 'positive' or 'negative'
349
but provided {polarity}"""
350
)
351
if clip_above_percentile < 0 or clip_above_percentile > 100:
352
raise ValueError("clip_above_percentile must be in [0, 100]")
353
354
if clip_below_percentile < 0 or clip_below_percentile > 100:
355
raise ValueError("clip_below_percentile must be in [0, 100]")
356
357
# 1. Apply polarity
358
if polarity == "positive":
359
attributions = self.apply_polarity(attributions, polarity=polarity)
360
channel = self.positive_channel
361
else:
362
attributions = self.apply_polarity(attributions, polarity=polarity)
363
attributions = np.abs(attributions)
364
channel = self.negative_channel
365
366
# 2. Take average over the channels
367
attributions = np.average(attributions, axis=2)
368
369
# 3. Apply linear transformation to the attributions
370
attributions = self.apply_linear_transformation(
371
attributions,
372
clip_above_percentile=clip_above_percentile,
373
clip_below_percentile=clip_below_percentile,
374
lower_end=0.0,
375
)
376
377
# 4. Cleanup
378
if morphological_cleanup:
379
attributions = self.morphological_cleanup_fn(
380
attributions, structure=structure
381
)
382
# 5. Draw the outlines
383
if outlines:
384
attributions = self.draw_outlines(
385
attributions, percentage=outlines_component_percentage
386
)
387
388
# 6. Expand the channel axis and convert to RGB
389
attributions = np.expand_dims(attributions, 2) * channel
390
391
# 7.Superimpose on the original image
392
if overlay:
393
attributions = np.clip((attributions * 0.8 + image), 0, 255)
394
return attributions
395
396
def visualize(
397
self,
398
image,
399
gradients,
400
integrated_gradients,
401
polarity="positive",
402
clip_above_percentile=99.9,
403
clip_below_percentile=0,
404
morphological_cleanup=False,
405
structure=np.ones((3, 3)),
406
outlines=False,
407
outlines_component_percentage=90,
408
overlay=True,
409
figsize=(15, 8),
410
):
411
# 1. Make two copies of the original image
412
img1 = np.copy(image)
413
img2 = np.copy(image)
414
415
# 2. Process the normal gradients
416
grads_attr = self.process_grads(
417
image=img1,
418
attributions=gradients,
419
polarity=polarity,
420
clip_above_percentile=clip_above_percentile,
421
clip_below_percentile=clip_below_percentile,
422
morphological_cleanup=morphological_cleanup,
423
structure=structure,
424
outlines=outlines,
425
outlines_component_percentage=outlines_component_percentage,
426
overlay=overlay,
427
)
428
429
# 3. Process the integrated gradients
430
igrads_attr = self.process_grads(
431
image=img2,
432
attributions=integrated_gradients,
433
polarity=polarity,
434
clip_above_percentile=clip_above_percentile,
435
clip_below_percentile=clip_below_percentile,
436
morphological_cleanup=morphological_cleanup,
437
structure=structure,
438
outlines=outlines,
439
outlines_component_percentage=outlines_component_percentage,
440
overlay=overlay,
441
)
442
443
_, ax = plt.subplots(1, 3, figsize=figsize)
444
ax[0].imshow(image)
445
ax[1].imshow(grads_attr.astype(np.uint8))
446
ax[2].imshow(igrads_attr.astype(np.uint8))
447
448
ax[0].set_title("Input")
449
ax[1].set_title("Normal gradients")
450
ax[2].set_title("Integrated gradients")
451
plt.show()
452
453
454
"""
455
## Let's test-drive it
456
"""
457
458
# 1. Convert the image to numpy array
459
img = get_img_array(img_path)
460
461
# 2. Keep a copy of the original image
462
orig_img = np.copy(img[0]).astype(np.uint8)
463
464
# 3. Preprocess the image
465
img_processed = tf.cast(xception.preprocess_input(img), dtype=tf.float32)
466
467
# 4. Get model predictions
468
preds = model.predict(img_processed)
469
top_pred_idx = tf.argmax(preds[0])
470
print("Predicted:", top_pred_idx, xception.decode_predictions(preds, top=1)[0])
471
472
# 5. Get the gradients of the last layer for the predicted label
473
grads = get_gradients(img_processed, top_pred_idx=top_pred_idx)
474
475
# 6. Get the integrated gradients
476
igrads = random_baseline_integrated_gradients(
477
np.copy(orig_img), top_pred_idx=top_pred_idx, num_steps=50, num_runs=2
478
)
479
480
# 7. Process the gradients and plot
481
vis = GradVisualizer()
482
vis.visualize(
483
image=orig_img,
484
gradients=grads[0].numpy(),
485
integrated_gradients=igrads.numpy(),
486
clip_above_percentile=99,
487
clip_below_percentile=0,
488
)
489
490
vis.visualize(
491
image=orig_img,
492
gradients=grads[0].numpy(),
493
integrated_gradients=igrads.numpy(),
494
clip_above_percentile=95,
495
clip_below_percentile=28,
496
morphological_cleanup=True,
497
outlines=True,
498
)
499
500