Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/recipes_source/foreach_map.py
Views: 1191
"""1Explicit horizontal fusion with foreach_map and torch.compile2===============================================================34**Author:** `Michael Lazos <https://github.com/mlazos>`_5"""67#########################################################8# Horizontal fusion is a key optimization in ML compilers. In eager,9# this is typically expressed using the torch._foreach* ops which parallelizes10# operations across a list of tensors. However, supporting all possible permutations11# of arguments is quite difficult (e.g. mixtures of scalars and lists). Foreach_map12# allows conversion of any pointwise op in ``torch`` to a horiztonally fused foreach13# variant. In this tutorial, we will demonstrate how to implement the Adam optimizer14# with ``foreach_map`` to generate a fully fused kernel.15#16# .. note::17#18# This recipe describes a prototype feature. Prototype features are typically19# at an early stage for feedback and testing and are subject to change.20#21# Prerequisites22# -------------23#24# * PyTorch v2.7.0 or later25#2627#####################################################################28# Model Setup29# ~~~~~~~~~~~~~~~~~~~~~30# For this example, we'll use a simple sequence of linear layers.31# We instantiate an independent copy to compare the two optimizer implementations.32#33import torch3435# exit cleanly if we are on a device that doesn't support ``torch.compile``36if torch.cuda.get_device_capability() < (7, 0):37print("Exiting because torch.compile is not supported on this device.")38import sys39sys.exit(0)4041# Create simple model42model = torch.nn.Sequential(43*[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]44)45model_copy = torch.nn.Sequential(46*[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]47)48input = torch.rand(1024, device="cuda")4950# run forward pass51output = model(input)52output_copy = model_copy(input)5354# run backward to populate the grads for our optimizer below55output.sum().backward()56output_copy.sum().backward()5758#####################################################################59# Helper functions for foreach_map implementation60# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~61#62# In this section, we'll begin our implementation of the Adam optimizer.63#64from torch._higher_order_ops.foreach_map import foreach_map6566# Helper function to extract optimizer states from a torch.optim.Adam instance67def get_inputs(optim):68steps = []69params = []70grads = []71exp_avgs = []72exp_avg_sqs = []73for group in optim.param_groups:74for p in group["params"]:75params.append(p)76grads.append(p.grad)77state = optim.state[p]78exp_avgs.append(state["exp_avg"])79exp_avg_sqs.append(state["exp_avg_sq"])80steps.append(state["step"])8182return steps, params, exp_avgs, exp_avg_sqs838485# Functions to update the different optimizer states86def update_exp_avg_sq(exp_avg_sq, grad, beta2):87return exp_avg_sq.mul(beta2).addcmul(grad, grad, value=1 - beta2)8889def update_param(param, step, exp_avg, exp_avg_sq, beta1, beta2, lr, eps):90bias_correction1 = 1 - torch.pow(beta1, step)91bias_correction2 = (1 - torch.pow(beta2, step)).sqrt()92step_size = (lr / bias_correction1).neg()93denom = (exp_avg_sq.sqrt() / (bias_correction2 * step_size)).add(eps / step_size)94return torch.add(param, torch.div(exp_avg, denom))9596# Our full Adam implementation97def foreach_map_adam(98steps,99params,100exp_avgs,101exp_avg_sqs,102weight_decay=0,103beta1=0.9,104beta2=0.999,105lr=1e-3,106eps=1e-8,107):108with torch.no_grad():109grads = [param.grad for param in params]110# update step111updated_steps = foreach_map(lambda x: x + 1, steps)112torch._foreach_copy_(steps, updated_steps)113114if weight_decay != 0:115foreach_map(torch.add, (grads,), alpha=weight_decay)116117# Higher-order operators (HOPs) cannot have multiple outputs at the moment118# need to call foreach_map once for each output119exp_avgs_updated = foreach_map(torch.lerp, exp_avgs, grads, 1 - beta1)120exp_avgs_sq_updated = foreach_map(update_exp_avg_sq, exp_avg_sqs, grads, beta2)121params_updated = foreach_map(122update_param,123params,124steps,125exp_avgs_updated,126exp_avgs_sq_updated,127beta1,128beta2,129lr,130eps,131)132# Higher-order operators (HOPs) don't support input mutation today133# so manually update the states in-place134torch._foreach_copy_(exp_avgs, exp_avgs_updated)135torch._foreach_copy_(exp_avg_sqs, exp_avgs_sq_updated)136torch._foreach_copy_(params, params_updated)137return138139#####################################################################140# Setting up and running the compiled kernel141# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~142#143# In this section, we'll run our Adam optimizer144# and compare the results145#146# .. note::147#148# ``torch.compile`` is only supported on CUDA devices that have a compute capability of 7.0 or higher.149opt_eager = torch.optim.Adam(model.parameters(), lr=torch.tensor(0.01))150opt_eager_copy = torch.optim.Adam(model_copy.parameters(), lr=torch.tensor(0.01))151152# warm up the optimizer state dict153opt_eager.step()154opt_eager_copy.step()155156inputs = get_inputs(opt_eager_copy)157compiled_adam = torch.compile(foreach_map_adam)158159# optionally view the output code160torch._logging.set_logs(output_code=True)161162# Warmup runs to compile the function163for _ in range(5):164opt_eager.step()165compiled_adam(*inputs)166167for eager_p, compile_p in zip(opt_eager.param_groups[0]["params"], opt_eager_copy.param_groups[0]["params"]):168torch.allclose(eager_p, compile_p)169170# Benchmark performance171172# Let's define a helpful benchmarking function:173import torch.utils.benchmark as benchmark174175def benchmark_torch_function_in_microseconds(f, *args, **kwargs):176t0 = benchmark.Timer(177stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}178)179return t0.blocked_autorange().mean * 1e6180181eager_runtime = benchmark_torch_function_in_microseconds(opt_eager.step)182compiled_runtime = benchmark_torch_function_in_microseconds(lambda: compiled_adam(*inputs))183184assert eager_runtime > compiled_runtime185186print(f"eager runtime: {eager_runtime}us")187print(f"compiled runtime: {compiled_runtime}us")188189190191######################################################################192# Conclusion193# ~~~~~~~~~~194# In this tutorial, we successfully implemented a custom fully-fused Adam optimizer using foreach_map.195# By leveraging the power of foreach_map and torch.compile, we were able to create an optimized version of the Adam196# optimizer that can be used in various machine learning applications. This tutorial provides a comprehensive guide197# on how to use foreach_map and torch.compile to optimize machine learning models, and serves as a198# valuable resource for developers looking to improve the performance of their models with horizontal fusion.199#200# See also:201#202# * `Compiled optimizer tutorial <https://pytorch.org/tutorials/recipes/compiling_optimizer.html>`__ - an intro into the compiled optimizer.203# * `Compiling the optimizer with PT2 <https://dev-discuss.pytorch.org/t/compiling-the-optimizer-with-pt2/1669>`__ - deeper technical details on the compiled optimizer.204205206