CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/prototype_source/vmap_recipe.py
Views: 494
1
"""
2
torch.vmap
3
==========
4
This tutorial introduces torch.vmap, an autovectorizer for PyTorch operations.
5
torch.vmap is a prototype feature and cannot handle a number of use cases;
6
however, we would like to gather use cases for it to inform the design. If you
7
are considering using torch.vmap or think it would be really cool for something,
8
please contact us at https://github.com/pytorch/pytorch/issues/42368.
9
10
So, what is vmap?
11
-----------------
12
vmap is a higher-order function. It accepts a function `func` and returns a new
13
function that maps `func` over some dimension of the inputs. It is highly
14
inspired by JAX's vmap.
15
16
Semantically, vmap pushes the "map" into PyTorch operations called by `func`,
17
effectively vectorizing those operations.
18
"""
19
import torch
20
# NB: vmap is only available on nightly builds of PyTorch.
21
# You can download one at pytorch.org if you're interested in testing it out.
22
from torch import vmap
23
24
####################################################################
25
# The first use case for vmap is making it easier to handle
26
# batch dimensions in your code. One can write a function `func`
27
# that runs on examples and then lift it to a function that can
28
# take batches of examples with `vmap(func)`. `func` however
29
# is subject to many restrictions:
30
#
31
# - it must be functional (one cannot mutate a Python data structure
32
# inside of it), with the exception of in-place PyTorch operations.
33
# - batches of examples must be provided as Tensors. This means that
34
# vmap doesn't handle variable-length sequences out of the box.
35
#
36
# One example of using `vmap` is to compute batched dot products. PyTorch
37
# doesn't provide a batched `torch.dot` API; instead of unsuccessfully
38
# rummaging through docs, use `vmap` to construct a new function:
39
40
torch.dot # [D], [D] -> []
41
batched_dot = torch.vmap(torch.dot) # [N, D], [N, D] -> [N]
42
x, y = torch.randn(2, 5), torch.randn(2, 5)
43
batched_dot(x, y)
44
45
####################################################################
46
# `vmap` can be helpful in hiding batch dimensions, leading to a simpler
47
# model authoring experience.
48
batch_size, feature_size = 3, 5
49
weights = torch.randn(feature_size, requires_grad=True)
50
51
# Note that model doesn't work with a batch of feature vectors because
52
# torch.dot must take 1D tensors. It's pretty easy to rewrite this
53
# to use `torch.matmul` instead, but if we didn't want to do that or if
54
# the code is more complicated (e.g., does some advanced indexing
55
# shenanigins), we can simply call `vmap`. `vmap` batches over ALL
56
# inputs, unless otherwise specified (with the in_dims argument,
57
# please see the documentation for more details).
58
def model(feature_vec):
59
# Very simple linear model with activation
60
return feature_vec.dot(weights).relu()
61
62
examples = torch.randn(batch_size, feature_size)
63
result = torch.vmap(model)(examples)
64
expected = torch.stack([model(example) for example in examples.unbind()])
65
assert torch.allclose(result, expected)
66
67
####################################################################
68
# `vmap` can also help vectorize computations that were previously difficult
69
# or impossible to batch. This bring us to our second use case: batched
70
# gradient computation.
71
#
72
# - https://github.com/pytorch/pytorch/issues/8304
73
# - https://github.com/pytorch/pytorch/issues/23475
74
#
75
# The PyTorch autograd engine computes vjps (vector-Jacobian products).
76
# Using vmap, we can compute (batched vector) - jacobian products.
77
#
78
# One example of this is computing a full Jacobian matrix (this can also be
79
# applied to computing a full Hessian matrix).
80
# Computing a full Jacobian matrix for some function f: R^N -> R^N usually
81
# requires N calls to `autograd.grad`, one per Jacobian row.
82
83
# Setup
84
N = 5
85
def f(x):
86
return x ** 2
87
88
x = torch.randn(N, requires_grad=True)
89
y = f(x)
90
basis_vectors = torch.eye(N)
91
92
# Sequential approach
93
jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
94
for v in basis_vectors.unbind()]
95
jacobian = torch.stack(jacobian_rows)
96
97
# Using `vmap`, we can vectorize the whole computation, computing the
98
# Jacobian in a single call to `autograd.grad`.
99
def get_vjp(v):
100
return torch.autograd.grad(y, x, v)[0]
101
102
jacobian_vmap = vmap(get_vjp)(basis_vectors)
103
assert torch.allclose(jacobian_vmap, jacobian)
104
105
####################################################################
106
# The third main use case for vmap is computing per-sample-gradients.
107
# This is something that the vmap prototype cannot handle performantly
108
# right now. We're not sure what the API for computing per-sample-gradients
109
# should be, but if you have ideas, please comment in
110
# https://github.com/pytorch/pytorch/issues/7786.
111
112
def model(sample, weight):
113
# do something...
114
return torch.dot(sample, weight)
115
116
def grad_sample(sample):
117
return torch.autograd.functional.vjp(lambda weight: model(sample), weight)[1]
118
119
# The following doesn't actually work in the vmap prototype. But it
120
# could be an API for computing per-sample-gradients.
121
122
# batch_of_samples = torch.randn(64, 5)
123
# vmap(grad_sample)(batch_of_samples)
124
125