CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/recipes_source/regional_compilation.py
Views: 712
1
"""
2
Reducing torch.compile cold start compilation time with regional compilation
3
============================================================================
4
5
**Author:** `Animesh Jain <https://github.com/anijain2305>`_
6
7
As deep learning models get larger, the compilation time of these models also
8
increases. This extended compilation time can result in a large startup time in
9
inference services or wasted resources in large-scale training. This recipe
10
shows an example of how to reduce the cold start compilation time by choosing to
11
compile a repeated region of the model instead of the entire model.
12
13
Prerequisites
14
----------------
15
16
* Pytorch 2.5 or later
17
18
Setup
19
-----
20
Before we begin, we need to install ``torch`` if it is not already
21
available.
22
23
.. code-block:: sh
24
25
pip install torch
26
27
.. note::
28
This feature is available starting with the 2.5 release. If you are using version 2.4,
29
you can enable the configuration flag ``torch._dynamo.config.inline_inbuilt_nn_modules=True``
30
to prevent recompilations during regional compilation. In version 2.5, this flag is enabled by default.
31
"""
32
33
from time import perf_counter
34
35
######################################################################
36
# Steps
37
# -----
38
#
39
# In this recipe, we will follow these steps:
40
#
41
# 1. Import all necessary libraries.
42
# 2. Define and initialize a neural network with repeated regions.
43
# 3. Understand the difference between the full model and the regional compilation.
44
# 4. Measure the compilation time of the full model and the regional compilation.
45
#
46
# First, let's import the necessary libraries for loading our data:
47
#
48
#
49
#
50
51
import torch
52
import torch.nn as nn
53
54
55
##########################################################
56
# Next, let's define and initialize a neural network with repeated regions.
57
#
58
# Typically, neural networks are composed of repeated layers. For example, a
59
# large language model is composed of many Transformer blocks. In this recipe,
60
# we will create a ``Layer`` using the ``nn.Module`` class as a proxy for a repeated region.
61
# We will then create a ``Model`` which is composed of 64 instances of this
62
# ``Layer`` class.
63
#
64
class Layer(torch.nn.Module):
65
def __init__(self):
66
super().__init__()
67
self.linear1 = torch.nn.Linear(10, 10)
68
self.relu1 = torch.nn.ReLU()
69
self.linear2 = torch.nn.Linear(10, 10)
70
self.relu2 = torch.nn.ReLU()
71
72
def forward(self, x):
73
a = self.linear1(x)
74
a = self.relu1(a)
75
a = torch.sigmoid(a)
76
b = self.linear2(a)
77
b = self.relu2(b)
78
return b
79
80
81
class Model(torch.nn.Module):
82
def __init__(self, apply_regional_compilation):
83
super().__init__()
84
self.linear = torch.nn.Linear(10, 10)
85
# Apply compile only to the repeated layers.
86
if apply_regional_compilation:
87
self.layers = torch.nn.ModuleList(
88
[torch.compile(Layer()) for _ in range(64)]
89
)
90
else:
91
self.layers = torch.nn.ModuleList([Layer() for _ in range(64)])
92
93
def forward(self, x):
94
# In regional compilation, the self.linear is outside of the scope of `torch.compile`.
95
x = self.linear(x)
96
for layer in self.layers:
97
x = layer(x)
98
return x
99
100
101
####################################################
102
# Next, let's review the difference between the full model and the regional compilation.
103
#
104
# In full model compilation, the entire model is compiled as a whole. This is the common approach
105
# most users take with ``torch.compile``. In this example, we apply ``torch.compile`` to
106
# the ``Model`` object. This will effectively inline the 64 layers, producing a
107
# large graph to compile. You can look at the full graph by running this recipe
108
# with ``TORCH_LOGS=graph_code``.
109
#
110
#
111
112
model = Model(apply_regional_compilation=False).cuda()
113
full_compiled_model = torch.compile(model)
114
115
116
###################################################
117
# The regional compilation, on the other hand, compiles a region of the model.
118
# By strategically choosing to compile a repeated region of the model, we can compile a
119
# much smaller graph and then reuse the compiled graph for all the regions.
120
# In the example, ``torch.compile`` is applied only to the ``layers`` and not the full model.
121
#
122
123
regional_compiled_model = Model(apply_regional_compilation=True).cuda()
124
125
#####################################################
126
# Applying compilation to a repeated region, instead of full model, leads to
127
# large savings in compile time. Here, we will just compile a layer instance and
128
# then reuse it 64 times in the ``Model`` object.
129
#
130
# Note that with repeated regions, some part of the model might not be compiled.
131
# For example, the ``self.linear`` in the ``Model`` is outside of the scope of
132
# regional compilation.
133
#
134
# Also, note that there is a tradeoff between performance speedup and compile
135
# time. Full model compilation involves a larger graph and,
136
# theoretically, offers more scope for optimizations. However, for practical
137
# purposes and depending on the model, we have observed many cases with minimal
138
# speedup differences between the full model and regional compilation.
139
140
141
###################################################
142
# Next, let's measure the compilation time of the full model and the regional compilation.
143
#
144
# ``torch.compile`` is a JIT compiler, which means that it compiles on the first invocation.
145
# In the code below, we measure the total time spent in the first invocation. While this method is not
146
# precise, it provides a good estimate since the majority of the time is spent in
147
# compilation.
148
149
150
def measure_latency(fn, input):
151
# Reset the compiler caches to ensure no reuse between different runs
152
torch.compiler.reset()
153
with torch._inductor.utils.fresh_inductor_cache():
154
start = perf_counter()
155
fn(input)
156
torch.cuda.synchronize()
157
end = perf_counter()
158
return end - start
159
160
161
input = torch.randn(10, 10, device="cuda")
162
full_model_compilation_latency = measure_latency(full_compiled_model, input)
163
print(f"Full model compilation time = {full_model_compilation_latency:.2f} seconds")
164
165
regional_compilation_latency = measure_latency(regional_compiled_model, input)
166
print(f"Regional compilation time = {regional_compilation_latency:.2f} seconds")
167
168
assert regional_compilation_latency < full_model_compilation_latency
169
170
############################################################################
171
# Conclusion
172
# -----------
173
#
174
# This recipe shows how to control the cold start compilation time if your model
175
# has repeated regions. This approach requires user modifications to apply `torch.compile` to
176
# the repeated regions instead of more commonly used full model compilation. We
177
# are continually working on reducing cold start compilation time.
178
#
179
180