CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/per_sample_grads.py
Views: 494
1
# -*- coding: utf-8 -*-
2
"""
3
Per-sample-gradients
4
====================
5
6
What is it?
7
-----------
8
9
Per-sample-gradient computation is computing the gradient for each and every
10
sample in a batch of data. It is a useful quantity in differential privacy,
11
meta-learning, and optimization research.
12
13
.. note::
14
15
This tutorial requires PyTorch 2.0.0 or later.
16
17
"""
18
19
import torch
20
import torch.nn as nn
21
import torch.nn.functional as F
22
torch.manual_seed(0)
23
24
# Here's a simple CNN and loss function:
25
26
class SimpleCNN(nn.Module):
27
def __init__(self):
28
super(SimpleCNN, self).__init__()
29
self.conv1 = nn.Conv2d(1, 32, 3, 1)
30
self.conv2 = nn.Conv2d(32, 64, 3, 1)
31
self.fc1 = nn.Linear(9216, 128)
32
self.fc2 = nn.Linear(128, 10)
33
34
def forward(self, x):
35
x = self.conv1(x)
36
x = F.relu(x)
37
x = self.conv2(x)
38
x = F.relu(x)
39
x = F.max_pool2d(x, 2)
40
x = torch.flatten(x, 1)
41
x = self.fc1(x)
42
x = F.relu(x)
43
x = self.fc2(x)
44
output = F.log_softmax(x, dim=1)
45
return output
46
47
def loss_fn(predictions, targets):
48
return F.nll_loss(predictions, targets)
49
50
51
######################################################################
52
# Let’s generate a batch of dummy data and pretend that we’re working with an MNIST dataset.
53
# The dummy images are 28 by 28 and we use a minibatch of size 64.
54
55
device = 'cuda'
56
57
num_models = 10
58
batch_size = 64
59
data = torch.randn(batch_size, 1, 28, 28, device=device)
60
61
targets = torch.randint(10, (64,), device=device)
62
63
######################################################################
64
# In regular model training, one would forward the minibatch through the model,
65
# and then call .backward() to compute gradients. This would generate an
66
# 'average' gradient of the entire mini-batch:
67
68
model = SimpleCNN().to(device=device)
69
predictions = model(data) # move the entire mini-batch through the model
70
71
loss = loss_fn(predictions, targets)
72
loss.backward() # back propagate the 'average' gradient of this mini-batch
73
74
######################################################################
75
# In contrast to the above approach, per-sample-gradient computation is
76
# equivalent to:
77
#
78
# - for each individual sample of the data, perform a forward and a backward
79
# pass to get an individual (per-sample) gradient.
80
81
def compute_grad(sample, target):
82
sample = sample.unsqueeze(0) # prepend batch dimension for processing
83
target = target.unsqueeze(0)
84
85
prediction = model(sample)
86
loss = loss_fn(prediction, target)
87
88
return torch.autograd.grad(loss, list(model.parameters()))
89
90
91
def compute_sample_grads(data, targets):
92
""" manually process each sample with per sample gradient """
93
sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]
94
sample_grads = zip(*sample_grads)
95
sample_grads = [torch.stack(shards) for shards in sample_grads]
96
return sample_grads
97
98
per_sample_grads = compute_sample_grads(data, targets)
99
100
######################################################################
101
# ``sample_grads[0]`` is the per-sample-grad for model.conv1.weight.
102
# ``model.conv1.weight.shape`` is ``[32, 1, 3, 3]``; notice how there is one
103
# gradient, per sample, in the batch for a total of 64.
104
105
print(per_sample_grads[0].shape)
106
107
######################################################################
108
# Per-sample-grads, *the efficient way*, using function transforms
109
# ----------------------------------------------------------------
110
# We can compute per-sample-gradients efficiently by using function transforms.
111
#
112
# The ``torch.func`` function transform API transforms over functions.
113
# Our strategy is to define a function that computes the loss and then apply
114
# transforms to construct a function that computes per-sample-gradients.
115
#
116
# We'll use the ``torch.func.functional_call`` function to treat an ``nn.Module``
117
# like a function.
118
#
119
# First, let’s extract the state from ``model`` into two dictionaries,
120
# parameters and buffers. We'll be detaching them because we won't use
121
# regular PyTorch autograd (e.g. Tensor.backward(), torch.autograd.grad).
122
123
from torch.func import functional_call, vmap, grad
124
125
params = {k: v.detach() for k, v in model.named_parameters()}
126
buffers = {k: v.detach() for k, v in model.named_buffers()}
127
128
######################################################################
129
# Next, let's define a function to compute the loss of the model given a
130
# single input rather than a batch of inputs. It is important that this
131
# function accepts the parameters, the input, and the target, because we will
132
# be transforming over them.
133
#
134
# Note - because the model was originally written to handle batches, we’ll
135
# use ``torch.unsqueeze`` to add a batch dimension.
136
137
def compute_loss(params, buffers, sample, target):
138
batch = sample.unsqueeze(0)
139
targets = target.unsqueeze(0)
140
141
predictions = functional_call(model, (params, buffers), (batch,))
142
loss = loss_fn(predictions, targets)
143
return loss
144
145
######################################################################
146
# Now, let’s use the ``grad`` transform to create a new function that computes
147
# the gradient with respect to the first argument of ``compute_loss``
148
# (i.e. the ``params``).
149
150
ft_compute_grad = grad(compute_loss)
151
152
######################################################################
153
# The ``ft_compute_grad`` function computes the gradient for a single
154
# (sample, target) pair. We can use ``vmap`` to get it to compute the gradient
155
# over an entire batch of samples and targets. Note that
156
# ``in_dims=(None, None, 0, 0)`` because we wish to map ``ft_compute_grad`` over
157
# the 0th dimension of the data and targets, and use the same ``params`` and
158
# buffers for each.
159
160
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))
161
162
######################################################################
163
# Finally, let's used our transformed function to compute per-sample-gradients:
164
165
ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets)
166
167
######################################################################
168
# we can double check that the results using ``grad`` and ``vmap`` match the
169
# results of hand processing each one individually:
170
171
for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads.values()):
172
assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)
173
174
######################################################################
175
# A quick note: there are limitations around what types of functions can be
176
# transformed by ``vmap``. The best functions to transform are ones that are pure
177
# functions: a function where the outputs are only determined by the inputs,
178
# and that have no side effects (e.g. mutation). ``vmap`` is unable to handle
179
# mutation of arbitrary Python data structures, but it is able to handle many
180
# in-place PyTorch operations.
181
#
182
# Performance comparison
183
# ----------------------
184
#
185
# Curious about how the performance of ``vmap`` compares?
186
#
187
# Currently the best results are obtained on newer GPU's such as the A100
188
# (Ampere) where we've seen up to 25x speedups on this example, but here are
189
# some results on our build machines:
190
191
def get_perf(first, first_descriptor, second, second_descriptor):
192
"""takes torch.benchmark objects and compares delta of second vs first."""
193
second_res = second.times[0]
194
first_res = first.times[0]
195
196
gain = (first_res-second_res)/first_res
197
if gain < 0: gain *=-1
198
final_gain = gain*100
199
200
print(f"Performance delta: {final_gain:.4f} percent improvement with {first_descriptor} ")
201
202
from torch.utils.benchmark import Timer
203
204
without_vmap = Timer(stmt="compute_sample_grads(data, targets)", globals=globals())
205
with_vmap = Timer(stmt="ft_compute_sample_grad(params, buffers, data, targets)",globals=globals())
206
no_vmap_timing = without_vmap.timeit(100)
207
with_vmap_timing = with_vmap.timeit(100)
208
209
print(f'Per-sample-grads without vmap {no_vmap_timing}')
210
print(f'Per-sample-grads with vmap {with_vmap_timing}')
211
212
get_perf(with_vmap_timing, "vmap", no_vmap_timing, "no vmap")
213
214
######################################################################
215
# There are other optimized solutions (like in https://github.com/pytorch/opacus)
216
# to computing per-sample-gradients in PyTorch that also perform better than
217
# the naive method. But it’s cool that composing ``vmap`` and ``grad`` give us a
218
# nice speedup.
219
#
220
# In general, vectorization with ``vmap`` should be faster than running a function
221
# in a for-loop and competitive with manual batching. There are some exceptions
222
# though, like if we haven’t implemented the ``vmap`` rule for a particular
223
# operation or if the underlying kernels weren’t optimized for older hardware
224
# (GPUs). If you see any of these cases, please let us know by opening an issue
225
# at on GitHub.
226
227