CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
Path: blob/main/prototype_source/vmap_recipe.py
Views: 494
"""1torch.vmap2==========3This tutorial introduces torch.vmap, an autovectorizer for PyTorch operations.4torch.vmap is a prototype feature and cannot handle a number of use cases;5however, we would like to gather use cases for it to inform the design. If you6are considering using torch.vmap or think it would be really cool for something,7please contact us at https://github.com/pytorch/pytorch/issues/42368.89So, what is vmap?10-----------------11vmap is a higher-order function. It accepts a function `func` and returns a new12function that maps `func` over some dimension of the inputs. It is highly13inspired by JAX's vmap.1415Semantically, vmap pushes the "map" into PyTorch operations called by `func`,16effectively vectorizing those operations.17"""18import torch19# NB: vmap is only available on nightly builds of PyTorch.20# You can download one at pytorch.org if you're interested in testing it out.21from torch import vmap2223####################################################################24# The first use case for vmap is making it easier to handle25# batch dimensions in your code. One can write a function `func`26# that runs on examples and then lift it to a function that can27# take batches of examples with `vmap(func)`. `func` however28# is subject to many restrictions:29#30# - it must be functional (one cannot mutate a Python data structure31# inside of it), with the exception of in-place PyTorch operations.32# - batches of examples must be provided as Tensors. This means that33# vmap doesn't handle variable-length sequences out of the box.34#35# One example of using `vmap` is to compute batched dot products. PyTorch36# doesn't provide a batched `torch.dot` API; instead of unsuccessfully37# rummaging through docs, use `vmap` to construct a new function:3839torch.dot # [D], [D] -> []40batched_dot = torch.vmap(torch.dot) # [N, D], [N, D] -> [N]41x, y = torch.randn(2, 5), torch.randn(2, 5)42batched_dot(x, y)4344####################################################################45# `vmap` can be helpful in hiding batch dimensions, leading to a simpler46# model authoring experience.47batch_size, feature_size = 3, 548weights = torch.randn(feature_size, requires_grad=True)4950# Note that model doesn't work with a batch of feature vectors because51# torch.dot must take 1D tensors. It's pretty easy to rewrite this52# to use `torch.matmul` instead, but if we didn't want to do that or if53# the code is more complicated (e.g., does some advanced indexing54# shenanigins), we can simply call `vmap`. `vmap` batches over ALL55# inputs, unless otherwise specified (with the in_dims argument,56# please see the documentation for more details).57def model(feature_vec):58# Very simple linear model with activation59return feature_vec.dot(weights).relu()6061examples = torch.randn(batch_size, feature_size)62result = torch.vmap(model)(examples)63expected = torch.stack([model(example) for example in examples.unbind()])64assert torch.allclose(result, expected)6566####################################################################67# `vmap` can also help vectorize computations that were previously difficult68# or impossible to batch. This bring us to our second use case: batched69# gradient computation.70#71# - https://github.com/pytorch/pytorch/issues/830472# - https://github.com/pytorch/pytorch/issues/2347573#74# The PyTorch autograd engine computes vjps (vector-Jacobian products).75# Using vmap, we can compute (batched vector) - jacobian products.76#77# One example of this is computing a full Jacobian matrix (this can also be78# applied to computing a full Hessian matrix).79# Computing a full Jacobian matrix for some function f: R^N -> R^N usually80# requires N calls to `autograd.grad`, one per Jacobian row.8182# Setup83N = 584def f(x):85return x ** 28687x = torch.randn(N, requires_grad=True)88y = f(x)89basis_vectors = torch.eye(N)9091# Sequential approach92jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]93for v in basis_vectors.unbind()]94jacobian = torch.stack(jacobian_rows)9596# Using `vmap`, we can vectorize the whole computation, computing the97# Jacobian in a single call to `autograd.grad`.98def get_vjp(v):99return torch.autograd.grad(y, x, v)[0]100101jacobian_vmap = vmap(get_vjp)(basis_vectors)102assert torch.allclose(jacobian_vmap, jacobian)103104####################################################################105# The third main use case for vmap is computing per-sample-gradients.106# This is something that the vmap prototype cannot handle performantly107# right now. We're not sure what the API for computing per-sample-gradients108# should be, but if you have ideas, please comment in109# https://github.com/pytorch/pytorch/issues/7786.110111def model(sample, weight):112# do something...113return torch.dot(sample, weight)114115def grad_sample(sample):116return torch.autograd.functional.vjp(lambda weight: model(sample), weight)[1]117118# The following doesn't actually work in the vmap prototype. But it119# could be an API for computing per-sample-gradients.120121# batch_of_samples = torch.randn(64, 5)122# vmap(grad_sample)(batch_of_samples)123124125