Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/recipes_source/compiling_optimizer_lr_scheduler.py
Views: 712
"""1(beta) Running the compiled optimizer with an LR Scheduler2============================================================34**Author:** `Michael Lazos <https://github.com/mlazos>`_5"""67#########################################################8# The optimizer is a key algorithm for training any deep learning model.9# In this example, we will show how to pair the optimizer, which has been compiled using ``torch.compile``,10# with the LR schedulers to accelerate training convergence.11#12# .. note::13#14# This tutorial requires PyTorch 2.3.0 or later.1516#####################################################################17# Model Setup18# ~~~~~~~~~~~~~~~~~~~~~19# For this example, we'll use a simple sequence of linear layers.20#2122import torch2324# Create simple model25model = torch.nn.Sequential(26*[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]27)28input = torch.rand(1024, device="cuda")2930# run forward pass31output = model(input)3233# run backward to populate the grads for our optimizer below34output.sum().backward()353637#####################################################################38# Setting up and running the compiled optimizer with LR Scheduler39# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~40#41# In this section, we'll use the Adam optimizer with LinearLR Scheduler42# and create a helper function to wrap the ``step()`` call for each of them43# in ``torch.compile()``.44#45# .. note::46#47# ``torch.compile`` is only supported on CUDA devices that have a compute capability of 7.0 or higher.484950# exit cleanly if we are on a device that doesn't support ``torch.compile``51if torch.cuda.get_device_capability() < (7, 0):52print("Exiting because torch.compile is not supported on this device.")53import sys54sys.exit(0)5556# !!! IMPORTANT !!! Wrap the lr in a Tensor if we are pairing the57# the optimizer with an LR Scheduler.58# Without this, torch.compile will recompile as the value of the LR59# changes.60opt = torch.optim.Adam(model.parameters(), lr=torch.tensor(0.01))61sched = torch.optim.lr_scheduler.LinearLR(opt, total_iters=5)6263@torch.compile(fullgraph=False)64def fn():65opt.step()66sched.step()676869# Warmup runs to compile the function70for _ in range(5):71fn()72print(opt.param_groups[0]["lr"])737475######################################################################76# Extension: What happens with a non-tensor LR?77# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~78# For the curious, we will show how to peek into what happens with ``torch.compile`` when we don't wrap the79# LR in a tensor.8081# No longer wrap the LR in a tensor here82opt = torch.optim.Adam(model.parameters(), lr=0.01)83sched = torch.optim.lr_scheduler.LinearLR(opt, total_iters=5)8485@torch.compile(fullgraph=False)86def fn():87opt.step()88sched.step()8990# Setup logging to view recompiles91torch._logging.set_logs(recompiles=True)9293# Warmup runs to compile the function94# We will now recompile on each iteration95# as the value of the lr is mutated.96for _ in range(5):97fn()9899100######################################################################101# With this example, we can see that we recompile the optimizer a few times102# due to the guard failure on the ``lr`` in ``param_groups[0]``.103104######################################################################105# Conclusion106# ~~~~~~~~~~107#108# In this tutorial we showed how to pair the optimizer compiled with ``torch.compile``109# with an LR Scheduler to accelerate training convergence. We used a model consisting110# of a simple sequence of linear layers with the Adam optimizer paired111# with a LinearLR scheduler to demonstrate the LR changing across iterations.112#113# See also:114#115# * `Compiled optimizer tutorial <https://pytorch.org/tutorials/recipes/compiling_optimizer.html>`__ - an intro into the compiled optimizer.116# * `Compiling the optimizer with PT2 <https://dev-discuss.pytorch.org/t/compiling-the-optimizer-with-pt2/1669>`__ - deeper technical details on the compiled optimizer.117118119