Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/intermediate_source/jacobians_hessians.py
Views: 712
# -*- coding: utf-8 -*-1"""2Jacobians, Hessians, hvp, vhp, and more: composing function transforms3======================================================================45Computing jacobians or hessians are useful in a number of non-traditional6deep learning models. It is difficult (or annoying) to compute these quantities7efficiently using PyTorch's regular autodiff APIs8(``Tensor.backward()``, ``torch.autograd.grad``). PyTorch's9`JAX-inspired <https://github.com/google/jax>`_10`function transforms API <https://pytorch.org/docs/master/func.html>`_11provides ways of computing various higher-order autodiff quantities12efficiently.1314.. note::1516This tutorial requires PyTorch 2.0.0 or later.1718Computing the Jacobian19----------------------20"""2122import torch23import torch.nn.functional as F24from functools import partial25_ = torch.manual_seed(0)2627######################################################################28# Let's start with a function that we'd like to compute the jacobian of.29# This is a simple linear function with non-linear activation.3031def predict(weight, bias, x):32return F.linear(x, weight, bias).tanh()3334######################################################################35# Let's add some dummy data: a weight, a bias, and a feature vector x.3637D = 1638weight = torch.randn(D, D)39bias = torch.randn(D)40x = torch.randn(D) # feature vector4142######################################################################43# Let's think of ``predict`` as a function that maps the input ``x`` from :math:`R^D \to R^D`.44# PyTorch Autograd computes vector-Jacobian products. In order to compute the full45# Jacobian of this :math:`R^D \to R^D` function, we would have to compute it row-by-row46# by using a different unit vector each time.4748def compute_jac(xp):49jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]50for vec in unit_vectors]51return torch.stack(jacobian_rows)5253xp = x.clone().requires_grad_()54unit_vectors = torch.eye(D)5556jacobian = compute_jac(xp)5758print(jacobian.shape)59print(jacobian[0]) # show first row6061######################################################################62# Instead of computing the jacobian row-by-row, we can use PyTorch's63# ``torch.vmap`` function transform to get rid of the for-loop and vectorize the64# computation. We can’t directly apply ``vmap`` to ``torch.autograd.grad``;65# instead, PyTorch provides a ``torch.func.vjp`` transform that composes with66# ``torch.vmap``:6768from torch.func import vmap, vjp6970_, vjp_fn = vjp(partial(predict, weight, bias), x)7172ft_jacobian, = vmap(vjp_fn)(unit_vectors)7374# let's confirm both methods compute the same result75assert torch.allclose(ft_jacobian, jacobian)7677######################################################################78# In a later tutorial a composition of reverse-mode AD and ``vmap`` will give us79# per-sample-gradients.80# In this tutorial, composing reverse-mode AD and ``vmap`` gives us Jacobian81# computation!82# Various compositions of ``vmap`` and autodiff transforms can give us different83# interesting quantities.84#85# PyTorch provides ``torch.func.jacrev`` as a convenience function that performs86# the ``vmap-vjp`` composition to compute jacobians. ``jacrev`` accepts an ``argnums``87# argument that says which argument we would like to compute Jacobians with88# respect to.8990from torch.func import jacrev9192ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)9394# Confirm by running the following:95assert torch.allclose(ft_jacobian, jacobian)9697######################################################################98# Let's compare the performance of the two ways to compute the jacobian.99# The function transform version is much faster (and becomes even faster the100# more outputs there are).101#102# In general, we expect that vectorization via ``vmap`` can help eliminate overhead103# and give better utilization of your hardware.104#105# ``vmap`` does this magic by pushing the outer loop down into the function's106# primitive operations in order to obtain better performance.107#108# Let's make a quick function to evaluate performance and deal with109# microseconds and milliseconds measurements:110111def get_perf(first, first_descriptor, second, second_descriptor):112"""takes torch.benchmark objects and compares delta of second vs first."""113faster = second.times[0]114slower = first.times[0]115gain = (slower-faster)/slower116if gain < 0: gain *=-1117final_gain = gain*100118print(f" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} ")119120######################################################################121# And then run the performance comparison:122123from torch.utils.benchmark import Timer124125without_vmap = Timer(stmt="compute_jac(xp)", globals=globals())126with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())127128no_vmap_timer = without_vmap.timeit(500)129with_vmap_timer = with_vmap.timeit(500)130131print(no_vmap_timer)132print(with_vmap_timer)133134######################################################################135# Let's do a relative performance comparison of the above with our ``get_perf`` function:136137get_perf(no_vmap_timer, "without vmap", with_vmap_timer, "vmap")138139######################################################################140# Furthermore, it’s pretty easy to flip the problem around and say we want to141# compute Jacobians of the parameters to our model (weight, bias) instead of the input142143# note the change in input via ``argnums`` parameters of 0,1 to map to weight and bias144ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)145146######################################################################147# Reverse-mode Jacobian (``jacrev``) vs forward-mode Jacobian (``jacfwd``)148# ------------------------------------------------------------------------149#150# We offer two APIs to compute jacobians: ``jacrev`` and ``jacfwd``:151#152# - ``jacrev`` uses reverse-mode AD. As you saw above it is a composition of our153# ``vjp`` and ``vmap`` transforms.154# - ``jacfwd`` uses forward-mode AD. It is implemented as a composition of our155# ``jvp`` and ``vmap`` transforms.156#157# ``jacfwd`` and ``jacrev`` can be substituted for each other but they have different158# performance characteristics.159#160# As a general rule of thumb, if you’re computing the jacobian of an :math:`R^N \to R^M`161# function, and there are many more outputs than inputs (for example, :math:`M > N`) then162# ``jacfwd`` is preferred, otherwise use ``jacrev``. There are exceptions to this rule,163# but a non-rigorous argument for this follows:164#165# In reverse-mode AD, we are computing the jacobian row-by-row, while in166# forward-mode AD (which computes Jacobian-vector products), we are computing167# it column-by-column. The Jacobian matrix has M rows and N columns, so if it168# is taller or wider one way we may prefer the method that deals with fewer169# rows or columns.170171from torch.func import jacrev, jacfwd172173######################################################################174# First, let's benchmark with more inputs than outputs:175176Din = 32177Dout = 2048178weight = torch.randn(Dout, Din)179180bias = torch.randn(Dout)181x = torch.randn(Din)182183# remember the general rule about taller vs wider... here we have a taller matrix:184print(weight.shape)185186using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())187using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())188189jacfwd_timing = using_fwd.timeit(500)190jacrev_timing = using_bwd.timeit(500)191192print(f'jacfwd time: {jacfwd_timing}')193print(f'jacrev time: {jacrev_timing}')194195######################################################################196# and then do a relative benchmark:197198get_perf(jacfwd_timing, "jacfwd", jacrev_timing, "jacrev", );199200#######################################################################201# and now the reverse - more outputs (M) than inputs (N):202203Din = 2048204Dout = 32205weight = torch.randn(Dout, Din)206bias = torch.randn(Dout)207x = torch.randn(Din)208209using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())210using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())211212jacfwd_timing = using_fwd.timeit(500)213jacrev_timing = using_bwd.timeit(500)214215print(f'jacfwd time: {jacfwd_timing}')216print(f'jacrev time: {jacrev_timing}')217218#######################################################################219# and a relative performance comparison:220221get_perf(jacrev_timing, "jacrev", jacfwd_timing, "jacfwd")222223#######################################################################224# Hessian computation with functorch.hessian225# ------------------------------------------226# We offer a convenience API to compute hessians: ``torch.func.hessiani``.227# Hessians are the jacobian of the jacobian (or the partial derivative of228# the partial derivative, aka second order).229#230# This suggests that one can just compose functorch jacobian transforms to231# compute the Hessian.232# Indeed, under the hood, ``hessian(f)`` is simply ``jacfwd(jacrev(f))``.233#234# Note: to boost performance: depending on your model, you may also want to235# use ``jacfwd(jacfwd(f))`` or ``jacrev(jacrev(f))`` instead to compute hessians236# leveraging the rule of thumb above regarding wider vs taller matrices.237238from torch.func import hessian239240# lets reduce the size in order not to overwhelm Colab. Hessians require241# significant memory:242Din = 512243Dout = 32244weight = torch.randn(Dout, Din)245bias = torch.randn(Dout)246x = torch.randn(Din)247248hess_api = hessian(predict, argnums=2)(weight, bias, x)249hess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)250hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)251252#######################################################################253# Let's verify we have the same result regardless of using hessian API or254# using ``jacfwd(jacfwd())``.255256torch.allclose(hess_api, hess_fwdfwd)257258#######################################################################259# Batch Jacobian and Batch Hessian260# --------------------------------261# In the above examples we’ve been operating with a single feature vector.262# In some cases you might want to take the Jacobian of a batch of outputs263# with respect to a batch of inputs. That is, given a batch of inputs of264# shape ``(B, N)`` and a function that goes from :math:`R^N \to R^M`, we would like265# a Jacobian of shape ``(B, M, N)``.266#267# The easiest way to do this is to use ``vmap``:268269batch_size = 64270Din = 31271Dout = 33272273weight = torch.randn(Dout, Din)274print(f"weight shape = {weight.shape}")275276bias = torch.randn(Dout)277278x = torch.randn(batch_size, Din)279280compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))281batch_jacobian0 = compute_batch_jacobian(weight, bias, x)282283#######################################################################284# If you have a function that goes from (B, N) -> (B, M) instead and are285# certain that each input produces an independent output, then it's also286# sometimes possible to do this without using ``vmap`` by summing the outputs287# and then computing the Jacobian of that function:288289def predict_with_output_summed(weight, bias, x):290return predict(weight, bias, x).sum(0)291292batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0)293assert torch.allclose(batch_jacobian0, batch_jacobian1)294295#######################################################################296# If you instead have a function that goes from :math:`R^N \to R^M` but inputs that297# are batched, you compose ``vmap`` with ``jacrev`` to compute batched jacobians:298#299# Finally, batch hessians can be computed similarly. It's easiest to think300# about them by using ``vmap`` to batch over hessian computation, but in some301# cases the sum trick also works.302303compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0))304305batch_hess = compute_batch_hessian(weight, bias, x)306batch_hess.shape307308#######################################################################309# Computing Hessian-vector products310# ---------------------------------311# The naive way to compute a Hessian-vector product (hvp) is to materialize312# the full Hessian and perform a dot-product with a vector. We can do better:313# it turns out we don't need to materialize the full Hessian to do this. We'll314# go through two (of many) different strategies to compute Hessian-vector products:315# - composing reverse-mode AD with reverse-mode AD316# - composing reverse-mode AD with forward-mode AD317#318# Composing reverse-mode AD with forward-mode AD (as opposed to reverse-mode319# with reverse-mode) is generally the more memory efficient way to compute a320# hvp because forward-mode AD doesn't need to construct an Autograd graph and321# save intermediates for backward:322323from torch.func import jvp, grad, vjp324325def hvp(f, primals, tangents):326return jvp(grad(f), primals, tangents)[1]327328#######################################################################329# Here's some sample usage.330331def f(x):332return x.sin().sum()333334x = torch.randn(2048)335tangent = torch.randn(2048)336337result = hvp(f, (x,), (tangent,))338339#######################################################################340# If PyTorch forward-AD does not have coverage for your operations, then we can341# instead compose reverse-mode AD with reverse-mode AD:342343def hvp_revrev(f, primals, tangents):344_, vjp_fn = vjp(grad(f), *primals)345return vjp_fn(*tangents)346347result_hvp_revrev = hvp_revrev(f, (x,), (tangent,))348assert torch.allclose(result, result_hvp_revrev[0])349350351