CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/forward_ad_usage.py
Views: 712
1
# -*- coding: utf-8 -*-
2
"""
3
Forward-mode Automatic Differentiation (Beta)
4
=============================================
5
6
This tutorial demonstrates how to use forward-mode AD to compute
7
directional derivatives (or equivalently, Jacobian-vector products).
8
9
The tutorial below uses some APIs only available in versions >= 1.11
10
(or nightly builds).
11
12
Also note that forward-mode AD is currently in beta. The API is
13
subject to change and operator coverage is still incomplete.
14
15
Basic Usage
16
--------------------------------------------------------------------
17
Unlike reverse-mode AD, forward-mode AD computes gradients eagerly
18
alongside the forward pass. We can use forward-mode AD to compute a
19
directional derivative by performing the forward pass as before,
20
except we first associate our input with another tensor representing
21
the direction of the directional derivative (or equivalently, the ``v``
22
in a Jacobian-vector product). When an input, which we call "primal", is
23
associated with a "direction" tensor, which we call "tangent", the
24
resultant new tensor object is called a "dual tensor" for its connection
25
to dual numbers[0].
26
27
As the forward pass is performed, if any input tensors are dual tensors,
28
extra computation is performed to propagate this "sensitivity" of the
29
function.
30
31
"""
32
33
import torch
34
import torch.autograd.forward_ad as fwAD
35
36
primal = torch.randn(10, 10)
37
tangent = torch.randn(10, 10)
38
39
def fn(x, y):
40
return x ** 2 + y ** 2
41
42
# All forward AD computation must be performed in the context of
43
# a ``dual_level`` context. All dual tensors created in such a context
44
# will have their tangents destroyed upon exit. This is to ensure that
45
# if the output or intermediate results of this computation are reused
46
# in a future forward AD computation, their tangents (which are associated
47
# with this computation) won't be confused with tangents from the later
48
# computation.
49
with fwAD.dual_level():
50
# To create a dual tensor we associate a tensor, which we call the
51
# primal with another tensor of the same size, which we call the tangent.
52
# If the layout of the tangent is different from that of the primal,
53
# The values of the tangent are copied into a new tensor with the same
54
# metadata as the primal. Otherwise, the tangent itself is used as-is.
55
#
56
# It is also important to note that the dual tensor created by
57
# ``make_dual`` is a view of the primal.
58
dual_input = fwAD.make_dual(primal, tangent)
59
assert fwAD.unpack_dual(dual_input).tangent is tangent
60
61
# To demonstrate the case where the copy of the tangent happens,
62
# we pass in a tangent with a layout different from that of the primal
63
dual_input_alt = fwAD.make_dual(primal, tangent.T)
64
assert fwAD.unpack_dual(dual_input_alt).tangent is not tangent
65
66
# Tensors that do not have an associated tangent are automatically
67
# considered to have a zero-filled tangent of the same shape.
68
plain_tensor = torch.randn(10, 10)
69
dual_output = fn(dual_input, plain_tensor)
70
71
# Unpacking the dual returns a ``namedtuple`` with ``primal`` and ``tangent``
72
# as attributes
73
jvp = fwAD.unpack_dual(dual_output).tangent
74
75
assert fwAD.unpack_dual(dual_output).tangent is None
76
77
######################################################################
78
# Usage with Modules
79
# --------------------------------------------------------------------
80
# To use ``nn.Module`` with forward AD, replace the parameters of your
81
# model with dual tensors before performing the forward pass. At the
82
# time of writing, it is not possible to create dual tensor
83
# `nn.Parameter`s. As a workaround, one must register the dual tensor
84
# as a non-parameter attribute of the module.
85
86
import torch.nn as nn
87
88
model = nn.Linear(5, 5)
89
input = torch.randn(16, 5)
90
91
params = {name: p for name, p in model.named_parameters()}
92
tangents = {name: torch.rand_like(p) for name, p in params.items()}
93
94
with fwAD.dual_level():
95
for name, p in params.items():
96
delattr(model, name)
97
setattr(model, name, fwAD.make_dual(p, tangents[name]))
98
99
out = model(input)
100
jvp = fwAD.unpack_dual(out).tangent
101
102
######################################################################
103
# Using the functional Module API (beta)
104
# --------------------------------------------------------------------
105
# Another way to use ``nn.Module`` with forward AD is to utilize
106
# the functional Module API (also known as the stateless Module API).
107
108
from torch.func import functional_call
109
110
# We need a fresh module because the functional call requires the
111
# the model to have parameters registered.
112
model = nn.Linear(5, 5)
113
114
dual_params = {}
115
with fwAD.dual_level():
116
for name, p in params.items():
117
# Using the same ``tangents`` from the above section
118
dual_params[name] = fwAD.make_dual(p, tangents[name])
119
out = functional_call(model, dual_params, input)
120
jvp2 = fwAD.unpack_dual(out).tangent
121
122
# Check our results
123
assert torch.allclose(jvp, jvp2)
124
125
######################################################################
126
# Custom autograd Function
127
# --------------------------------------------------------------------
128
# Custom Functions also support forward-mode AD. To create custom Function
129
# supporting forward-mode AD, register the ``jvp()`` static method. It is
130
# possible, but not mandatory for custom Functions to support both forward
131
# and backward AD. See the
132
# `documentation <https://pytorch.org/docs/master/notes/extending.html#forward-mode-ad>`_
133
# for more information.
134
135
class Fn(torch.autograd.Function):
136
@staticmethod
137
def forward(ctx, foo):
138
result = torch.exp(foo)
139
# Tensors stored in ``ctx`` can be used in the subsequent forward grad
140
# computation.
141
ctx.result = result
142
return result
143
144
@staticmethod
145
def jvp(ctx, gI):
146
gO = gI * ctx.result
147
# If the tensor stored in`` ctx`` will not also be used in the backward pass,
148
# one can manually free it using ``del``
149
del ctx.result
150
return gO
151
152
fn = Fn.apply
153
154
primal = torch.randn(10, 10, dtype=torch.double, requires_grad=True)
155
tangent = torch.randn(10, 10)
156
157
with fwAD.dual_level():
158
dual_input = fwAD.make_dual(primal, tangent)
159
dual_output = fn(dual_input)
160
jvp = fwAD.unpack_dual(dual_output).tangent
161
162
# It is important to use ``autograd.gradcheck`` to verify that your
163
# custom autograd Function computes the gradients correctly. By default,
164
# ``gradcheck`` only checks the backward-mode (reverse-mode) AD gradients. Specify
165
# ``check_forward_ad=True`` to also check forward grads. If you did not
166
# implement the backward formula for your function, you can also tell ``gradcheck``
167
# to skip the tests that require backward-mode AD by specifying
168
# ``check_backward_ad=False``, ``check_undefined_grad=False``, and
169
# ``check_batched_grad=False``.
170
torch.autograd.gradcheck(Fn.apply, (primal,), check_forward_ad=True,
171
check_backward_ad=False, check_undefined_grad=False,
172
check_batched_grad=False)
173
174
######################################################################
175
# Functional API (beta)
176
# --------------------------------------------------------------------
177
# We also offer a higher-level functional API in functorch
178
# for computing Jacobian-vector products that you may find simpler to use
179
# depending on your use case.
180
#
181
# The benefit of the functional API is that there isn't a need to understand
182
# or use the lower-level dual tensor API and that you can compose it with
183
# other `functorch transforms (like vmap) <https://pytorch.org/functorch/stable/notebooks/jacobians_hessians.html>`_;
184
# the downside is that it offers you less control.
185
#
186
# Note that the remainder of this tutorial will require functorch
187
# (https://github.com/pytorch/functorch) to run. Please find installation
188
# instructions at the specified link.
189
190
import functorch as ft
191
192
primal0 = torch.randn(10, 10)
193
tangent0 = torch.randn(10, 10)
194
primal1 = torch.randn(10, 10)
195
tangent1 = torch.randn(10, 10)
196
197
def fn(x, y):
198
return x ** 2 + y ** 2
199
200
# Here is a basic example to compute the JVP of the above function.
201
# The ``jvp(func, primals, tangents)`` returns ``func(*primals)`` as well as the
202
# computed Jacobian-vector product (JVP). Each primal must be associated with a tangent of the same shape.
203
primal_out, tangent_out = ft.jvp(fn, (primal0, primal1), (tangent0, tangent1))
204
205
# ``functorch.jvp`` requires every primal to be associated with a tangent.
206
# If we only want to associate certain inputs to `fn` with tangents,
207
# then we'll need to create a new function that captures inputs without tangents:
208
primal = torch.randn(10, 10)
209
tangent = torch.randn(10, 10)
210
y = torch.randn(10, 10)
211
212
import functools
213
new_fn = functools.partial(fn, y=y)
214
primal_out, tangent_out = ft.jvp(new_fn, (primal,), (tangent,))
215
216
######################################################################
217
# Using the functional API with Modules
218
# --------------------------------------------------------------------
219
# To use ``nn.Module`` with ``functorch.jvp`` to compute Jacobian-vector products
220
# with respect to the model parameters, we need to reformulate the
221
# ``nn.Module`` as a function that accepts both the model parameters and inputs
222
# to the module.
223
224
model = nn.Linear(5, 5)
225
input = torch.randn(16, 5)
226
tangents = tuple([torch.rand_like(p) for p in model.parameters()])
227
228
# Given a ``torch.nn.Module``, ``ft.make_functional_with_buffers`` extracts the state
229
# (``params`` and buffers) and returns a functional version of the model that
230
# can be invoked like a function.
231
# That is, the returned ``func`` can be invoked like
232
# ``func(params, buffers, input)``.
233
# ``ft.make_functional_with_buffers`` is analogous to the ``nn.Modules`` stateless API
234
# that you saw previously and we're working on consolidating the two.
235
func, params, buffers = ft.make_functional_with_buffers(model)
236
237
# Because ``jvp`` requires every input to be associated with a tangent, we need to
238
# create a new function that, when given the parameters, produces the output
239
def func_params_only(params):
240
return func(params, buffers, input)
241
242
model_output, jvp_out = ft.jvp(func_params_only, (params,), (tangents,))
243
244
245
######################################################################
246
# [0] https://en.wikipedia.org/wiki/Dual_number
247
248