# -*- coding: utf-8 -*-1"""2Adversarial Example Generation3==============================45**Author:** `Nathan Inkawhich <https://github.com/inkawhich>`__67If you are reading this, hopefully you can appreciate how effective some8machine learning models are. Research is constantly pushing ML models to9be faster, more accurate, and more efficient. However, an often10overlooked aspect of designing and training models is security and11robustness, especially in the face of an adversary who wishes to fool12the model.1314This tutorial will raise your awareness to the security vulnerabilities15of ML models, and will give insight into the hot topic of adversarial16machine learning. You may be surprised to find that adding imperceptible17perturbations to an image *can* cause drastically different model18performance. Given that this is a tutorial, we will explore the topic19via example on an image classifier. Specifically, we will use one of the20first and most popular attack methods, the Fast Gradient Sign Attack21(FGSM), to fool an MNIST classifier.2223"""242526######################################################################27# Threat Model28# ------------29#30# For context, there are many categories of adversarial attacks, each with31# a different goal and assumption of the attacker’s knowledge. However, in32# general the overarching goal is to add the least amount of perturbation33# to the input data to cause the desired misclassification. There are34# several kinds of assumptions of the attacker’s knowledge, two of which35# are: **white-box** and **black-box**. A *white-box* attack assumes the36# attacker has full knowledge and access to the model, including37# architecture, inputs, outputs, and weights. A *black-box* attack assumes38# the attacker only has access to the inputs and outputs of the model, and39# knows nothing about the underlying architecture or weights. There are40# also several types of goals, including **misclassification** and41# **source/target misclassification**. A goal of *misclassification* means42# the adversary only wants the output classification to be wrong but does43# not care what the new classification is. A *source/target44# misclassification* means the adversary wants to alter an image that is45# originally of a specific source class so that it is classified as a46# specific target class.47#48# In this case, the FGSM attack is a *white-box* attack with the goal of49# *misclassification*. With this background information, we can now50# discuss the attack in detail.51#52# Fast Gradient Sign Attack53# -------------------------54#55# One of the first and most popular adversarial attacks to date is56# referred to as the *Fast Gradient Sign Attack (FGSM)* and is described57# by Goodfellow et. al. in `Explaining and Harnessing Adversarial58# Examples <https://arxiv.org/abs/1412.6572>`__. The attack is remarkably59# powerful, and yet intuitive. It is designed to attack neural networks by60# leveraging the way they learn, *gradients*. The idea is simple, rather61# than working to minimize the loss by adjusting the weights based on the62# backpropagated gradients, the attack *adjusts the input data to maximize63# the loss* based on the same backpropagated gradients. In other words,64# the attack uses the gradient of the loss w.r.t the input data, then65# adjusts the input data to maximize the loss.66#67# Before we jump into the code, let’s look at the famous68# `FGSM <https://arxiv.org/abs/1412.6572>`__ panda example and extract69# some notation.70#71# .. figure:: /_static/img/fgsm_panda_image.png72# :alt: fgsm_panda_image73#74# From the figure, :math:`\mathbf{x}` is the original input image75# correctly classified as a “panda”, :math:`y` is the ground truth label76# for :math:`\mathbf{x}`, :math:`\mathbf{\theta}` represents the model77# parameters, and :math:`J(\mathbf{\theta}, \mathbf{x}, y)` is the loss78# that is used to train the network. The attack backpropagates the79# gradient back to the input data to calculate80# :math:`\nabla_{x} J(\mathbf{\theta}, \mathbf{x}, y)`. Then, it adjusts81# the input data by a small step (:math:`\epsilon` or :math:`0.007` in the82# picture) in the direction (i.e.83# :math:`sign(\nabla_{x} J(\mathbf{\theta}, \mathbf{x}, y))`) that will84# maximize the loss. The resulting perturbed image, :math:`x'`, is then85# *misclassified* by the target network as a “gibbon” when it is still86# clearly a “panda”.87#88# Hopefully now the motivation for this tutorial is clear, so lets jump89# into the implementation.90#9192import torch93import torch.nn as nn94import torch.nn.functional as F95import torch.optim as optim96from torchvision import datasets, transforms97import numpy as np98import matplotlib.pyplot as plt99100101######################################################################102# Implementation103# --------------104#105# In this section, we will discuss the input parameters for the tutorial,106# define the model under attack, then code the attack and run some tests.107#108# Inputs109# ~~~~~~110#111# There are only three inputs for this tutorial, and are defined as112# follows:113#114# - ``epsilons`` - List of epsilon values to use for the run. It is115# important to keep 0 in the list because it represents the model116# performance on the original test set. Also, intuitively we would117# expect the larger the epsilon, the more noticeable the perturbations118# but the more effective the attack in terms of degrading model119# accuracy. Since the data range here is :math:`[0,1]`, no epsilon120# value should exceed 1.121#122# - ``pretrained_model`` - path to the pretrained MNIST model which was123# trained with124# `pytorch/examples/mnist <https://github.com/pytorch/examples/tree/master/mnist>`__.125# For simplicity, download the pretrained model `here <https://drive.google.com/file/d/1HJV2nUHJqclXQ8flKvcWmjZ-OU5DGatl/view?usp=drive_link>`__.126#127128epsilons = [0, .05, .1, .15, .2, .25, .3]129pretrained_model = "data/lenet_mnist_model.pth"130# Set random seed for reproducibility131torch.manual_seed(42)132133134######################################################################135# Model Under Attack136# ~~~~~~~~~~~~~~~~~~137#138# As mentioned, the model under attack is the same MNIST model from139# `pytorch/examples/mnist <https://github.com/pytorch/examples/tree/master/mnist>`__.140# You may train and save your own MNIST model or you can download and use141# the provided model. The *Net* definition and test dataloader here have142# been copied from the MNIST example. The purpose of this section is to143# define the model and dataloader, then initialize the model and load the144# pretrained weights.145#146147# LeNet Model definition148class Net(nn.Module):149def __init__(self):150super(Net, self).__init__()151self.conv1 = nn.Conv2d(1, 32, 3, 1)152self.conv2 = nn.Conv2d(32, 64, 3, 1)153self.dropout1 = nn.Dropout(0.25)154self.dropout2 = nn.Dropout(0.5)155self.fc1 = nn.Linear(9216, 128)156self.fc2 = nn.Linear(128, 10)157158def forward(self, x):159x = self.conv1(x)160x = F.relu(x)161x = self.conv2(x)162x = F.relu(x)163x = F.max_pool2d(x, 2)164x = self.dropout1(x)165x = torch.flatten(x, 1)166x = self.fc1(x)167x = F.relu(x)168x = self.dropout2(x)169x = self.fc2(x)170output = F.log_softmax(x, dim=1)171return output172173# MNIST Test dataset and dataloader declaration174test_loader = torch.utils.data.DataLoader(175datasets.MNIST('../data', train=False, download=True, transform=transforms.Compose([176transforms.ToTensor(),177transforms.Normalize((0.1307,), (0.3081,)),178])),179batch_size=1, shuffle=True)180181# We want to be able to train our model on an `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__182# such as CUDA, MPS, MTIA, or XPU. If the current accelerator is available, we will use it. Otherwise, we use the CPU.183device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"184print(f"Using {device} device")185186# Initialize the network187model = Net().to(device)188189# Load the pretrained model190model.load_state_dict(torch.load(pretrained_model, map_location=device, weights_only=True))191192# Set the model in evaluation mode. In this case this is for the Dropout layers193model.eval()194195196######################################################################197# FGSM Attack198# ~~~~~~~~~~~199#200# Now, we can define the function that creates the adversarial examples by201# perturbing the original inputs. The ``fgsm_attack`` function takes three202# inputs, *image* is the original clean image (:math:`x`), *epsilon* is203# the pixel-wise perturbation amount (:math:`\epsilon`), and *data_grad*204# is gradient of the loss w.r.t the input image205# (:math:`\nabla_{x} J(\mathbf{\theta}, \mathbf{x}, y)`). The function206# then creates perturbed image as207#208# .. math:: perturbed\_image = image + epsilon*sign(data\_grad) = x + \epsilon * sign(\nabla_{x} J(\mathbf{\theta}, \mathbf{x}, y))209#210# Finally, in order to maintain the original range of the data, the211# perturbed image is clipped to range :math:`[0,1]`.212#213214# FGSM attack code215def fgsm_attack(image, epsilon, data_grad):216# Collect the element-wise sign of the data gradient217sign_data_grad = data_grad.sign()218# Create the perturbed image by adjusting each pixel of the input image219perturbed_image = image + epsilon*sign_data_grad220# Adding clipping to maintain [0,1] range221perturbed_image = torch.clamp(perturbed_image, 0, 1)222# Return the perturbed image223return perturbed_image224225# restores the tensors to their original scale226def denorm(batch, mean=[0.1307], std=[0.3081]):227"""228Convert a batch of tensors to their original scale.229230Args:231batch (torch.Tensor): Batch of normalized tensors.232mean (torch.Tensor or list): Mean used for normalization.233std (torch.Tensor or list): Standard deviation used for normalization.234235Returns:236torch.Tensor: batch of tensors without normalization applied to them.237"""238if isinstance(mean, list):239mean = torch.tensor(mean).to(device)240if isinstance(std, list):241std = torch.tensor(std).to(device)242243return batch * std.view(1, -1, 1, 1) + mean.view(1, -1, 1, 1)244245246######################################################################247# Testing Function248# ~~~~~~~~~~~~~~~~249#250# Finally, the central result of this tutorial comes from the ``test``251# function. Each call to this test function performs a full test step on252# the MNIST test set and reports a final accuracy. However, notice that253# this function also takes an *epsilon* input. This is because the254# ``test`` function reports the accuracy of a model that is under attack255# from an adversary with strength :math:`\epsilon`. More specifically, for256# each sample in the test set, the function computes the gradient of the257# loss w.r.t the input data (:math:`data\_grad`), creates a perturbed258# image with ``fgsm_attack`` (:math:`perturbed\_data`), then checks to see259# if the perturbed example is adversarial. In addition to testing the260# accuracy of the model, the function also saves and returns some261# successful adversarial examples to be visualized later.262#263264def test( model, device, test_loader, epsilon ):265266# Accuracy counter267correct = 0268adv_examples = []269270# Loop over all examples in test set271for data, target in test_loader:272273# Send the data and label to the device274data, target = data.to(device), target.to(device)275276# Set requires_grad attribute of tensor. Important for Attack277data.requires_grad = True278279# Forward pass the data through the model280output = model(data)281init_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability282283# If the initial prediction is wrong, don't bother attacking, just move on284if init_pred.item() != target.item():285continue286287# Calculate the loss288loss = F.nll_loss(output, target)289290# Zero all existing gradients291model.zero_grad()292293# Calculate gradients of model in backward pass294loss.backward()295296# Collect ``datagrad``297data_grad = data.grad.data298299# Restore the data to its original scale300data_denorm = denorm(data)301302# Call FGSM Attack303perturbed_data = fgsm_attack(data_denorm, epsilon, data_grad)304305# Reapply normalization306perturbed_data_normalized = transforms.Normalize((0.1307,), (0.3081,))(perturbed_data)307308# Re-classify the perturbed image309output = model(perturbed_data_normalized)310311# Check for success312final_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability313if final_pred.item() == target.item():314correct += 1315# Special case for saving 0 epsilon examples316if epsilon == 0 and len(adv_examples) < 5:317adv_ex = perturbed_data.squeeze().detach().cpu().numpy()318adv_examples.append( (init_pred.item(), final_pred.item(), adv_ex) )319else:320# Save some adv examples for visualization later321if len(adv_examples) < 5:322adv_ex = perturbed_data.squeeze().detach().cpu().numpy()323adv_examples.append( (init_pred.item(), final_pred.item(), adv_ex) )324325# Calculate final accuracy for this epsilon326final_acc = correct/float(len(test_loader))327print(f"Epsilon: {epsilon}\tTest Accuracy = {correct} / {len(test_loader)} = {final_acc}")328329# Return the accuracy and an adversarial example330return final_acc, adv_examples331332333######################################################################334# Run Attack335# ~~~~~~~~~~336#337# The last part of the implementation is to actually run the attack. Here,338# we run a full test step for each epsilon value in the *epsilons* input.339# For each epsilon we also save the final accuracy and some successful340# adversarial examples to be plotted in the coming sections. Notice how341# the printed accuracies decrease as the epsilon value increases. Also,342# note the :math:`\epsilon=0` case represents the original test accuracy,343# with no attack.344#345346accuracies = []347examples = []348349# Run test for each epsilon350for eps in epsilons:351acc, ex = test(model, device, test_loader, eps)352accuracies.append(acc)353examples.append(ex)354355356######################################################################357# Results358# -------359#360# Accuracy vs Epsilon361# ~~~~~~~~~~~~~~~~~~~362#363# The first result is the accuracy versus epsilon plot. As alluded to364# earlier, as epsilon increases we expect the test accuracy to decrease.365# This is because larger epsilons mean we take a larger step in the366# direction that will maximize the loss. Notice the trend in the curve is367# not linear even though the epsilon values are linearly spaced. For368# example, the accuracy at :math:`\epsilon=0.05` is only about 4% lower369# than :math:`\epsilon=0`, but the accuracy at :math:`\epsilon=0.2` is 25%370# lower than :math:`\epsilon=0.15`. Also, notice the accuracy of the model371# hits random accuracy for a 10-class classifier between372# :math:`\epsilon=0.25` and :math:`\epsilon=0.3`.373#374375plt.figure(figsize=(5,5))376plt.plot(epsilons, accuracies, "*-")377plt.yticks(np.arange(0, 1.1, step=0.1))378plt.xticks(np.arange(0, .35, step=0.05))379plt.title("Accuracy vs Epsilon")380plt.xlabel("Epsilon")381plt.ylabel("Accuracy")382plt.show()383384385######################################################################386# Sample Adversarial Examples387# ~~~~~~~~~~~~~~~~~~~~~~~~~~~388#389# Remember the idea of no free lunch? In this case, as epsilon increases390# the test accuracy decreases **BUT** the perturbations become more easily391# perceptible. In reality, there is a tradeoff between accuracy392# degradation and perceptibility that an attacker must consider. Here, we393# show some examples of successful adversarial examples at each epsilon394# value. Each row of the plot shows a different epsilon value. The first395# row is the :math:`\epsilon=0` examples which represent the original396# “clean” images with no perturbation. The title of each image shows the397# “original classification -> adversarial classification.” Notice, the398# perturbations start to become evident at :math:`\epsilon=0.15` and are399# quite evident at :math:`\epsilon=0.3`. However, in all cases humans are400# still capable of identifying the correct class despite the added noise.401#402403# Plot several examples of adversarial samples at each epsilon404cnt = 0405plt.figure(figsize=(8,10))406for i in range(len(epsilons)):407for j in range(len(examples[i])):408cnt += 1409plt.subplot(len(epsilons),len(examples[0]),cnt)410plt.xticks([], [])411plt.yticks([], [])412if j == 0:413plt.ylabel(f"Eps: {epsilons[i]}", fontsize=14)414orig,adv,ex = examples[i][j]415plt.title(f"{orig} -> {adv}")416plt.imshow(ex, cmap="gray")417plt.tight_layout()418plt.show()419420421######################################################################422# Where to go next?423# -----------------424#425# Hopefully this tutorial gives some insight into the topic of adversarial426# machine learning. There are many potential directions to go from here.427# This attack represents the very beginning of adversarial attack research428# and since there have been many subsequent ideas for how to attack and429# defend ML models from an adversary. In fact, at NIPS 2017 there was an430# adversarial attack and defense competition and many of the methods used431# in the competition are described in this paper: `Adversarial Attacks and432# Defences Competition <https://arxiv.org/pdf/1804.00097.pdf>`__. The work433# on defense also leads into the idea of making machine learning models434# more *robust* in general, to both naturally perturbed and adversarially435# crafted inputs.436#437# Another direction to go is adversarial attacks and defense in different438# domains. Adversarial research is not limited to the image domain, check439# out `this <https://arxiv.org/pdf/1801.01944.pdf>`__ attack on440# speech-to-text models. But perhaps the best way to learn more about441# adversarial machine learning is to get your hands dirty. Try to442# implement a different attack from the NIPS 2017 competition, and see how443# it differs from FGSM. Then, try to defend the model from your own444# attacks.445#446# A further direction to go, depending on available resources, is to modify447# the code to support processing work in batch, in parallel, and or distributed448# vs working on one attack at a time in the above for each ``epsilon test()`` loop.449#450451452