Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pytorch
GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/fgsm_tutorial.py
1367 views
1
# -*- coding: utf-8 -*-
2
"""
3
Adversarial Example Generation
4
==============================
5
6
**Author:** `Nathan Inkawhich <https://github.com/inkawhich>`__
7
8
If you are reading this, hopefully you can appreciate how effective some
9
machine learning models are. Research is constantly pushing ML models to
10
be faster, more accurate, and more efficient. However, an often
11
overlooked aspect of designing and training models is security and
12
robustness, especially in the face of an adversary who wishes to fool
13
the model.
14
15
This tutorial will raise your awareness to the security vulnerabilities
16
of ML models, and will give insight into the hot topic of adversarial
17
machine learning. You may be surprised to find that adding imperceptible
18
perturbations to an image *can* cause drastically different model
19
performance. Given that this is a tutorial, we will explore the topic
20
via example on an image classifier. Specifically, we will use one of the
21
first and most popular attack methods, the Fast Gradient Sign Attack
22
(FGSM), to fool an MNIST classifier.
23
24
"""
25
26
27
######################################################################
28
# Threat Model
29
# ------------
30
#
31
# For context, there are many categories of adversarial attacks, each with
32
# a different goal and assumption of the attacker’s knowledge. However, in
33
# general the overarching goal is to add the least amount of perturbation
34
# to the input data to cause the desired misclassification. There are
35
# several kinds of assumptions of the attacker’s knowledge, two of which
36
# are: **white-box** and **black-box**. A *white-box* attack assumes the
37
# attacker has full knowledge and access to the model, including
38
# architecture, inputs, outputs, and weights. A *black-box* attack assumes
39
# the attacker only has access to the inputs and outputs of the model, and
40
# knows nothing about the underlying architecture or weights. There are
41
# also several types of goals, including **misclassification** and
42
# **source/target misclassification**. A goal of *misclassification* means
43
# the adversary only wants the output classification to be wrong but does
44
# not care what the new classification is. A *source/target
45
# misclassification* means the adversary wants to alter an image that is
46
# originally of a specific source class so that it is classified as a
47
# specific target class.
48
#
49
# In this case, the FGSM attack is a *white-box* attack with the goal of
50
# *misclassification*. With this background information, we can now
51
# discuss the attack in detail.
52
#
53
# Fast Gradient Sign Attack
54
# -------------------------
55
#
56
# One of the first and most popular adversarial attacks to date is
57
# referred to as the *Fast Gradient Sign Attack (FGSM)* and is described
58
# by Goodfellow et. al. in `Explaining and Harnessing Adversarial
59
# Examples <https://arxiv.org/abs/1412.6572>`__. The attack is remarkably
60
# powerful, and yet intuitive. It is designed to attack neural networks by
61
# leveraging the way they learn, *gradients*. The idea is simple, rather
62
# than working to minimize the loss by adjusting the weights based on the
63
# backpropagated gradients, the attack *adjusts the input data to maximize
64
# the loss* based on the same backpropagated gradients. In other words,
65
# the attack uses the gradient of the loss w.r.t the input data, then
66
# adjusts the input data to maximize the loss.
67
#
68
# Before we jump into the code, let’s look at the famous
69
# `FGSM <https://arxiv.org/abs/1412.6572>`__ panda example and extract
70
# some notation.
71
#
72
# .. figure:: /_static/img/fgsm_panda_image.png
73
# :alt: fgsm_panda_image
74
#
75
# From the figure, :math:`\mathbf{x}` is the original input image
76
# correctly classified as a “panda”, :math:`y` is the ground truth label
77
# for :math:`\mathbf{x}`, :math:`\mathbf{\theta}` represents the model
78
# parameters, and :math:`J(\mathbf{\theta}, \mathbf{x}, y)` is the loss
79
# that is used to train the network. The attack backpropagates the
80
# gradient back to the input data to calculate
81
# :math:`\nabla_{x} J(\mathbf{\theta}, \mathbf{x}, y)`. Then, it adjusts
82
# the input data by a small step (:math:`\epsilon` or :math:`0.007` in the
83
# picture) in the direction (i.e.
84
# :math:`sign(\nabla_{x} J(\mathbf{\theta}, \mathbf{x}, y))`) that will
85
# maximize the loss. The resulting perturbed image, :math:`x'`, is then
86
# *misclassified* by the target network as a “gibbon” when it is still
87
# clearly a “panda”.
88
#
89
# Hopefully now the motivation for this tutorial is clear, so lets jump
90
# into the implementation.
91
#
92
93
import torch
94
import torch.nn as nn
95
import torch.nn.functional as F
96
import torch.optim as optim
97
from torchvision import datasets, transforms
98
import numpy as np
99
import matplotlib.pyplot as plt
100
101
102
######################################################################
103
# Implementation
104
# --------------
105
#
106
# In this section, we will discuss the input parameters for the tutorial,
107
# define the model under attack, then code the attack and run some tests.
108
#
109
# Inputs
110
# ~~~~~~
111
#
112
# There are only three inputs for this tutorial, and are defined as
113
# follows:
114
#
115
# - ``epsilons`` - List of epsilon values to use for the run. It is
116
# important to keep 0 in the list because it represents the model
117
# performance on the original test set. Also, intuitively we would
118
# expect the larger the epsilon, the more noticeable the perturbations
119
# but the more effective the attack in terms of degrading model
120
# accuracy. Since the data range here is :math:`[0,1]`, no epsilon
121
# value should exceed 1.
122
#
123
# - ``pretrained_model`` - path to the pretrained MNIST model which was
124
# trained with
125
# `pytorch/examples/mnist <https://github.com/pytorch/examples/tree/master/mnist>`__.
126
# For simplicity, download the pretrained model `here <https://drive.google.com/file/d/1HJV2nUHJqclXQ8flKvcWmjZ-OU5DGatl/view?usp=drive_link>`__.
127
#
128
129
epsilons = [0, .05, .1, .15, .2, .25, .3]
130
pretrained_model = "data/lenet_mnist_model.pth"
131
# Set random seed for reproducibility
132
torch.manual_seed(42)
133
134
135
######################################################################
136
# Model Under Attack
137
# ~~~~~~~~~~~~~~~~~~
138
#
139
# As mentioned, the model under attack is the same MNIST model from
140
# `pytorch/examples/mnist <https://github.com/pytorch/examples/tree/master/mnist>`__.
141
# You may train and save your own MNIST model or you can download and use
142
# the provided model. The *Net* definition and test dataloader here have
143
# been copied from the MNIST example. The purpose of this section is to
144
# define the model and dataloader, then initialize the model and load the
145
# pretrained weights.
146
#
147
148
# LeNet Model definition
149
class Net(nn.Module):
150
def __init__(self):
151
super(Net, self).__init__()
152
self.conv1 = nn.Conv2d(1, 32, 3, 1)
153
self.conv2 = nn.Conv2d(32, 64, 3, 1)
154
self.dropout1 = nn.Dropout(0.25)
155
self.dropout2 = nn.Dropout(0.5)
156
self.fc1 = nn.Linear(9216, 128)
157
self.fc2 = nn.Linear(128, 10)
158
159
def forward(self, x):
160
x = self.conv1(x)
161
x = F.relu(x)
162
x = self.conv2(x)
163
x = F.relu(x)
164
x = F.max_pool2d(x, 2)
165
x = self.dropout1(x)
166
x = torch.flatten(x, 1)
167
x = self.fc1(x)
168
x = F.relu(x)
169
x = self.dropout2(x)
170
x = self.fc2(x)
171
output = F.log_softmax(x, dim=1)
172
return output
173
174
# MNIST Test dataset and dataloader declaration
175
test_loader = torch.utils.data.DataLoader(
176
datasets.MNIST('../data', train=False, download=True, transform=transforms.Compose([
177
transforms.ToTensor(),
178
transforms.Normalize((0.1307,), (0.3081,)),
179
])),
180
batch_size=1, shuffle=True)
181
182
# We want to be able to train our model on an `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
183
# such as CUDA, MPS, MTIA, or XPU. If the current accelerator is available, we will use it. Otherwise, we use the CPU.
184
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
185
print(f"Using {device} device")
186
187
# Initialize the network
188
model = Net().to(device)
189
190
# Load the pretrained model
191
model.load_state_dict(torch.load(pretrained_model, map_location=device, weights_only=True))
192
193
# Set the model in evaluation mode. In this case this is for the Dropout layers
194
model.eval()
195
196
197
######################################################################
198
# FGSM Attack
199
# ~~~~~~~~~~~
200
#
201
# Now, we can define the function that creates the adversarial examples by
202
# perturbing the original inputs. The ``fgsm_attack`` function takes three
203
# inputs, *image* is the original clean image (:math:`x`), *epsilon* is
204
# the pixel-wise perturbation amount (:math:`\epsilon`), and *data_grad*
205
# is gradient of the loss w.r.t the input image
206
# (:math:`\nabla_{x} J(\mathbf{\theta}, \mathbf{x}, y)`). The function
207
# then creates perturbed image as
208
#
209
# .. math:: perturbed\_image = image + epsilon*sign(data\_grad) = x + \epsilon * sign(\nabla_{x} J(\mathbf{\theta}, \mathbf{x}, y))
210
#
211
# Finally, in order to maintain the original range of the data, the
212
# perturbed image is clipped to range :math:`[0,1]`.
213
#
214
215
# FGSM attack code
216
def fgsm_attack(image, epsilon, data_grad):
217
# Collect the element-wise sign of the data gradient
218
sign_data_grad = data_grad.sign()
219
# Create the perturbed image by adjusting each pixel of the input image
220
perturbed_image = image + epsilon*sign_data_grad
221
# Adding clipping to maintain [0,1] range
222
perturbed_image = torch.clamp(perturbed_image, 0, 1)
223
# Return the perturbed image
224
return perturbed_image
225
226
# restores the tensors to their original scale
227
def denorm(batch, mean=[0.1307], std=[0.3081]):
228
"""
229
Convert a batch of tensors to their original scale.
230
231
Args:
232
batch (torch.Tensor): Batch of normalized tensors.
233
mean (torch.Tensor or list): Mean used for normalization.
234
std (torch.Tensor or list): Standard deviation used for normalization.
235
236
Returns:
237
torch.Tensor: batch of tensors without normalization applied to them.
238
"""
239
if isinstance(mean, list):
240
mean = torch.tensor(mean).to(device)
241
if isinstance(std, list):
242
std = torch.tensor(std).to(device)
243
244
return batch * std.view(1, -1, 1, 1) + mean.view(1, -1, 1, 1)
245
246
247
######################################################################
248
# Testing Function
249
# ~~~~~~~~~~~~~~~~
250
#
251
# Finally, the central result of this tutorial comes from the ``test``
252
# function. Each call to this test function performs a full test step on
253
# the MNIST test set and reports a final accuracy. However, notice that
254
# this function also takes an *epsilon* input. This is because the
255
# ``test`` function reports the accuracy of a model that is under attack
256
# from an adversary with strength :math:`\epsilon`. More specifically, for
257
# each sample in the test set, the function computes the gradient of the
258
# loss w.r.t the input data (:math:`data\_grad`), creates a perturbed
259
# image with ``fgsm_attack`` (:math:`perturbed\_data`), then checks to see
260
# if the perturbed example is adversarial. In addition to testing the
261
# accuracy of the model, the function also saves and returns some
262
# successful adversarial examples to be visualized later.
263
#
264
265
def test( model, device, test_loader, epsilon ):
266
267
# Accuracy counter
268
correct = 0
269
adv_examples = []
270
271
# Loop over all examples in test set
272
for data, target in test_loader:
273
274
# Send the data and label to the device
275
data, target = data.to(device), target.to(device)
276
277
# Set requires_grad attribute of tensor. Important for Attack
278
data.requires_grad = True
279
280
# Forward pass the data through the model
281
output = model(data)
282
init_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
283
284
# If the initial prediction is wrong, don't bother attacking, just move on
285
if init_pred.item() != target.item():
286
continue
287
288
# Calculate the loss
289
loss = F.nll_loss(output, target)
290
291
# Zero all existing gradients
292
model.zero_grad()
293
294
# Calculate gradients of model in backward pass
295
loss.backward()
296
297
# Collect ``datagrad``
298
data_grad = data.grad.data
299
300
# Restore the data to its original scale
301
data_denorm = denorm(data)
302
303
# Call FGSM Attack
304
perturbed_data = fgsm_attack(data_denorm, epsilon, data_grad)
305
306
# Reapply normalization
307
perturbed_data_normalized = transforms.Normalize((0.1307,), (0.3081,))(perturbed_data)
308
309
# Re-classify the perturbed image
310
output = model(perturbed_data_normalized)
311
312
# Check for success
313
final_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
314
if final_pred.item() == target.item():
315
correct += 1
316
# Special case for saving 0 epsilon examples
317
if epsilon == 0 and len(adv_examples) < 5:
318
adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
319
adv_examples.append( (init_pred.item(), final_pred.item(), adv_ex) )
320
else:
321
# Save some adv examples for visualization later
322
if len(adv_examples) < 5:
323
adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
324
adv_examples.append( (init_pred.item(), final_pred.item(), adv_ex) )
325
326
# Calculate final accuracy for this epsilon
327
final_acc = correct/float(len(test_loader))
328
print(f"Epsilon: {epsilon}\tTest Accuracy = {correct} / {len(test_loader)} = {final_acc}")
329
330
# Return the accuracy and an adversarial example
331
return final_acc, adv_examples
332
333
334
######################################################################
335
# Run Attack
336
# ~~~~~~~~~~
337
#
338
# The last part of the implementation is to actually run the attack. Here,
339
# we run a full test step for each epsilon value in the *epsilons* input.
340
# For each epsilon we also save the final accuracy and some successful
341
# adversarial examples to be plotted in the coming sections. Notice how
342
# the printed accuracies decrease as the epsilon value increases. Also,
343
# note the :math:`\epsilon=0` case represents the original test accuracy,
344
# with no attack.
345
#
346
347
accuracies = []
348
examples = []
349
350
# Run test for each epsilon
351
for eps in epsilons:
352
acc, ex = test(model, device, test_loader, eps)
353
accuracies.append(acc)
354
examples.append(ex)
355
356
357
######################################################################
358
# Results
359
# -------
360
#
361
# Accuracy vs Epsilon
362
# ~~~~~~~~~~~~~~~~~~~
363
#
364
# The first result is the accuracy versus epsilon plot. As alluded to
365
# earlier, as epsilon increases we expect the test accuracy to decrease.
366
# This is because larger epsilons mean we take a larger step in the
367
# direction that will maximize the loss. Notice the trend in the curve is
368
# not linear even though the epsilon values are linearly spaced. For
369
# example, the accuracy at :math:`\epsilon=0.05` is only about 4% lower
370
# than :math:`\epsilon=0`, but the accuracy at :math:`\epsilon=0.2` is 25%
371
# lower than :math:`\epsilon=0.15`. Also, notice the accuracy of the model
372
# hits random accuracy for a 10-class classifier between
373
# :math:`\epsilon=0.25` and :math:`\epsilon=0.3`.
374
#
375
376
plt.figure(figsize=(5,5))
377
plt.plot(epsilons, accuracies, "*-")
378
plt.yticks(np.arange(0, 1.1, step=0.1))
379
plt.xticks(np.arange(0, .35, step=0.05))
380
plt.title("Accuracy vs Epsilon")
381
plt.xlabel("Epsilon")
382
plt.ylabel("Accuracy")
383
plt.show()
384
385
386
######################################################################
387
# Sample Adversarial Examples
388
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
389
#
390
# Remember the idea of no free lunch? In this case, as epsilon increases
391
# the test accuracy decreases **BUT** the perturbations become more easily
392
# perceptible. In reality, there is a tradeoff between accuracy
393
# degradation and perceptibility that an attacker must consider. Here, we
394
# show some examples of successful adversarial examples at each epsilon
395
# value. Each row of the plot shows a different epsilon value. The first
396
# row is the :math:`\epsilon=0` examples which represent the original
397
# “clean” images with no perturbation. The title of each image shows the
398
# “original classification -> adversarial classification.” Notice, the
399
# perturbations start to become evident at :math:`\epsilon=0.15` and are
400
# quite evident at :math:`\epsilon=0.3`. However, in all cases humans are
401
# still capable of identifying the correct class despite the added noise.
402
#
403
404
# Plot several examples of adversarial samples at each epsilon
405
cnt = 0
406
plt.figure(figsize=(8,10))
407
for i in range(len(epsilons)):
408
for j in range(len(examples[i])):
409
cnt += 1
410
plt.subplot(len(epsilons),len(examples[0]),cnt)
411
plt.xticks([], [])
412
plt.yticks([], [])
413
if j == 0:
414
plt.ylabel(f"Eps: {epsilons[i]}", fontsize=14)
415
orig,adv,ex = examples[i][j]
416
plt.title(f"{orig} -> {adv}")
417
plt.imshow(ex, cmap="gray")
418
plt.tight_layout()
419
plt.show()
420
421
422
######################################################################
423
# Where to go next?
424
# -----------------
425
#
426
# Hopefully this tutorial gives some insight into the topic of adversarial
427
# machine learning. There are many potential directions to go from here.
428
# This attack represents the very beginning of adversarial attack research
429
# and since there have been many subsequent ideas for how to attack and
430
# defend ML models from an adversary. In fact, at NIPS 2017 there was an
431
# adversarial attack and defense competition and many of the methods used
432
# in the competition are described in this paper: `Adversarial Attacks and
433
# Defences Competition <https://arxiv.org/pdf/1804.00097.pdf>`__. The work
434
# on defense also leads into the idea of making machine learning models
435
# more *robust* in general, to both naturally perturbed and adversarially
436
# crafted inputs.
437
#
438
# Another direction to go is adversarial attacks and defense in different
439
# domains. Adversarial research is not limited to the image domain, check
440
# out `this <https://arxiv.org/pdf/1801.01944.pdf>`__ attack on
441
# speech-to-text models. But perhaps the best way to learn more about
442
# adversarial machine learning is to get your hands dirty. Try to
443
# implement a different attack from the NIPS 2017 competition, and see how
444
# it differs from FGSM. Then, try to defend the model from your own
445
# attacks.
446
#
447
# A further direction to go, depending on available resources, is to modify
448
# the code to support processing work in batch, in parallel, and or distributed
449
# vs working on one attack at a time in the above for each ``epsilon test()`` loop.
450
#
451
452