CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: pytorch/tutorials
Path: blob/main/recipes_source/foreach_map.py
Views: 1191
1
"""
2
Explicit horizontal fusion with foreach_map and torch.compile
3
===============================================================
4
5
**Author:** `Michael Lazos <https://github.com/mlazos>`_
6
"""
7
8
#########################################################
9
# Horizontal fusion is a key optimization in ML compilers. In eager,
10
# this is typically expressed using the torch._foreach* ops which parallelizes
11
# operations across a list of tensors. However, supporting all possible permutations
12
# of arguments is quite difficult (e.g. mixtures of scalars and lists). Foreach_map
13
# allows conversion of any pointwise op in ``torch`` to a horiztonally fused foreach
14
# variant. In this tutorial, we will demonstrate how to implement the Adam optimizer
15
# with ``foreach_map`` to generate a fully fused kernel.
16
#
17
# .. note::
18
#
19
# This recipe describes a prototype feature. Prototype features are typically
20
# at an early stage for feedback and testing and are subject to change.
21
#
22
# Prerequisites
23
# -------------
24
#
25
# * PyTorch v2.7.0 or later
26
#
27
28
#####################################################################
29
# Model Setup
30
# ~~~~~~~~~~~~~~~~~~~~~
31
# For this example, we'll use a simple sequence of linear layers.
32
# We instantiate an independent copy to compare the two optimizer implementations.
33
#
34
import torch
35
36
# exit cleanly if we are on a device that doesn't support ``torch.compile``
37
if torch.cuda.get_device_capability() < (7, 0):
38
print("Exiting because torch.compile is not supported on this device.")
39
import sys
40
sys.exit(0)
41
42
# Create simple model
43
model = torch.nn.Sequential(
44
*[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]
45
)
46
model_copy = torch.nn.Sequential(
47
*[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]
48
)
49
input = torch.rand(1024, device="cuda")
50
51
# run forward pass
52
output = model(input)
53
output_copy = model_copy(input)
54
55
# run backward to populate the grads for our optimizer below
56
output.sum().backward()
57
output_copy.sum().backward()
58
59
#####################################################################
60
# Helper functions for foreach_map implementation
61
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
62
#
63
# In this section, we'll begin our implementation of the Adam optimizer.
64
#
65
from torch._higher_order_ops.foreach_map import foreach_map
66
67
# Helper function to extract optimizer states from a torch.optim.Adam instance
68
def get_inputs(optim):
69
steps = []
70
params = []
71
grads = []
72
exp_avgs = []
73
exp_avg_sqs = []
74
for group in optim.param_groups:
75
for p in group["params"]:
76
params.append(p)
77
grads.append(p.grad)
78
state = optim.state[p]
79
exp_avgs.append(state["exp_avg"])
80
exp_avg_sqs.append(state["exp_avg_sq"])
81
steps.append(state["step"])
82
83
return steps, params, exp_avgs, exp_avg_sqs
84
85
86
# Functions to update the different optimizer states
87
def update_exp_avg_sq(exp_avg_sq, grad, beta2):
88
return exp_avg_sq.mul(beta2).addcmul(grad, grad, value=1 - beta2)
89
90
def update_param(param, step, exp_avg, exp_avg_sq, beta1, beta2, lr, eps):
91
bias_correction1 = 1 - torch.pow(beta1, step)
92
bias_correction2 = (1 - torch.pow(beta2, step)).sqrt()
93
step_size = (lr / bias_correction1).neg()
94
denom = (exp_avg_sq.sqrt() / (bias_correction2 * step_size)).add(eps / step_size)
95
return torch.add(param, torch.div(exp_avg, denom))
96
97
# Our full Adam implementation
98
def foreach_map_adam(
99
steps,
100
params,
101
exp_avgs,
102
exp_avg_sqs,
103
weight_decay=0,
104
beta1=0.9,
105
beta2=0.999,
106
lr=1e-3,
107
eps=1e-8,
108
):
109
with torch.no_grad():
110
grads = [param.grad for param in params]
111
# update step
112
updated_steps = foreach_map(lambda x: x + 1, steps)
113
torch._foreach_copy_(steps, updated_steps)
114
115
if weight_decay != 0:
116
foreach_map(torch.add, (grads,), alpha=weight_decay)
117
118
# Higher-order operators (HOPs) cannot have multiple outputs at the moment
119
# need to call foreach_map once for each output
120
exp_avgs_updated = foreach_map(torch.lerp, exp_avgs, grads, 1 - beta1)
121
exp_avgs_sq_updated = foreach_map(update_exp_avg_sq, exp_avg_sqs, grads, beta2)
122
params_updated = foreach_map(
123
update_param,
124
params,
125
steps,
126
exp_avgs_updated,
127
exp_avgs_sq_updated,
128
beta1,
129
beta2,
130
lr,
131
eps,
132
)
133
# Higher-order operators (HOPs) don't support input mutation today
134
# so manually update the states in-place
135
torch._foreach_copy_(exp_avgs, exp_avgs_updated)
136
torch._foreach_copy_(exp_avg_sqs, exp_avgs_sq_updated)
137
torch._foreach_copy_(params, params_updated)
138
return
139
140
#####################################################################
141
# Setting up and running the compiled kernel
142
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
143
#
144
# In this section, we'll run our Adam optimizer
145
# and compare the results
146
#
147
# .. note::
148
#
149
# ``torch.compile`` is only supported on CUDA devices that have a compute capability of 7.0 or higher.
150
opt_eager = torch.optim.Adam(model.parameters(), lr=torch.tensor(0.01))
151
opt_eager_copy = torch.optim.Adam(model_copy.parameters(), lr=torch.tensor(0.01))
152
153
# warm up the optimizer state dict
154
opt_eager.step()
155
opt_eager_copy.step()
156
157
inputs = get_inputs(opt_eager_copy)
158
compiled_adam = torch.compile(foreach_map_adam)
159
160
# optionally view the output code
161
torch._logging.set_logs(output_code=True)
162
163
# Warmup runs to compile the function
164
for _ in range(5):
165
opt_eager.step()
166
compiled_adam(*inputs)
167
168
for eager_p, compile_p in zip(opt_eager.param_groups[0]["params"], opt_eager_copy.param_groups[0]["params"]):
169
torch.allclose(eager_p, compile_p)
170
171
# Benchmark performance
172
173
# Let's define a helpful benchmarking function:
174
import torch.utils.benchmark as benchmark
175
176
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
177
t0 = benchmark.Timer(
178
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
179
)
180
return t0.blocked_autorange().mean * 1e6
181
182
eager_runtime = benchmark_torch_function_in_microseconds(opt_eager.step)
183
compiled_runtime = benchmark_torch_function_in_microseconds(lambda: compiled_adam(*inputs))
184
185
assert eager_runtime > compiled_runtime
186
187
print(f"eager runtime: {eager_runtime}us")
188
print(f"compiled runtime: {compiled_runtime}us")
189
190
191
192
######################################################################
193
# Conclusion
194
# ~~~~~~~~~~
195
# In this tutorial, we successfully implemented a custom fully-fused Adam optimizer using foreach_map.
196
# By leveraging the power of foreach_map and torch.compile, we were able to create an optimized version of the Adam
197
# optimizer that can be used in various machine learning applications. This tutorial provides a comprehensive guide
198
# on how to use foreach_map and torch.compile to optimize machine learning models, and serves as a
199
# valuable resource for developers looking to improve the performance of their models with horizontal fusion.
200
#
201
# See also:
202
#
203
# * `Compiled optimizer tutorial <https://pytorch.org/tutorials/recipes/compiling_optimizer.html>`__ - an intro into the compiled optimizer.
204
# * `Compiling the optimizer with PT2 <https://dev-discuss.pytorch.org/t/compiling-the-optimizer-with-pt2/1669>`__ - deeper technical details on the compiled optimizer.
205
206