1"""2Reducing AoT cold start compilation time with regional compilation3============================================================================45**Author:** `Sayak Paul <https://huggingface.co/sayakpaul>`_, `Charles Bensimon <https://huggingface.co/cbensimon>`_, `Angela Yi <https://github.com/angelayi>`_67In the `regional compilation recipe <https://docs.pytorch.org/tutorials/recipes/regional_compilation.html>`__, we showed8how to reduce cold start compilation times while retaining (almost) full compilation benefits. This was demonstrated for9just-in-time (JIT) compilation.1011This recipe shows how to apply similar principles when compiling a model ahead-of-time (AoT). If you12are not familiar with AOTInductor and ``torch.export``, we recommend you to check out `this tutorial <https://docs.pytorch.org/tutorials/recipes/torch_export_aoti_python.html>`__.1314Prerequisites15----------------1617* Pytorch 2.6 or later18* Familiarity with regional compilation19* Familiarity with AOTInductor and ``torch.export``2021Setup22-----23Before we begin, we need to install ``torch`` if it is not already24available.2526.. code-block:: sh2728pip install torch29"""3031######################################################################32# Steps33# -----34#35# In this recipe, we will follow the same steps as the regional compilation recipe mentioned above:36#37# 1. Import all necessary libraries.38# 2. Define and initialize a neural network with repeated regions.39# 3. Measure the compilation time of the full model and the regional compilation with AoT.40#41# First, let's import the necessary libraries for loading our data:42#4344import torch45torch.set_grad_enabled(False)4647from time import perf_counter4849###################################################################################50# Defining the Neural Network51# ---------------------------52#53# We will use the same neural network structure as the regional compilation recipe.54#55# We will use a network, composed of repeated layers. This mimics a56# large language model, that typically is composed of many Transformer blocks. In this recipe,57# we will create a ``Layer`` using the ``nn.Module`` class as a proxy for a repeated region.58# We will then create a ``Model`` which is composed of 64 instances of this59# ``Layer`` class.60#61class Layer(torch.nn.Module):62def __init__(self):63super().__init__()64self.linear1 = torch.nn.Linear(10, 10)65self.relu1 = torch.nn.ReLU()66self.linear2 = torch.nn.Linear(10, 10)67self.relu2 = torch.nn.ReLU()6869def forward(self, x):70a = self.linear1(x)71a = self.relu1(a)72a = torch.sigmoid(a)73b = self.linear2(a)74b = self.relu2(b)75return b767778class Model(torch.nn.Module):79def __init__(self):80super().__init__()81self.linear = torch.nn.Linear(10, 10)82self.layers = torch.nn.ModuleList([Layer() for _ in range(64)])8384def forward(self, x):85# In regional compilation, the self.linear is outside of the scope of ``torch.compile``.86x = self.linear(x)87for layer in self.layers:88x = layer(x)89return x909192##################################################################################93# Compiling the model ahead-of-time94# ---------------------------------95#96# Since we're compiling the model ahead-of-time, we need to prepare representative97# input examples, that we expect the model to see during actual deployments.98#99# Let's create an instance of ``Model`` and pass it some sample input data.100#101102model = Model().cuda()103input = torch.randn(10, 10, device="cuda")104output = model(input)105print(f"{output.shape=}")106107###############################################################################################108# Now, let's compile our model ahead-of-time. We will use ``input`` created above to pass109# to ``torch.export``. This will yield a ``torch.export.ExportedProgram`` which we can compile.110111path = torch._inductor.aoti_compile_and_package(112torch.export.export(model, args=(input,))113)114115#################################################################116# We can load from this ``path`` and use it to perform inference.117118compiled_binary = torch._inductor.aoti_load_package(path)119output_compiled = compiled_binary(input)120print(f"{output_compiled.shape=}")121122######################################################################################123# Compiling _regions_ of the model ahead-of-time124# ----------------------------------------------125#126# Compiling model regions ahead-of-time, on the other hand, requires a few key changes.127#128# Since the compute pattern is shared by all the blocks that129# are repeated in a model (``Layer`` instances in this cases), we can just130# compile a single block and let the inductor reuse it.131132model = Model().cuda()133path = torch._inductor.aoti_compile_and_package(134torch.export.export(model.layers[0], args=(input,)),135inductor_configs={136# compile artifact w/o saving params in the artifact137"aot_inductor.package_constants_in_so": False,138}139)140141###################################################142# An exported program (``torch.export.ExportedProgram``) contains the Tensor computation,143# a ``state_dict`` containing tensor values of all lifted parameters and buffer alongside144# other metadata. We specify the ``aot_inductor.package_constants_in_so`` to be ``False`` to145# not serialize the model parameters in the generated artifact.146#147# Now, when loading the compiled binary, we can reuse the existing parameters of148# each block. This lets us take advantage of the compiled binary obtained above.149#150151for layer in model.layers:152compiled_layer = torch._inductor.aoti_load_package(path)153compiled_layer.load_constants(154layer.state_dict(), check_full_update=True, user_managed=True155)156layer.forward = compiled_layer157158output_regional_compiled = model(input)159print(f"{output_regional_compiled.shape=}")160161#####################################################162# Just like JIT regional compilation, compiling regions within a model ahead-of-time163# leads to significantly reduced cold start times. The actual number will vary from164# model to model.165#166# Even though full model compilation offers the fullest scope of optimizations,167# for practical purposes and depending on the type of model, we have seen regional168# compilation (both JiT and AoT) providing similar speed benefits, while drastically169# reducing the cold start times.170171###################################################172# Measuring compilation time173# --------------------------174# Next, let's measure the compilation time of the full model and the regional compilation.175#176177def measure_compile_time(input, regional=False):178start = perf_counter()179model = aot_compile_load_model(regional=regional)180torch.cuda.synchronize()181end = perf_counter()182# make sure the model works.183_ = model(input)184return end - start185186def aot_compile_load_model(regional=False) -> torch.nn.Module:187input = torch.randn(10, 10, device="cuda")188model = Model().cuda()189190inductor_configs = {}191if regional:192inductor_configs = {"aot_inductor.package_constants_in_so": False}193194# Reset the compiler caches to ensure no reuse between different runs195torch.compiler.reset()196with torch._inductor.utils.fresh_inductor_cache():197path = torch._inductor.aoti_compile_and_package(198torch.export.export(199model.layers[0] if regional else model,200args=(input,)201),202inductor_configs=inductor_configs,203)204205if regional:206for layer in model.layers:207compiled_layer = torch._inductor.aoti_load_package(path)208compiled_layer.load_constants(209layer.state_dict(), check_full_update=True, user_managed=True210)211layer.forward = compiled_layer212else:213model = torch._inductor.aoti_load_package(path)214return model215216input = torch.randn(10, 10, device="cuda")217full_model_compilation_latency = measure_compile_time(input, regional=False)218print(f"Full model compilation time = {full_model_compilation_latency:.2f} seconds")219220regional_compilation_latency = measure_compile_time(input, regional=True)221print(f"Regional compilation time = {regional_compilation_latency:.2f} seconds")222223assert regional_compilation_latency < full_model_compilation_latency224225############################################################################226# There may also be layers in a model incompatible with compilation. So,227# full compilation will result in a fragmented computation graph resulting228# in potential latency degradation. In these case, regional compilation229# can be beneficial.230#231232############################################################################233# Conclusion234# -----------235#236# This recipe shows how to control the cold start time when compiling your237# model ahead-of-time. This becomes effective when your model has repeated238# blocks, which is typically seen in large generative models. We used this239# recipe on various models to speed up real-time performance. Learn more240# `here <https://huggingface.co/blog/zerogpu-aoti>`__.241242243