Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/intermediate_source/forward_ad_usage.py
Views: 712
# -*- coding: utf-8 -*-1"""2Forward-mode Automatic Differentiation (Beta)3=============================================45This tutorial demonstrates how to use forward-mode AD to compute6directional derivatives (or equivalently, Jacobian-vector products).78The tutorial below uses some APIs only available in versions >= 1.119(or nightly builds).1011Also note that forward-mode AD is currently in beta. The API is12subject to change and operator coverage is still incomplete.1314Basic Usage15--------------------------------------------------------------------16Unlike reverse-mode AD, forward-mode AD computes gradients eagerly17alongside the forward pass. We can use forward-mode AD to compute a18directional derivative by performing the forward pass as before,19except we first associate our input with another tensor representing20the direction of the directional derivative (or equivalently, the ``v``21in a Jacobian-vector product). When an input, which we call "primal", is22associated with a "direction" tensor, which we call "tangent", the23resultant new tensor object is called a "dual tensor" for its connection24to dual numbers[0].2526As the forward pass is performed, if any input tensors are dual tensors,27extra computation is performed to propagate this "sensitivity" of the28function.2930"""3132import torch33import torch.autograd.forward_ad as fwAD3435primal = torch.randn(10, 10)36tangent = torch.randn(10, 10)3738def fn(x, y):39return x ** 2 + y ** 24041# All forward AD computation must be performed in the context of42# a ``dual_level`` context. All dual tensors created in such a context43# will have their tangents destroyed upon exit. This is to ensure that44# if the output or intermediate results of this computation are reused45# in a future forward AD computation, their tangents (which are associated46# with this computation) won't be confused with tangents from the later47# computation.48with fwAD.dual_level():49# To create a dual tensor we associate a tensor, which we call the50# primal with another tensor of the same size, which we call the tangent.51# If the layout of the tangent is different from that of the primal,52# The values of the tangent are copied into a new tensor with the same53# metadata as the primal. Otherwise, the tangent itself is used as-is.54#55# It is also important to note that the dual tensor created by56# ``make_dual`` is a view of the primal.57dual_input = fwAD.make_dual(primal, tangent)58assert fwAD.unpack_dual(dual_input).tangent is tangent5960# To demonstrate the case where the copy of the tangent happens,61# we pass in a tangent with a layout different from that of the primal62dual_input_alt = fwAD.make_dual(primal, tangent.T)63assert fwAD.unpack_dual(dual_input_alt).tangent is not tangent6465# Tensors that do not have an associated tangent are automatically66# considered to have a zero-filled tangent of the same shape.67plain_tensor = torch.randn(10, 10)68dual_output = fn(dual_input, plain_tensor)6970# Unpacking the dual returns a ``namedtuple`` with ``primal`` and ``tangent``71# as attributes72jvp = fwAD.unpack_dual(dual_output).tangent7374assert fwAD.unpack_dual(dual_output).tangent is None7576######################################################################77# Usage with Modules78# --------------------------------------------------------------------79# To use ``nn.Module`` with forward AD, replace the parameters of your80# model with dual tensors before performing the forward pass. At the81# time of writing, it is not possible to create dual tensor82# `nn.Parameter`s. As a workaround, one must register the dual tensor83# as a non-parameter attribute of the module.8485import torch.nn as nn8687model = nn.Linear(5, 5)88input = torch.randn(16, 5)8990params = {name: p for name, p in model.named_parameters()}91tangents = {name: torch.rand_like(p) for name, p in params.items()}9293with fwAD.dual_level():94for name, p in params.items():95delattr(model, name)96setattr(model, name, fwAD.make_dual(p, tangents[name]))9798out = model(input)99jvp = fwAD.unpack_dual(out).tangent100101######################################################################102# Using the functional Module API (beta)103# --------------------------------------------------------------------104# Another way to use ``nn.Module`` with forward AD is to utilize105# the functional Module API (also known as the stateless Module API).106107from torch.func import functional_call108109# We need a fresh module because the functional call requires the110# the model to have parameters registered.111model = nn.Linear(5, 5)112113dual_params = {}114with fwAD.dual_level():115for name, p in params.items():116# Using the same ``tangents`` from the above section117dual_params[name] = fwAD.make_dual(p, tangents[name])118out = functional_call(model, dual_params, input)119jvp2 = fwAD.unpack_dual(out).tangent120121# Check our results122assert torch.allclose(jvp, jvp2)123124######################################################################125# Custom autograd Function126# --------------------------------------------------------------------127# Custom Functions also support forward-mode AD. To create custom Function128# supporting forward-mode AD, register the ``jvp()`` static method. It is129# possible, but not mandatory for custom Functions to support both forward130# and backward AD. See the131# `documentation <https://pytorch.org/docs/master/notes/extending.html#forward-mode-ad>`_132# for more information.133134class Fn(torch.autograd.Function):135@staticmethod136def forward(ctx, foo):137result = torch.exp(foo)138# Tensors stored in ``ctx`` can be used in the subsequent forward grad139# computation.140ctx.result = result141return result142143@staticmethod144def jvp(ctx, gI):145gO = gI * ctx.result146# If the tensor stored in`` ctx`` will not also be used in the backward pass,147# one can manually free it using ``del``148del ctx.result149return gO150151fn = Fn.apply152153primal = torch.randn(10, 10, dtype=torch.double, requires_grad=True)154tangent = torch.randn(10, 10)155156with fwAD.dual_level():157dual_input = fwAD.make_dual(primal, tangent)158dual_output = fn(dual_input)159jvp = fwAD.unpack_dual(dual_output).tangent160161# It is important to use ``autograd.gradcheck`` to verify that your162# custom autograd Function computes the gradients correctly. By default,163# ``gradcheck`` only checks the backward-mode (reverse-mode) AD gradients. Specify164# ``check_forward_ad=True`` to also check forward grads. If you did not165# implement the backward formula for your function, you can also tell ``gradcheck``166# to skip the tests that require backward-mode AD by specifying167# ``check_backward_ad=False``, ``check_undefined_grad=False``, and168# ``check_batched_grad=False``.169torch.autograd.gradcheck(Fn.apply, (primal,), check_forward_ad=True,170check_backward_ad=False, check_undefined_grad=False,171check_batched_grad=False)172173######################################################################174# Functional API (beta)175# --------------------------------------------------------------------176# We also offer a higher-level functional API in functorch177# for computing Jacobian-vector products that you may find simpler to use178# depending on your use case.179#180# The benefit of the functional API is that there isn't a need to understand181# or use the lower-level dual tensor API and that you can compose it with182# other `functorch transforms (like vmap) <https://pytorch.org/functorch/stable/notebooks/jacobians_hessians.html>`_;183# the downside is that it offers you less control.184#185# Note that the remainder of this tutorial will require functorch186# (https://github.com/pytorch/functorch) to run. Please find installation187# instructions at the specified link.188189import functorch as ft190191primal0 = torch.randn(10, 10)192tangent0 = torch.randn(10, 10)193primal1 = torch.randn(10, 10)194tangent1 = torch.randn(10, 10)195196def fn(x, y):197return x ** 2 + y ** 2198199# Here is a basic example to compute the JVP of the above function.200# The ``jvp(func, primals, tangents)`` returns ``func(*primals)`` as well as the201# computed Jacobian-vector product (JVP). Each primal must be associated with a tangent of the same shape.202primal_out, tangent_out = ft.jvp(fn, (primal0, primal1), (tangent0, tangent1))203204# ``functorch.jvp`` requires every primal to be associated with a tangent.205# If we only want to associate certain inputs to `fn` with tangents,206# then we'll need to create a new function that captures inputs without tangents:207primal = torch.randn(10, 10)208tangent = torch.randn(10, 10)209y = torch.randn(10, 10)210211import functools212new_fn = functools.partial(fn, y=y)213primal_out, tangent_out = ft.jvp(new_fn, (primal,), (tangent,))214215######################################################################216# Using the functional API with Modules217# --------------------------------------------------------------------218# To use ``nn.Module`` with ``functorch.jvp`` to compute Jacobian-vector products219# with respect to the model parameters, we need to reformulate the220# ``nn.Module`` as a function that accepts both the model parameters and inputs221# to the module.222223model = nn.Linear(5, 5)224input = torch.randn(16, 5)225tangents = tuple([torch.rand_like(p) for p in model.parameters()])226227# Given a ``torch.nn.Module``, ``ft.make_functional_with_buffers`` extracts the state228# (``params`` and buffers) and returns a functional version of the model that229# can be invoked like a function.230# That is, the returned ``func`` can be invoked like231# ``func(params, buffers, input)``.232# ``ft.make_functional_with_buffers`` is analogous to the ``nn.Modules`` stateless API233# that you saw previously and we're working on consolidating the two.234func, params, buffers = ft.make_functional_with_buffers(model)235236# Because ``jvp`` requires every input to be associated with a tangent, we need to237# create a new function that, when given the parameters, produces the output238def func_params_only(params):239return func(params, buffers, input)240241model_output, jvp_out = ft.jvp(func_params_only, (params,), (tangents,))242243244######################################################################245# [0] https://en.wikipedia.org/wiki/Dual_number246247248