Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/beginner_source/fgsm_tutorial.py
Views: 712
# -*- coding: utf-8 -*-1"""2Adversarial Example Generation3==============================45**Author:** `Nathan Inkawhich <https://github.com/inkawhich>`__67If you are reading this, hopefully you can appreciate how effective some8machine learning models are. Research is constantly pushing ML models to9be faster, more accurate, and more efficient. However, an often10overlooked aspect of designing and training models is security and11robustness, especially in the face of an adversary who wishes to fool12the model.1314This tutorial will raise your awareness to the security vulnerabilities15of ML models, and will give insight into the hot topic of adversarial16machine learning. You may be surprised to find that adding imperceptible17perturbations to an image *can* cause drastically different model18performance. Given that this is a tutorial, we will explore the topic19via example on an image classifier. Specifically, we will use one of the20first and most popular attack methods, the Fast Gradient Sign Attack21(FGSM), to fool an MNIST classifier.2223"""242526######################################################################27# Threat Model28# ------------29#30# For context, there are many categories of adversarial attacks, each with31# a different goal and assumption of the attacker’s knowledge. However, in32# general the overarching goal is to add the least amount of perturbation33# to the input data to cause the desired misclassification. There are34# several kinds of assumptions of the attacker’s knowledge, two of which35# are: **white-box** and **black-box**. A *white-box* attack assumes the36# attacker has full knowledge and access to the model, including37# architecture, inputs, outputs, and weights. A *black-box* attack assumes38# the attacker only has access to the inputs and outputs of the model, and39# knows nothing about the underlying architecture or weights. There are40# also several types of goals, including **misclassification** and41# **source/target misclassification**. A goal of *misclassification* means42# the adversary only wants the output classification to be wrong but does43# not care what the new classification is. A *source/target44# misclassification* means the adversary wants to alter an image that is45# originally of a specific source class so that it is classified as a46# specific target class.47#48# In this case, the FGSM attack is a *white-box* attack with the goal of49# *misclassification*. With this background information, we can now50# discuss the attack in detail.51#52# Fast Gradient Sign Attack53# -------------------------54#55# One of the first and most popular adversarial attacks to date is56# referred to as the *Fast Gradient Sign Attack (FGSM)* and is described57# by Goodfellow et. al. in `Explaining and Harnessing Adversarial58# Examples <https://arxiv.org/abs/1412.6572>`__. The attack is remarkably59# powerful, and yet intuitive. It is designed to attack neural networks by60# leveraging the way they learn, *gradients*. The idea is simple, rather61# than working to minimize the loss by adjusting the weights based on the62# backpropagated gradients, the attack *adjusts the input data to maximize63# the loss* based on the same backpropagated gradients. In other words,64# the attack uses the gradient of the loss w.r.t the input data, then65# adjusts the input data to maximize the loss.66#67# Before we jump into the code, let’s look at the famous68# `FGSM <https://arxiv.org/abs/1412.6572>`__ panda example and extract69# some notation.70#71# .. figure:: /_static/img/fgsm_panda_image.png72# :alt: fgsm_panda_image73#74# From the figure, :math:`\mathbf{x}` is the original input image75# correctly classified as a “panda”, :math:`y` is the ground truth label76# for :math:`\mathbf{x}`, :math:`\mathbf{\theta}` represents the model77# parameters, and :math:`J(\mathbf{\theta}, \mathbf{x}, y)` is the loss78# that is used to train the network. The attack backpropagates the79# gradient back to the input data to calculate80# :math:`\nabla_{x} J(\mathbf{\theta}, \mathbf{x}, y)`. Then, it adjusts81# the input data by a small step (:math:`\epsilon` or :math:`0.007` in the82# picture) in the direction (i.e.83# :math:`sign(\nabla_{x} J(\mathbf{\theta}, \mathbf{x}, y))`) that will84# maximize the loss. The resulting perturbed image, :math:`x'`, is then85# *misclassified* by the target network as a “gibbon” when it is still86# clearly a “panda”.87#88# Hopefully now the motivation for this tutorial is clear, so lets jump89# into the implementation.90#9192import torch93import torch.nn as nn94import torch.nn.functional as F95import torch.optim as optim96from torchvision import datasets, transforms97import numpy as np98import matplotlib.pyplot as plt99100101######################################################################102# Implementation103# --------------104#105# In this section, we will discuss the input parameters for the tutorial,106# define the model under attack, then code the attack and run some tests.107#108# Inputs109# ~~~~~~110#111# There are only three inputs for this tutorial, and are defined as112# follows:113#114# - ``epsilons`` - List of epsilon values to use for the run. It is115# important to keep 0 in the list because it represents the model116# performance on the original test set. Also, intuitively we would117# expect the larger the epsilon, the more noticeable the perturbations118# but the more effective the attack in terms of degrading model119# accuracy. Since the data range here is :math:`[0,1]`, no epsilon120# value should exceed 1.121#122# - ``pretrained_model`` - path to the pretrained MNIST model which was123# trained with124# `pytorch/examples/mnist <https://github.com/pytorch/examples/tree/master/mnist>`__.125# For simplicity, download the pretrained model `here <https://drive.google.com/file/d/1HJV2nUHJqclXQ8flKvcWmjZ-OU5DGatl/view?usp=drive_link>`__.126#127# - ``use_cuda`` - boolean flag to use CUDA if desired and available.128# Note, a GPU with CUDA is not critical for this tutorial as a CPU will129# not take much time.130#131132epsilons = [0, .05, .1, .15, .2, .25, .3]133pretrained_model = "data/lenet_mnist_model.pth"134use_cuda=True135# Set random seed for reproducibility136torch.manual_seed(42)137138139######################################################################140# Model Under Attack141# ~~~~~~~~~~~~~~~~~~142#143# As mentioned, the model under attack is the same MNIST model from144# `pytorch/examples/mnist <https://github.com/pytorch/examples/tree/master/mnist>`__.145# You may train and save your own MNIST model or you can download and use146# the provided model. The *Net* definition and test dataloader here have147# been copied from the MNIST example. The purpose of this section is to148# define the model and dataloader, then initialize the model and load the149# pretrained weights.150#151152# LeNet Model definition153class Net(nn.Module):154def __init__(self):155super(Net, self).__init__()156self.conv1 = nn.Conv2d(1, 32, 3, 1)157self.conv2 = nn.Conv2d(32, 64, 3, 1)158self.dropout1 = nn.Dropout(0.25)159self.dropout2 = nn.Dropout(0.5)160self.fc1 = nn.Linear(9216, 128)161self.fc2 = nn.Linear(128, 10)162163def forward(self, x):164x = self.conv1(x)165x = F.relu(x)166x = self.conv2(x)167x = F.relu(x)168x = F.max_pool2d(x, 2)169x = self.dropout1(x)170x = torch.flatten(x, 1)171x = self.fc1(x)172x = F.relu(x)173x = self.dropout2(x)174x = self.fc2(x)175output = F.log_softmax(x, dim=1)176return output177178# MNIST Test dataset and dataloader declaration179test_loader = torch.utils.data.DataLoader(180datasets.MNIST('../data', train=False, download=True, transform=transforms.Compose([181transforms.ToTensor(),182transforms.Normalize((0.1307,), (0.3081,)),183])),184batch_size=1, shuffle=True)185186# Define what device we are using187print("CUDA Available: ",torch.cuda.is_available())188device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu")189190# Initialize the network191model = Net().to(device)192193# Load the pretrained model194model.load_state_dict(torch.load(pretrained_model, map_location=device, weights_only=True))195196# Set the model in evaluation mode. In this case this is for the Dropout layers197model.eval()198199200######################################################################201# FGSM Attack202# ~~~~~~~~~~~203#204# Now, we can define the function that creates the adversarial examples by205# perturbing the original inputs. The ``fgsm_attack`` function takes three206# inputs, *image* is the original clean image (:math:`x`), *epsilon* is207# the pixel-wise perturbation amount (:math:`\epsilon`), and *data_grad*208# is gradient of the loss w.r.t the input image209# (:math:`\nabla_{x} J(\mathbf{\theta}, \mathbf{x}, y)`). The function210# then creates perturbed image as211#212# .. math:: perturbed\_image = image + epsilon*sign(data\_grad) = x + \epsilon * sign(\nabla_{x} J(\mathbf{\theta}, \mathbf{x}, y))213#214# Finally, in order to maintain the original range of the data, the215# perturbed image is clipped to range :math:`[0,1]`.216#217218# FGSM attack code219def fgsm_attack(image, epsilon, data_grad):220# Collect the element-wise sign of the data gradient221sign_data_grad = data_grad.sign()222# Create the perturbed image by adjusting each pixel of the input image223perturbed_image = image + epsilon*sign_data_grad224# Adding clipping to maintain [0,1] range225perturbed_image = torch.clamp(perturbed_image, 0, 1)226# Return the perturbed image227return perturbed_image228229# restores the tensors to their original scale230def denorm(batch, mean=[0.1307], std=[0.3081]):231"""232Convert a batch of tensors to their original scale.233234Args:235batch (torch.Tensor): Batch of normalized tensors.236mean (torch.Tensor or list): Mean used for normalization.237std (torch.Tensor or list): Standard deviation used for normalization.238239Returns:240torch.Tensor: batch of tensors without normalization applied to them.241"""242if isinstance(mean, list):243mean = torch.tensor(mean).to(device)244if isinstance(std, list):245std = torch.tensor(std).to(device)246247return batch * std.view(1, -1, 1, 1) + mean.view(1, -1, 1, 1)248249250######################################################################251# Testing Function252# ~~~~~~~~~~~~~~~~253#254# Finally, the central result of this tutorial comes from the ``test``255# function. Each call to this test function performs a full test step on256# the MNIST test set and reports a final accuracy. However, notice that257# this function also takes an *epsilon* input. This is because the258# ``test`` function reports the accuracy of a model that is under attack259# from an adversary with strength :math:`\epsilon`. More specifically, for260# each sample in the test set, the function computes the gradient of the261# loss w.r.t the input data (:math:`data\_grad`), creates a perturbed262# image with ``fgsm_attack`` (:math:`perturbed\_data`), then checks to see263# if the perturbed example is adversarial. In addition to testing the264# accuracy of the model, the function also saves and returns some265# successful adversarial examples to be visualized later.266#267268def test( model, device, test_loader, epsilon ):269270# Accuracy counter271correct = 0272adv_examples = []273274# Loop over all examples in test set275for data, target in test_loader:276277# Send the data and label to the device278data, target = data.to(device), target.to(device)279280# Set requires_grad attribute of tensor. Important for Attack281data.requires_grad = True282283# Forward pass the data through the model284output = model(data)285init_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability286287# If the initial prediction is wrong, don't bother attacking, just move on288if init_pred.item() != target.item():289continue290291# Calculate the loss292loss = F.nll_loss(output, target)293294# Zero all existing gradients295model.zero_grad()296297# Calculate gradients of model in backward pass298loss.backward()299300# Collect ``datagrad``301data_grad = data.grad.data302303# Restore the data to its original scale304data_denorm = denorm(data)305306# Call FGSM Attack307perturbed_data = fgsm_attack(data_denorm, epsilon, data_grad)308309# Reapply normalization310perturbed_data_normalized = transforms.Normalize((0.1307,), (0.3081,))(perturbed_data)311312# Re-classify the perturbed image313output = model(perturbed_data_normalized)314315# Check for success316final_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability317if final_pred.item() == target.item():318correct += 1319# Special case for saving 0 epsilon examples320if epsilon == 0 and len(adv_examples) < 5:321adv_ex = perturbed_data.squeeze().detach().cpu().numpy()322adv_examples.append( (init_pred.item(), final_pred.item(), adv_ex) )323else:324# Save some adv examples for visualization later325if len(adv_examples) < 5:326adv_ex = perturbed_data.squeeze().detach().cpu().numpy()327adv_examples.append( (init_pred.item(), final_pred.item(), adv_ex) )328329# Calculate final accuracy for this epsilon330final_acc = correct/float(len(test_loader))331print(f"Epsilon: {epsilon}\tTest Accuracy = {correct} / {len(test_loader)} = {final_acc}")332333# Return the accuracy and an adversarial example334return final_acc, adv_examples335336337######################################################################338# Run Attack339# ~~~~~~~~~~340#341# The last part of the implementation is to actually run the attack. Here,342# we run a full test step for each epsilon value in the *epsilons* input.343# For each epsilon we also save the final accuracy and some successful344# adversarial examples to be plotted in the coming sections. Notice how345# the printed accuracies decrease as the epsilon value increases. Also,346# note the :math:`\epsilon=0` case represents the original test accuracy,347# with no attack.348#349350accuracies = []351examples = []352353# Run test for each epsilon354for eps in epsilons:355acc, ex = test(model, device, test_loader, eps)356accuracies.append(acc)357examples.append(ex)358359360######################################################################361# Results362# -------363#364# Accuracy vs Epsilon365# ~~~~~~~~~~~~~~~~~~~366#367# The first result is the accuracy versus epsilon plot. As alluded to368# earlier, as epsilon increases we expect the test accuracy to decrease.369# This is because larger epsilons mean we take a larger step in the370# direction that will maximize the loss. Notice the trend in the curve is371# not linear even though the epsilon values are linearly spaced. For372# example, the accuracy at :math:`\epsilon=0.05` is only about 4% lower373# than :math:`\epsilon=0`, but the accuracy at :math:`\epsilon=0.2` is 25%374# lower than :math:`\epsilon=0.15`. Also, notice the accuracy of the model375# hits random accuracy for a 10-class classifier between376# :math:`\epsilon=0.25` and :math:`\epsilon=0.3`.377#378379plt.figure(figsize=(5,5))380plt.plot(epsilons, accuracies, "*-")381plt.yticks(np.arange(0, 1.1, step=0.1))382plt.xticks(np.arange(0, .35, step=0.05))383plt.title("Accuracy vs Epsilon")384plt.xlabel("Epsilon")385plt.ylabel("Accuracy")386plt.show()387388389######################################################################390# Sample Adversarial Examples391# ~~~~~~~~~~~~~~~~~~~~~~~~~~~392#393# Remember the idea of no free lunch? In this case, as epsilon increases394# the test accuracy decreases **BUT** the perturbations become more easily395# perceptible. In reality, there is a tradeoff between accuracy396# degradation and perceptibility that an attacker must consider. Here, we397# show some examples of successful adversarial examples at each epsilon398# value. Each row of the plot shows a different epsilon value. The first399# row is the :math:`\epsilon=0` examples which represent the original400# “clean” images with no perturbation. The title of each image shows the401# “original classification -> adversarial classification.” Notice, the402# perturbations start to become evident at :math:`\epsilon=0.15` and are403# quite evident at :math:`\epsilon=0.3`. However, in all cases humans are404# still capable of identifying the correct class despite the added noise.405#406407# Plot several examples of adversarial samples at each epsilon408cnt = 0409plt.figure(figsize=(8,10))410for i in range(len(epsilons)):411for j in range(len(examples[i])):412cnt += 1413plt.subplot(len(epsilons),len(examples[0]),cnt)414plt.xticks([], [])415plt.yticks([], [])416if j == 0:417plt.ylabel(f"Eps: {epsilons[i]}", fontsize=14)418orig,adv,ex = examples[i][j]419plt.title(f"{orig} -> {adv}")420plt.imshow(ex, cmap="gray")421plt.tight_layout()422plt.show()423424425######################################################################426# Where to go next?427# -----------------428#429# Hopefully this tutorial gives some insight into the topic of adversarial430# machine learning. There are many potential directions to go from here.431# This attack represents the very beginning of adversarial attack research432# and since there have been many subsequent ideas for how to attack and433# defend ML models from an adversary. In fact, at NIPS 2017 there was an434# adversarial attack and defense competition and many of the methods used435# in the competition are described in this paper: `Adversarial Attacks and436# Defences Competition <https://arxiv.org/pdf/1804.00097.pdf>`__. The work437# on defense also leads into the idea of making machine learning models438# more *robust* in general, to both naturally perturbed and adversarially439# crafted inputs.440#441# Another direction to go is adversarial attacks and defense in different442# domains. Adversarial research is not limited to the image domain, check443# out `this <https://arxiv.org/pdf/1801.01944.pdf>`__ attack on444# speech-to-text models. But perhaps the best way to learn more about445# adversarial machine learning is to get your hands dirty. Try to446# implement a different attack from the NIPS 2017 competition, and see how447# it differs from FGSM. Then, try to defend the model from your own448# attacks.449#450# A further direction to go, depending on available resources, is to modify451# the code to support processing work in batch, in parallel, and or distributed452# vs working on one attack at a time in the above for each ``epsilon test()`` loop.453#454455456