CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/jacobians_hessians.py
Views: 712
1
# -*- coding: utf-8 -*-
2
"""
3
Jacobians, Hessians, hvp, vhp, and more: composing function transforms
4
======================================================================
5
6
Computing jacobians or hessians are useful in a number of non-traditional
7
deep learning models. It is difficult (or annoying) to compute these quantities
8
efficiently using PyTorch's regular autodiff APIs
9
(``Tensor.backward()``, ``torch.autograd.grad``). PyTorch's
10
`JAX-inspired <https://github.com/google/jax>`_
11
`function transforms API <https://pytorch.org/docs/master/func.html>`_
12
provides ways of computing various higher-order autodiff quantities
13
efficiently.
14
15
.. note::
16
17
This tutorial requires PyTorch 2.0.0 or later.
18
19
Computing the Jacobian
20
----------------------
21
"""
22
23
import torch
24
import torch.nn.functional as F
25
from functools import partial
26
_ = torch.manual_seed(0)
27
28
######################################################################
29
# Let's start with a function that we'd like to compute the jacobian of.
30
# This is a simple linear function with non-linear activation.
31
32
def predict(weight, bias, x):
33
return F.linear(x, weight, bias).tanh()
34
35
######################################################################
36
# Let's add some dummy data: a weight, a bias, and a feature vector x.
37
38
D = 16
39
weight = torch.randn(D, D)
40
bias = torch.randn(D)
41
x = torch.randn(D) # feature vector
42
43
######################################################################
44
# Let's think of ``predict`` as a function that maps the input ``x`` from :math:`R^D \to R^D`.
45
# PyTorch Autograd computes vector-Jacobian products. In order to compute the full
46
# Jacobian of this :math:`R^D \to R^D` function, we would have to compute it row-by-row
47
# by using a different unit vector each time.
48
49
def compute_jac(xp):
50
jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]
51
for vec in unit_vectors]
52
return torch.stack(jacobian_rows)
53
54
xp = x.clone().requires_grad_()
55
unit_vectors = torch.eye(D)
56
57
jacobian = compute_jac(xp)
58
59
print(jacobian.shape)
60
print(jacobian[0]) # show first row
61
62
######################################################################
63
# Instead of computing the jacobian row-by-row, we can use PyTorch's
64
# ``torch.vmap`` function transform to get rid of the for-loop and vectorize the
65
# computation. We can’t directly apply ``vmap`` to ``torch.autograd.grad``;
66
# instead, PyTorch provides a ``torch.func.vjp`` transform that composes with
67
# ``torch.vmap``:
68
69
from torch.func import vmap, vjp
70
71
_, vjp_fn = vjp(partial(predict, weight, bias), x)
72
73
ft_jacobian, = vmap(vjp_fn)(unit_vectors)
74
75
# let's confirm both methods compute the same result
76
assert torch.allclose(ft_jacobian, jacobian)
77
78
######################################################################
79
# In a later tutorial a composition of reverse-mode AD and ``vmap`` will give us
80
# per-sample-gradients.
81
# In this tutorial, composing reverse-mode AD and ``vmap`` gives us Jacobian
82
# computation!
83
# Various compositions of ``vmap`` and autodiff transforms can give us different
84
# interesting quantities.
85
#
86
# PyTorch provides ``torch.func.jacrev`` as a convenience function that performs
87
# the ``vmap-vjp`` composition to compute jacobians. ``jacrev`` accepts an ``argnums``
88
# argument that says which argument we would like to compute Jacobians with
89
# respect to.
90
91
from torch.func import jacrev
92
93
ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)
94
95
# Confirm by running the following:
96
assert torch.allclose(ft_jacobian, jacobian)
97
98
######################################################################
99
# Let's compare the performance of the two ways to compute the jacobian.
100
# The function transform version is much faster (and becomes even faster the
101
# more outputs there are).
102
#
103
# In general, we expect that vectorization via ``vmap`` can help eliminate overhead
104
# and give better utilization of your hardware.
105
#
106
# ``vmap`` does this magic by pushing the outer loop down into the function's
107
# primitive operations in order to obtain better performance.
108
#
109
# Let's make a quick function to evaluate performance and deal with
110
# microseconds and milliseconds measurements:
111
112
def get_perf(first, first_descriptor, second, second_descriptor):
113
"""takes torch.benchmark objects and compares delta of second vs first."""
114
faster = second.times[0]
115
slower = first.times[0]
116
gain = (slower-faster)/slower
117
if gain < 0: gain *=-1
118
final_gain = gain*100
119
print(f" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} ")
120
121
######################################################################
122
# And then run the performance comparison:
123
124
from torch.utils.benchmark import Timer
125
126
without_vmap = Timer(stmt="compute_jac(xp)", globals=globals())
127
with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
128
129
no_vmap_timer = without_vmap.timeit(500)
130
with_vmap_timer = with_vmap.timeit(500)
131
132
print(no_vmap_timer)
133
print(with_vmap_timer)
134
135
######################################################################
136
# Let's do a relative performance comparison of the above with our ``get_perf`` function:
137
138
get_perf(no_vmap_timer, "without vmap", with_vmap_timer, "vmap")
139
140
######################################################################
141
# Furthermore, it’s pretty easy to flip the problem around and say we want to
142
# compute Jacobians of the parameters to our model (weight, bias) instead of the input
143
144
# note the change in input via ``argnums`` parameters of 0,1 to map to weight and bias
145
ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)
146
147
######################################################################
148
# Reverse-mode Jacobian (``jacrev``) vs forward-mode Jacobian (``jacfwd``)
149
# ------------------------------------------------------------------------
150
#
151
# We offer two APIs to compute jacobians: ``jacrev`` and ``jacfwd``:
152
#
153
# - ``jacrev`` uses reverse-mode AD. As you saw above it is a composition of our
154
# ``vjp`` and ``vmap`` transforms.
155
# - ``jacfwd`` uses forward-mode AD. It is implemented as a composition of our
156
# ``jvp`` and ``vmap`` transforms.
157
#
158
# ``jacfwd`` and ``jacrev`` can be substituted for each other but they have different
159
# performance characteristics.
160
#
161
# As a general rule of thumb, if you’re computing the jacobian of an :math:`R^N \to R^M`
162
# function, and there are many more outputs than inputs (for example, :math:`M > N`) then
163
# ``jacfwd`` is preferred, otherwise use ``jacrev``. There are exceptions to this rule,
164
# but a non-rigorous argument for this follows:
165
#
166
# In reverse-mode AD, we are computing the jacobian row-by-row, while in
167
# forward-mode AD (which computes Jacobian-vector products), we are computing
168
# it column-by-column. The Jacobian matrix has M rows and N columns, so if it
169
# is taller or wider one way we may prefer the method that deals with fewer
170
# rows or columns.
171
172
from torch.func import jacrev, jacfwd
173
174
######################################################################
175
# First, let's benchmark with more inputs than outputs:
176
177
Din = 32
178
Dout = 2048
179
weight = torch.randn(Dout, Din)
180
181
bias = torch.randn(Dout)
182
x = torch.randn(Din)
183
184
# remember the general rule about taller vs wider... here we have a taller matrix:
185
print(weight.shape)
186
187
using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
188
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
189
190
jacfwd_timing = using_fwd.timeit(500)
191
jacrev_timing = using_bwd.timeit(500)
192
193
print(f'jacfwd time: {jacfwd_timing}')
194
print(f'jacrev time: {jacrev_timing}')
195
196
######################################################################
197
# and then do a relative benchmark:
198
199
get_perf(jacfwd_timing, "jacfwd", jacrev_timing, "jacrev", );
200
201
#######################################################################
202
# and now the reverse - more outputs (M) than inputs (N):
203
204
Din = 2048
205
Dout = 32
206
weight = torch.randn(Dout, Din)
207
bias = torch.randn(Dout)
208
x = torch.randn(Din)
209
210
using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
211
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
212
213
jacfwd_timing = using_fwd.timeit(500)
214
jacrev_timing = using_bwd.timeit(500)
215
216
print(f'jacfwd time: {jacfwd_timing}')
217
print(f'jacrev time: {jacrev_timing}')
218
219
#######################################################################
220
# and a relative performance comparison:
221
222
get_perf(jacrev_timing, "jacrev", jacfwd_timing, "jacfwd")
223
224
#######################################################################
225
# Hessian computation with functorch.hessian
226
# ------------------------------------------
227
# We offer a convenience API to compute hessians: ``torch.func.hessiani``.
228
# Hessians are the jacobian of the jacobian (or the partial derivative of
229
# the partial derivative, aka second order).
230
#
231
# This suggests that one can just compose functorch jacobian transforms to
232
# compute the Hessian.
233
# Indeed, under the hood, ``hessian(f)`` is simply ``jacfwd(jacrev(f))``.
234
#
235
# Note: to boost performance: depending on your model, you may also want to
236
# use ``jacfwd(jacfwd(f))`` or ``jacrev(jacrev(f))`` instead to compute hessians
237
# leveraging the rule of thumb above regarding wider vs taller matrices.
238
239
from torch.func import hessian
240
241
# lets reduce the size in order not to overwhelm Colab. Hessians require
242
# significant memory:
243
Din = 512
244
Dout = 32
245
weight = torch.randn(Dout, Din)
246
bias = torch.randn(Dout)
247
x = torch.randn(Din)
248
249
hess_api = hessian(predict, argnums=2)(weight, bias, x)
250
hess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)
251
hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)
252
253
#######################################################################
254
# Let's verify we have the same result regardless of using hessian API or
255
# using ``jacfwd(jacfwd())``.
256
257
torch.allclose(hess_api, hess_fwdfwd)
258
259
#######################################################################
260
# Batch Jacobian and Batch Hessian
261
# --------------------------------
262
# In the above examples we’ve been operating with a single feature vector.
263
# In some cases you might want to take the Jacobian of a batch of outputs
264
# with respect to a batch of inputs. That is, given a batch of inputs of
265
# shape ``(B, N)`` and a function that goes from :math:`R^N \to R^M`, we would like
266
# a Jacobian of shape ``(B, M, N)``.
267
#
268
# The easiest way to do this is to use ``vmap``:
269
270
batch_size = 64
271
Din = 31
272
Dout = 33
273
274
weight = torch.randn(Dout, Din)
275
print(f"weight shape = {weight.shape}")
276
277
bias = torch.randn(Dout)
278
279
x = torch.randn(batch_size, Din)
280
281
compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))
282
batch_jacobian0 = compute_batch_jacobian(weight, bias, x)
283
284
#######################################################################
285
# If you have a function that goes from (B, N) -> (B, M) instead and are
286
# certain that each input produces an independent output, then it's also
287
# sometimes possible to do this without using ``vmap`` by summing the outputs
288
# and then computing the Jacobian of that function:
289
290
def predict_with_output_summed(weight, bias, x):
291
return predict(weight, bias, x).sum(0)
292
293
batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0)
294
assert torch.allclose(batch_jacobian0, batch_jacobian1)
295
296
#######################################################################
297
# If you instead have a function that goes from :math:`R^N \to R^M` but inputs that
298
# are batched, you compose ``vmap`` with ``jacrev`` to compute batched jacobians:
299
#
300
# Finally, batch hessians can be computed similarly. It's easiest to think
301
# about them by using ``vmap`` to batch over hessian computation, but in some
302
# cases the sum trick also works.
303
304
compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0))
305
306
batch_hess = compute_batch_hessian(weight, bias, x)
307
batch_hess.shape
308
309
#######################################################################
310
# Computing Hessian-vector products
311
# ---------------------------------
312
# The naive way to compute a Hessian-vector product (hvp) is to materialize
313
# the full Hessian and perform a dot-product with a vector. We can do better:
314
# it turns out we don't need to materialize the full Hessian to do this. We'll
315
# go through two (of many) different strategies to compute Hessian-vector products:
316
# - composing reverse-mode AD with reverse-mode AD
317
# - composing reverse-mode AD with forward-mode AD
318
#
319
# Composing reverse-mode AD with forward-mode AD (as opposed to reverse-mode
320
# with reverse-mode) is generally the more memory efficient way to compute a
321
# hvp because forward-mode AD doesn't need to construct an Autograd graph and
322
# save intermediates for backward:
323
324
from torch.func import jvp, grad, vjp
325
326
def hvp(f, primals, tangents):
327
return jvp(grad(f), primals, tangents)[1]
328
329
#######################################################################
330
# Here's some sample usage.
331
332
def f(x):
333
return x.sin().sum()
334
335
x = torch.randn(2048)
336
tangent = torch.randn(2048)
337
338
result = hvp(f, (x,), (tangent,))
339
340
#######################################################################
341
# If PyTorch forward-AD does not have coverage for your operations, then we can
342
# instead compose reverse-mode AD with reverse-mode AD:
343
344
def hvp_revrev(f, primals, tangents):
345
_, vjp_fn = vjp(grad(f), *primals)
346
return vjp_fn(*tangents)
347
348
result_hvp_revrev = hvp_revrev(f, (x,), (tangent,))
349
assert torch.allclose(result, result_hvp_revrev[0])
350
351