Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/recipes_source/regional_compilation.py
Views: 712
"""1Reducing torch.compile cold start compilation time with regional compilation2============================================================================34**Author:** `Animesh Jain <https://github.com/anijain2305>`_56As deep learning models get larger, the compilation time of these models also7increases. This extended compilation time can result in a large startup time in8inference services or wasted resources in large-scale training. This recipe9shows an example of how to reduce the cold start compilation time by choosing to10compile a repeated region of the model instead of the entire model.1112Prerequisites13----------------1415* Pytorch 2.5 or later1617Setup18-----19Before we begin, we need to install ``torch`` if it is not already20available.2122.. code-block:: sh2324pip install torch2526.. note::27This feature is available starting with the 2.5 release. If you are using version 2.4,28you can enable the configuration flag ``torch._dynamo.config.inline_inbuilt_nn_modules=True``29to prevent recompilations during regional compilation. In version 2.5, this flag is enabled by default.30"""3132from time import perf_counter3334######################################################################35# Steps36# -----37#38# In this recipe, we will follow these steps:39#40# 1. Import all necessary libraries.41# 2. Define and initialize a neural network with repeated regions.42# 3. Understand the difference between the full model and the regional compilation.43# 4. Measure the compilation time of the full model and the regional compilation.44#45# First, let's import the necessary libraries for loading our data:46#47#48#4950import torch51import torch.nn as nn525354##########################################################55# Next, let's define and initialize a neural network with repeated regions.56#57# Typically, neural networks are composed of repeated layers. For example, a58# large language model is composed of many Transformer blocks. In this recipe,59# we will create a ``Layer`` using the ``nn.Module`` class as a proxy for a repeated region.60# We will then create a ``Model`` which is composed of 64 instances of this61# ``Layer`` class.62#63class Layer(torch.nn.Module):64def __init__(self):65super().__init__()66self.linear1 = torch.nn.Linear(10, 10)67self.relu1 = torch.nn.ReLU()68self.linear2 = torch.nn.Linear(10, 10)69self.relu2 = torch.nn.ReLU()7071def forward(self, x):72a = self.linear1(x)73a = self.relu1(a)74a = torch.sigmoid(a)75b = self.linear2(a)76b = self.relu2(b)77return b787980class Model(torch.nn.Module):81def __init__(self, apply_regional_compilation):82super().__init__()83self.linear = torch.nn.Linear(10, 10)84# Apply compile only to the repeated layers.85if apply_regional_compilation:86self.layers = torch.nn.ModuleList(87[torch.compile(Layer()) for _ in range(64)]88)89else:90self.layers = torch.nn.ModuleList([Layer() for _ in range(64)])9192def forward(self, x):93# In regional compilation, the self.linear is outside of the scope of `torch.compile`.94x = self.linear(x)95for layer in self.layers:96x = layer(x)97return x9899100####################################################101# Next, let's review the difference between the full model and the regional compilation.102#103# In full model compilation, the entire model is compiled as a whole. This is the common approach104# most users take with ``torch.compile``. In this example, we apply ``torch.compile`` to105# the ``Model`` object. This will effectively inline the 64 layers, producing a106# large graph to compile. You can look at the full graph by running this recipe107# with ``TORCH_LOGS=graph_code``.108#109#110111model = Model(apply_regional_compilation=False).cuda()112full_compiled_model = torch.compile(model)113114115###################################################116# The regional compilation, on the other hand, compiles a region of the model.117# By strategically choosing to compile a repeated region of the model, we can compile a118# much smaller graph and then reuse the compiled graph for all the regions.119# In the example, ``torch.compile`` is applied only to the ``layers`` and not the full model.120#121122regional_compiled_model = Model(apply_regional_compilation=True).cuda()123124#####################################################125# Applying compilation to a repeated region, instead of full model, leads to126# large savings in compile time. Here, we will just compile a layer instance and127# then reuse it 64 times in the ``Model`` object.128#129# Note that with repeated regions, some part of the model might not be compiled.130# For example, the ``self.linear`` in the ``Model`` is outside of the scope of131# regional compilation.132#133# Also, note that there is a tradeoff between performance speedup and compile134# time. Full model compilation involves a larger graph and,135# theoretically, offers more scope for optimizations. However, for practical136# purposes and depending on the model, we have observed many cases with minimal137# speedup differences between the full model and regional compilation.138139140###################################################141# Next, let's measure the compilation time of the full model and the regional compilation.142#143# ``torch.compile`` is a JIT compiler, which means that it compiles on the first invocation.144# In the code below, we measure the total time spent in the first invocation. While this method is not145# precise, it provides a good estimate since the majority of the time is spent in146# compilation.147148149def measure_latency(fn, input):150# Reset the compiler caches to ensure no reuse between different runs151torch.compiler.reset()152with torch._inductor.utils.fresh_inductor_cache():153start = perf_counter()154fn(input)155torch.cuda.synchronize()156end = perf_counter()157return end - start158159160input = torch.randn(10, 10, device="cuda")161full_model_compilation_latency = measure_latency(full_compiled_model, input)162print(f"Full model compilation time = {full_model_compilation_latency:.2f} seconds")163164regional_compilation_latency = measure_latency(regional_compiled_model, input)165print(f"Regional compilation time = {regional_compilation_latency:.2f} seconds")166167assert regional_compilation_latency < full_model_compilation_latency168169############################################################################170# Conclusion171# -----------172#173# This recipe shows how to control the cold start compilation time if your model174# has repeated regions. This approach requires user modifications to apply `torch.compile` to175# the repeated regions instead of more commonly used full model compilation. We176# are continually working on reducing cold start compilation time.177#178179180