Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/intermediate_source/per_sample_grads.py
Views: 712
# -*- coding: utf-8 -*-1"""2Per-sample-gradients3====================45What is it?6-----------78Per-sample-gradient computation is computing the gradient for each and every9sample in a batch of data. It is a useful quantity in differential privacy,10meta-learning, and optimization research.1112.. note::1314This tutorial requires PyTorch 2.0.0 or later.1516"""1718import torch19import torch.nn as nn20import torch.nn.functional as F21torch.manual_seed(0)2223# Here's a simple CNN and loss function:2425class SimpleCNN(nn.Module):26def __init__(self):27super(SimpleCNN, self).__init__()28self.conv1 = nn.Conv2d(1, 32, 3, 1)29self.conv2 = nn.Conv2d(32, 64, 3, 1)30self.fc1 = nn.Linear(9216, 128)31self.fc2 = nn.Linear(128, 10)3233def forward(self, x):34x = self.conv1(x)35x = F.relu(x)36x = self.conv2(x)37x = F.relu(x)38x = F.max_pool2d(x, 2)39x = torch.flatten(x, 1)40x = self.fc1(x)41x = F.relu(x)42x = self.fc2(x)43output = F.log_softmax(x, dim=1)44return output4546def loss_fn(predictions, targets):47return F.nll_loss(predictions, targets)484950######################################################################51# Let’s generate a batch of dummy data and pretend that we’re working with an MNIST dataset.52# The dummy images are 28 by 28 and we use a minibatch of size 64.5354device = 'cuda'5556num_models = 1057batch_size = 6458data = torch.randn(batch_size, 1, 28, 28, device=device)5960targets = torch.randint(10, (64,), device=device)6162######################################################################63# In regular model training, one would forward the minibatch through the model,64# and then call .backward() to compute gradients. This would generate an65# 'average' gradient of the entire mini-batch:6667model = SimpleCNN().to(device=device)68predictions = model(data) # move the entire mini-batch through the model6970loss = loss_fn(predictions, targets)71loss.backward() # back propagate the 'average' gradient of this mini-batch7273######################################################################74# In contrast to the above approach, per-sample-gradient computation is75# equivalent to:76#77# - for each individual sample of the data, perform a forward and a backward78# pass to get an individual (per-sample) gradient.7980def compute_grad(sample, target):81sample = sample.unsqueeze(0) # prepend batch dimension for processing82target = target.unsqueeze(0)8384prediction = model(sample)85loss = loss_fn(prediction, target)8687return torch.autograd.grad(loss, list(model.parameters()))888990def compute_sample_grads(data, targets):91""" manually process each sample with per sample gradient """92sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]93sample_grads = zip(*sample_grads)94sample_grads = [torch.stack(shards) for shards in sample_grads]95return sample_grads9697per_sample_grads = compute_sample_grads(data, targets)9899######################################################################100# ``sample_grads[0]`` is the per-sample-grad for model.conv1.weight.101# ``model.conv1.weight.shape`` is ``[32, 1, 3, 3]``; notice how there is one102# gradient, per sample, in the batch for a total of 64.103104print(per_sample_grads[0].shape)105106######################################################################107# Per-sample-grads, *the efficient way*, using function transforms108# ----------------------------------------------------------------109# We can compute per-sample-gradients efficiently by using function transforms.110#111# The ``torch.func`` function transform API transforms over functions.112# Our strategy is to define a function that computes the loss and then apply113# transforms to construct a function that computes per-sample-gradients.114#115# We'll use the ``torch.func.functional_call`` function to treat an ``nn.Module``116# like a function.117#118# First, let’s extract the state from ``model`` into two dictionaries,119# parameters and buffers. We'll be detaching them because we won't use120# regular PyTorch autograd (e.g. Tensor.backward(), torch.autograd.grad).121122from torch.func import functional_call, vmap, grad123124params = {k: v.detach() for k, v in model.named_parameters()}125buffers = {k: v.detach() for k, v in model.named_buffers()}126127######################################################################128# Next, let's define a function to compute the loss of the model given a129# single input rather than a batch of inputs. It is important that this130# function accepts the parameters, the input, and the target, because we will131# be transforming over them.132#133# Note - because the model was originally written to handle batches, we’ll134# use ``torch.unsqueeze`` to add a batch dimension.135136def compute_loss(params, buffers, sample, target):137batch = sample.unsqueeze(0)138targets = target.unsqueeze(0)139140predictions = functional_call(model, (params, buffers), (batch,))141loss = loss_fn(predictions, targets)142return loss143144######################################################################145# Now, let’s use the ``grad`` transform to create a new function that computes146# the gradient with respect to the first argument of ``compute_loss``147# (i.e. the ``params``).148149ft_compute_grad = grad(compute_loss)150151######################################################################152# The ``ft_compute_grad`` function computes the gradient for a single153# (sample, target) pair. We can use ``vmap`` to get it to compute the gradient154# over an entire batch of samples and targets. Note that155# ``in_dims=(None, None, 0, 0)`` because we wish to map ``ft_compute_grad`` over156# the 0th dimension of the data and targets, and use the same ``params`` and157# buffers for each.158159ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))160161######################################################################162# Finally, let's used our transformed function to compute per-sample-gradients:163164ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets)165166######################################################################167# we can double check that the results using ``grad`` and ``vmap`` match the168# results of hand processing each one individually:169170for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads.values()):171assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)172173######################################################################174# A quick note: there are limitations around what types of functions can be175# transformed by ``vmap``. The best functions to transform are ones that are pure176# functions: a function where the outputs are only determined by the inputs,177# and that have no side effects (e.g. mutation). ``vmap`` is unable to handle178# mutation of arbitrary Python data structures, but it is able to handle many179# in-place PyTorch operations.180#181# Performance comparison182# ----------------------183#184# Curious about how the performance of ``vmap`` compares?185#186# Currently the best results are obtained on newer GPU's such as the A100187# (Ampere) where we've seen up to 25x speedups on this example, but here are188# some results on our build machines:189190def get_perf(first, first_descriptor, second, second_descriptor):191"""takes torch.benchmark objects and compares delta of second vs first."""192second_res = second.times[0]193first_res = first.times[0]194195gain = (first_res-second_res)/first_res196if gain < 0: gain *=-1197final_gain = gain*100198199print(f"Performance delta: {final_gain:.4f} percent improvement with {first_descriptor} ")200201from torch.utils.benchmark import Timer202203without_vmap = Timer(stmt="compute_sample_grads(data, targets)", globals=globals())204with_vmap = Timer(stmt="ft_compute_sample_grad(params, buffers, data, targets)",globals=globals())205no_vmap_timing = without_vmap.timeit(100)206with_vmap_timing = with_vmap.timeit(100)207208print(f'Per-sample-grads without vmap {no_vmap_timing}')209print(f'Per-sample-grads with vmap {with_vmap_timing}')210211get_perf(with_vmap_timing, "vmap", no_vmap_timing, "no vmap")212213######################################################################214# There are other optimized solutions (like in https://github.com/pytorch/opacus)215# to computing per-sample-gradients in PyTorch that also perform better than216# the naive method. But it’s cool that composing ``vmap`` and ``grad`` give us a217# nice speedup.218#219# In general, vectorization with ``vmap`` should be faster than running a function220# in a for-loop and competitive with manual batching. There are some exceptions221# though, like if we haven’t implemented the ``vmap`` rule for a particular222# operation or if the underlying kernels weren’t optimized for older hardware223# (GPUs). If you see any of these cases, please let us know by opening an issue224# at on GitHub.225226227