CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/fgsm_tutorial.py
Views: 494
1
# -*- coding: utf-8 -*-
2
"""
3
Adversarial Example Generation
4
==============================
5
6
**Author:** `Nathan Inkawhich <https://github.com/inkawhich>`__
7
8
If you are reading this, hopefully you can appreciate how effective some
9
machine learning models are. Research is constantly pushing ML models to
10
be faster, more accurate, and more efficient. However, an often
11
overlooked aspect of designing and training models is security and
12
robustness, especially in the face of an adversary who wishes to fool
13
the model.
14
15
This tutorial will raise your awareness to the security vulnerabilities
16
of ML models, and will give insight into the hot topic of adversarial
17
machine learning. You may be surprised to find that adding imperceptible
18
perturbations to an image *can* cause drastically different model
19
performance. Given that this is a tutorial, we will explore the topic
20
via example on an image classifier. Specifically, we will use one of the
21
first and most popular attack methods, the Fast Gradient Sign Attack
22
(FGSM), to fool an MNIST classifier.
23
24
"""
25
26
27
######################################################################
28
# Threat Model
29
# ------------
30
#
31
# For context, there are many categories of adversarial attacks, each with
32
# a different goal and assumption of the attacker’s knowledge. However, in
33
# general the overarching goal is to add the least amount of perturbation
34
# to the input data to cause the desired misclassification. There are
35
# several kinds of assumptions of the attacker’s knowledge, two of which
36
# are: **white-box** and **black-box**. A *white-box* attack assumes the
37
# attacker has full knowledge and access to the model, including
38
# architecture, inputs, outputs, and weights. A *black-box* attack assumes
39
# the attacker only has access to the inputs and outputs of the model, and
40
# knows nothing about the underlying architecture or weights. There are
41
# also several types of goals, including **misclassification** and
42
# **source/target misclassification**. A goal of *misclassification* means
43
# the adversary only wants the output classification to be wrong but does
44
# not care what the new classification is. A *source/target
45
# misclassification* means the adversary wants to alter an image that is
46
# originally of a specific source class so that it is classified as a
47
# specific target class.
48
#
49
# In this case, the FGSM attack is a *white-box* attack with the goal of
50
# *misclassification*. With this background information, we can now
51
# discuss the attack in detail.
52
#
53
# Fast Gradient Sign Attack
54
# -------------------------
55
#
56
# One of the first and most popular adversarial attacks to date is
57
# referred to as the *Fast Gradient Sign Attack (FGSM)* and is described
58
# by Goodfellow et. al. in `Explaining and Harnessing Adversarial
59
# Examples <https://arxiv.org/abs/1412.6572>`__. The attack is remarkably
60
# powerful, and yet intuitive. It is designed to attack neural networks by
61
# leveraging the way they learn, *gradients*. The idea is simple, rather
62
# than working to minimize the loss by adjusting the weights based on the
63
# backpropagated gradients, the attack *adjusts the input data to maximize
64
# the loss* based on the same backpropagated gradients. In other words,
65
# the attack uses the gradient of the loss w.r.t the input data, then
66
# adjusts the input data to maximize the loss.
67
#
68
# Before we jump into the code, let’s look at the famous
69
# `FGSM <https://arxiv.org/abs/1412.6572>`__ panda example and extract
70
# some notation.
71
#
72
# .. figure:: /_static/img/fgsm_panda_image.png
73
# :alt: fgsm_panda_image
74
#
75
# From the figure, :math:`\mathbf{x}` is the original input image
76
# correctly classified as a “panda”, :math:`y` is the ground truth label
77
# for :math:`\mathbf{x}`, :math:`\mathbf{\theta}` represents the model
78
# parameters, and :math:`J(\mathbf{\theta}, \mathbf{x}, y)` is the loss
79
# that is used to train the network. The attack backpropagates the
80
# gradient back to the input data to calculate
81
# :math:`\nabla_{x} J(\mathbf{\theta}, \mathbf{x}, y)`. Then, it adjusts
82
# the input data by a small step (:math:`\epsilon` or :math:`0.007` in the
83
# picture) in the direction (i.e.
84
# :math:`sign(\nabla_{x} J(\mathbf{\theta}, \mathbf{x}, y))`) that will
85
# maximize the loss. The resulting perturbed image, :math:`x'`, is then
86
# *misclassified* by the target network as a “gibbon” when it is still
87
# clearly a “panda”.
88
#
89
# Hopefully now the motivation for this tutorial is clear, so lets jump
90
# into the implementation.
91
#
92
93
import torch
94
import torch.nn as nn
95
import torch.nn.functional as F
96
import torch.optim as optim
97
from torchvision import datasets, transforms
98
import numpy as np
99
import matplotlib.pyplot as plt
100
101
102
######################################################################
103
# Implementation
104
# --------------
105
#
106
# In this section, we will discuss the input parameters for the tutorial,
107
# define the model under attack, then code the attack and run some tests.
108
#
109
# Inputs
110
# ~~~~~~
111
#
112
# There are only three inputs for this tutorial, and are defined as
113
# follows:
114
#
115
# - ``epsilons`` - List of epsilon values to use for the run. It is
116
# important to keep 0 in the list because it represents the model
117
# performance on the original test set. Also, intuitively we would
118
# expect the larger the epsilon, the more noticeable the perturbations
119
# but the more effective the attack in terms of degrading model
120
# accuracy. Since the data range here is :math:`[0,1]`, no epsilon
121
# value should exceed 1.
122
#
123
# - ``pretrained_model`` - path to the pretrained MNIST model which was
124
# trained with
125
# `pytorch/examples/mnist <https://github.com/pytorch/examples/tree/master/mnist>`__.
126
# For simplicity, download the pretrained model `here <https://drive.google.com/file/d/1HJV2nUHJqclXQ8flKvcWmjZ-OU5DGatl/view?usp=drive_link>`__.
127
#
128
# - ``use_cuda`` - boolean flag to use CUDA if desired and available.
129
# Note, a GPU with CUDA is not critical for this tutorial as a CPU will
130
# not take much time.
131
#
132
133
epsilons = [0, .05, .1, .15, .2, .25, .3]
134
pretrained_model = "data/lenet_mnist_model.pth"
135
use_cuda=True
136
# Set random seed for reproducibility
137
torch.manual_seed(42)
138
139
140
######################################################################
141
# Model Under Attack
142
# ~~~~~~~~~~~~~~~~~~
143
#
144
# As mentioned, the model under attack is the same MNIST model from
145
# `pytorch/examples/mnist <https://github.com/pytorch/examples/tree/master/mnist>`__.
146
# You may train and save your own MNIST model or you can download and use
147
# the provided model. The *Net* definition and test dataloader here have
148
# been copied from the MNIST example. The purpose of this section is to
149
# define the model and dataloader, then initialize the model and load the
150
# pretrained weights.
151
#
152
153
# LeNet Model definition
154
class Net(nn.Module):
155
def __init__(self):
156
super(Net, self).__init__()
157
self.conv1 = nn.Conv2d(1, 32, 3, 1)
158
self.conv2 = nn.Conv2d(32, 64, 3, 1)
159
self.dropout1 = nn.Dropout(0.25)
160
self.dropout2 = nn.Dropout(0.5)
161
self.fc1 = nn.Linear(9216, 128)
162
self.fc2 = nn.Linear(128, 10)
163
164
def forward(self, x):
165
x = self.conv1(x)
166
x = F.relu(x)
167
x = self.conv2(x)
168
x = F.relu(x)
169
x = F.max_pool2d(x, 2)
170
x = self.dropout1(x)
171
x = torch.flatten(x, 1)
172
x = self.fc1(x)
173
x = F.relu(x)
174
x = self.dropout2(x)
175
x = self.fc2(x)
176
output = F.log_softmax(x, dim=1)
177
return output
178
179
# MNIST Test dataset and dataloader declaration
180
test_loader = torch.utils.data.DataLoader(
181
datasets.MNIST('../data', train=False, download=True, transform=transforms.Compose([
182
transforms.ToTensor(),
183
transforms.Normalize((0.1307,), (0.3081,)),
184
])),
185
batch_size=1, shuffle=True)
186
187
# Define what device we are using
188
print("CUDA Available: ",torch.cuda.is_available())
189
device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu")
190
191
# Initialize the network
192
model = Net().to(device)
193
194
# Load the pretrained model
195
model.load_state_dict(torch.load(pretrained_model, map_location=device, weights_only=True))
196
197
# Set the model in evaluation mode. In this case this is for the Dropout layers
198
model.eval()
199
200
201
######################################################################
202
# FGSM Attack
203
# ~~~~~~~~~~~
204
#
205
# Now, we can define the function that creates the adversarial examples by
206
# perturbing the original inputs. The ``fgsm_attack`` function takes three
207
# inputs, *image* is the original clean image (:math:`x`), *epsilon* is
208
# the pixel-wise perturbation amount (:math:`\epsilon`), and *data_grad*
209
# is gradient of the loss w.r.t the input image
210
# (:math:`\nabla_{x} J(\mathbf{\theta}, \mathbf{x}, y)`). The function
211
# then creates perturbed image as
212
#
213
# .. math:: perturbed\_image = image + epsilon*sign(data\_grad) = x + \epsilon * sign(\nabla_{x} J(\mathbf{\theta}, \mathbf{x}, y))
214
#
215
# Finally, in order to maintain the original range of the data, the
216
# perturbed image is clipped to range :math:`[0,1]`.
217
#
218
219
# FGSM attack code
220
def fgsm_attack(image, epsilon, data_grad):
221
# Collect the element-wise sign of the data gradient
222
sign_data_grad = data_grad.sign()
223
# Create the perturbed image by adjusting each pixel of the input image
224
perturbed_image = image + epsilon*sign_data_grad
225
# Adding clipping to maintain [0,1] range
226
perturbed_image = torch.clamp(perturbed_image, 0, 1)
227
# Return the perturbed image
228
return perturbed_image
229
230
# restores the tensors to their original scale
231
def denorm(batch, mean=[0.1307], std=[0.3081]):
232
"""
233
Convert a batch of tensors to their original scale.
234
235
Args:
236
batch (torch.Tensor): Batch of normalized tensors.
237
mean (torch.Tensor or list): Mean used for normalization.
238
std (torch.Tensor or list): Standard deviation used for normalization.
239
240
Returns:
241
torch.Tensor: batch of tensors without normalization applied to them.
242
"""
243
if isinstance(mean, list):
244
mean = torch.tensor(mean).to(device)
245
if isinstance(std, list):
246
std = torch.tensor(std).to(device)
247
248
return batch * std.view(1, -1, 1, 1) + mean.view(1, -1, 1, 1)
249
250
251
######################################################################
252
# Testing Function
253
# ~~~~~~~~~~~~~~~~
254
#
255
# Finally, the central result of this tutorial comes from the ``test``
256
# function. Each call to this test function performs a full test step on
257
# the MNIST test set and reports a final accuracy. However, notice that
258
# this function also takes an *epsilon* input. This is because the
259
# ``test`` function reports the accuracy of a model that is under attack
260
# from an adversary with strength :math:`\epsilon`. More specifically, for
261
# each sample in the test set, the function computes the gradient of the
262
# loss w.r.t the input data (:math:`data\_grad`), creates a perturbed
263
# image with ``fgsm_attack`` (:math:`perturbed\_data`), then checks to see
264
# if the perturbed example is adversarial. In addition to testing the
265
# accuracy of the model, the function also saves and returns some
266
# successful adversarial examples to be visualized later.
267
#
268
269
def test( model, device, test_loader, epsilon ):
270
271
# Accuracy counter
272
correct = 0
273
adv_examples = []
274
275
# Loop over all examples in test set
276
for data, target in test_loader:
277
278
# Send the data and label to the device
279
data, target = data.to(device), target.to(device)
280
281
# Set requires_grad attribute of tensor. Important for Attack
282
data.requires_grad = True
283
284
# Forward pass the data through the model
285
output = model(data)
286
init_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
287
288
# If the initial prediction is wrong, don't bother attacking, just move on
289
if init_pred.item() != target.item():
290
continue
291
292
# Calculate the loss
293
loss = F.nll_loss(output, target)
294
295
# Zero all existing gradients
296
model.zero_grad()
297
298
# Calculate gradients of model in backward pass
299
loss.backward()
300
301
# Collect ``datagrad``
302
data_grad = data.grad.data
303
304
# Restore the data to its original scale
305
data_denorm = denorm(data)
306
307
# Call FGSM Attack
308
perturbed_data = fgsm_attack(data_denorm, epsilon, data_grad)
309
310
# Reapply normalization
311
perturbed_data_normalized = transforms.Normalize((0.1307,), (0.3081,))(perturbed_data)
312
313
# Re-classify the perturbed image
314
output = model(perturbed_data_normalized)
315
316
# Check for success
317
final_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
318
if final_pred.item() == target.item():
319
correct += 1
320
# Special case for saving 0 epsilon examples
321
if epsilon == 0 and len(adv_examples) < 5:
322
adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
323
adv_examples.append( (init_pred.item(), final_pred.item(), adv_ex) )
324
else:
325
# Save some adv examples for visualization later
326
if len(adv_examples) < 5:
327
adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
328
adv_examples.append( (init_pred.item(), final_pred.item(), adv_ex) )
329
330
# Calculate final accuracy for this epsilon
331
final_acc = correct/float(len(test_loader))
332
print(f"Epsilon: {epsilon}\tTest Accuracy = {correct} / {len(test_loader)} = {final_acc}")
333
334
# Return the accuracy and an adversarial example
335
return final_acc, adv_examples
336
337
338
######################################################################
339
# Run Attack
340
# ~~~~~~~~~~
341
#
342
# The last part of the implementation is to actually run the attack. Here,
343
# we run a full test step for each epsilon value in the *epsilons* input.
344
# For each epsilon we also save the final accuracy and some successful
345
# adversarial examples to be plotted in the coming sections. Notice how
346
# the printed accuracies decrease as the epsilon value increases. Also,
347
# note the :math:`\epsilon=0` case represents the original test accuracy,
348
# with no attack.
349
#
350
351
accuracies = []
352
examples = []
353
354
# Run test for each epsilon
355
for eps in epsilons:
356
acc, ex = test(model, device, test_loader, eps)
357
accuracies.append(acc)
358
examples.append(ex)
359
360
361
######################################################################
362
# Results
363
# -------
364
#
365
# Accuracy vs Epsilon
366
# ~~~~~~~~~~~~~~~~~~~
367
#
368
# The first result is the accuracy versus epsilon plot. As alluded to
369
# earlier, as epsilon increases we expect the test accuracy to decrease.
370
# This is because larger epsilons mean we take a larger step in the
371
# direction that will maximize the loss. Notice the trend in the curve is
372
# not linear even though the epsilon values are linearly spaced. For
373
# example, the accuracy at :math:`\epsilon=0.05` is only about 4% lower
374
# than :math:`\epsilon=0`, but the accuracy at :math:`\epsilon=0.2` is 25%
375
# lower than :math:`\epsilon=0.15`. Also, notice the accuracy of the model
376
# hits random accuracy for a 10-class classifier between
377
# :math:`\epsilon=0.25` and :math:`\epsilon=0.3`.
378
#
379
380
plt.figure(figsize=(5,5))
381
plt.plot(epsilons, accuracies, "*-")
382
plt.yticks(np.arange(0, 1.1, step=0.1))
383
plt.xticks(np.arange(0, .35, step=0.05))
384
plt.title("Accuracy vs Epsilon")
385
plt.xlabel("Epsilon")
386
plt.ylabel("Accuracy")
387
plt.show()
388
389
390
######################################################################
391
# Sample Adversarial Examples
392
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
393
#
394
# Remember the idea of no free lunch? In this case, as epsilon increases
395
# the test accuracy decreases **BUT** the perturbations become more easily
396
# perceptible. In reality, there is a tradeoff between accuracy
397
# degradation and perceptibility that an attacker must consider. Here, we
398
# show some examples of successful adversarial examples at each epsilon
399
# value. Each row of the plot shows a different epsilon value. The first
400
# row is the :math:`\epsilon=0` examples which represent the original
401
# “clean” images with no perturbation. The title of each image shows the
402
# “original classification -> adversarial classification.” Notice, the
403
# perturbations start to become evident at :math:`\epsilon=0.15` and are
404
# quite evident at :math:`\epsilon=0.3`. However, in all cases humans are
405
# still capable of identifying the correct class despite the added noise.
406
#
407
408
# Plot several examples of adversarial samples at each epsilon
409
cnt = 0
410
plt.figure(figsize=(8,10))
411
for i in range(len(epsilons)):
412
for j in range(len(examples[i])):
413
cnt += 1
414
plt.subplot(len(epsilons),len(examples[0]),cnt)
415
plt.xticks([], [])
416
plt.yticks([], [])
417
if j == 0:
418
plt.ylabel(f"Eps: {epsilons[i]}", fontsize=14)
419
orig,adv,ex = examples[i][j]
420
plt.title(f"{orig} -> {adv}")
421
plt.imshow(ex, cmap="gray")
422
plt.tight_layout()
423
plt.show()
424
425
426
######################################################################
427
# Where to go next?
428
# -----------------
429
#
430
# Hopefully this tutorial gives some insight into the topic of adversarial
431
# machine learning. There are many potential directions to go from here.
432
# This attack represents the very beginning of adversarial attack research
433
# and since there have been many subsequent ideas for how to attack and
434
# defend ML models from an adversary. In fact, at NIPS 2017 there was an
435
# adversarial attack and defense competition and many of the methods used
436
# in the competition are described in this paper: `Adversarial Attacks and
437
# Defences Competition <https://arxiv.org/pdf/1804.00097.pdf>`__. The work
438
# on defense also leads into the idea of making machine learning models
439
# more *robust* in general, to both naturally perturbed and adversarially
440
# crafted inputs.
441
#
442
# Another direction to go is adversarial attacks and defense in different
443
# domains. Adversarial research is not limited to the image domain, check
444
# out `this <https://arxiv.org/pdf/1801.01944.pdf>`__ attack on
445
# speech-to-text models. But perhaps the best way to learn more about
446
# adversarial machine learning is to get your hands dirty. Try to
447
# implement a different attack from the NIPS 2017 competition, and see how
448
# it differs from FGSM. Then, try to defend the model from your own
449
# attacks.
450
#
451
# A further direction to go, depending on available resources, is to modify
452
# the code to support processing work in batch, in parallel, and or distributed
453
# vs working on one attack at a time in the above for each ``epsilon test()`` loop.
454
#
455
456