CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/recipes_source/torch_export_aoti_python.py
Views: 494
1
# -*- coding: utf-8 -*-
2
3
"""
4
.. meta::
5
:description: An end-to-end example of how to use AOTInductor for Python runtime.
6
:keywords: torch.export, AOTInductor, torch._inductor.aot_compile, torch._export.aot_load
7
8
``torch.export`` AOTInductor Tutorial for Python runtime (Beta)
9
===============================================================
10
**Author:** Ankith Gunapal, Bin Bao, Angela Yi
11
"""
12
13
######################################################################
14
#
15
# .. warning::
16
#
17
# ``torch._inductor.aot_compile`` and ``torch._export.aot_load`` are in Beta status and are subject to backwards compatibility
18
# breaking changes. This tutorial provides an example of how to use these APIs for model deployment using Python runtime.
19
#
20
# It has been shown `previously <https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html#>`__ how AOTInductor can be used
21
# to do Ahead-of-Time compilation of PyTorch exported models by creating
22
# a shared library that can be run in a non-Python environment.
23
#
24
#
25
# In this tutorial, you will learn an end-to-end example of how to use AOTInductor for Python runtime.
26
# We will look at how to use :func:`torch._inductor.aot_compile` along with :func:`torch.export.export` to generate a
27
# shared library. Additionally, we will examine how to execute the shared library in Python runtime using :func:`torch._export.aot_load`.
28
# You will learn about the speed up seen in the first inference time using AOTInductor, especially when using
29
# ``max-autotune`` mode which can take some time to execute.
30
#
31
# **Contents**
32
#
33
# .. contents::
34
# :local:
35
36
######################################################################
37
# Prerequisites
38
# -------------
39
# * PyTorch 2.4 or later
40
# * Basic understanding of ``torch.export`` and AOTInductor
41
# * Complete the `AOTInductor: Ahead-Of-Time Compilation for Torch.Export-ed Models <https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html#>`_ tutorial
42
43
######################################################################
44
# What you will learn
45
# ----------------------
46
# * How to use AOTInductor for python runtime.
47
# * How to use :func:`torch._inductor.aot_compile` along with :func:`torch.export.export` to generate a shared library
48
# * How to run a shared library in Python runtime using :func:`torch._export.aot_load`.
49
# * When do you use AOTInductor for python runtime
50
51
######################################################################
52
# Model Compilation
53
# -----------------
54
#
55
# We will use the TorchVision pretrained `ResNet18` model and TorchInductor on the
56
# exported PyTorch program using :func:`torch._inductor.aot_compile`.
57
#
58
# .. note::
59
#
60
# This API also supports :func:`torch.compile` options like ``mode``
61
# This means that if used on a CUDA enabled device, you can, for example, set ``"max_autotune": True``
62
# which leverages Triton based matrix multiplications & convolutions, and enables CUDA graphs by default.
63
#
64
# We also specify ``dynamic_shapes`` for the batch dimension. In this example, ``min=2`` is not a bug and is
65
# explained in `The 0/1 Specialization Problem <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk>`__
66
67
68
import os
69
import torch
70
from torchvision.models import ResNet18_Weights, resnet18
71
72
model = resnet18(weights=ResNet18_Weights.DEFAULT)
73
model.eval()
74
75
with torch.inference_mode():
76
77
# Specify the generated shared library path
78
aot_compile_options = {
79
"aot_inductor.output_path": os.path.join(os.getcwd(), "resnet18_pt2.so"),
80
}
81
if torch.cuda.is_available():
82
device = "cuda"
83
aot_compile_options.update({"max_autotune": True})
84
else:
85
device = "cpu"
86
87
model = model.to(device=device)
88
example_inputs = (torch.randn(2, 3, 224, 224, device=device),)
89
90
# min=2 is not a bug and is explained in the 0/1 Specialization Problem
91
batch_dim = torch.export.Dim("batch", min=2, max=32)
92
exported_program = torch.export.export(
93
model,
94
example_inputs,
95
# Specify the first dimension of the input x as dynamic
96
dynamic_shapes={"x": {0: batch_dim}},
97
)
98
so_path = torch._inductor.aot_compile(
99
exported_program.module(),
100
example_inputs,
101
# Specify the generated shared library path
102
options=aot_compile_options
103
)
104
105
106
######################################################################
107
# Model Inference in Python
108
# -------------------------
109
#
110
# Typically, the shared object generated above is used in a non-Python environment. In PyTorch 2.3,
111
# we added a new API called :func:`torch._export.aot_load` to load the shared library in the Python runtime.
112
# The API follows a structure similar to the :func:`torch.jit.load` API . You need to specify the path
113
# of the shared library and the device where it should be loaded.
114
#
115
# .. note::
116
# In the example above, we specified ``batch_size=1`` for inference and it still functions correctly even though we specified ``min=2`` in
117
# :func:`torch.export.export`.
118
119
120
import os
121
import torch
122
123
device = "cuda" if torch.cuda.is_available() else "cpu"
124
model_so_path = os.path.join(os.getcwd(), "resnet18_pt2.so")
125
126
model = torch._export.aot_load(model_so_path, device)
127
example_inputs = (torch.randn(1, 3, 224, 224, device=device),)
128
129
with torch.inference_mode():
130
output = model(example_inputs)
131
132
######################################################################
133
# When to use AOTInductor for Python Runtime
134
# ------------------------------------------
135
#
136
# One of the requirements for using AOTInductor is that the model shouldn't have any graph breaks.
137
# Once this requirement is met, the primary use case for using AOTInductor Python Runtime is for
138
# model deployment using Python.
139
# There are mainly two reasons why you would use AOTInductor Python Runtime:
140
#
141
# - ``torch._inductor.aot_compile`` generates a shared library. This is useful for model
142
# versioning for deployments and tracking model performance over time.
143
# - With :func:`torch.compile` being a JIT compiler, there is a warmup
144
# cost associated with the first compilation. Your deployment needs to account for the
145
# compilation time taken for the first inference. With AOTInductor, the compilation is
146
# done offline using ``torch.export.export`` & ``torch._indutor.aot_compile``. The deployment
147
# would only load the shared library using ``torch._export.aot_load`` and run inference.
148
#
149
#
150
# The section below shows the speedup achieved with AOTInductor for first inference
151
#
152
# We define a utility function ``timed`` to measure the time taken for inference
153
#
154
155
import time
156
def timed(fn):
157
# Returns the result of running `fn()` and the time it took for `fn()` to run,
158
# in seconds. We use CUDA events and synchronization for accurate
159
# measurement on CUDA enabled devices.
160
if torch.cuda.is_available():
161
start = torch.cuda.Event(enable_timing=True)
162
end = torch.cuda.Event(enable_timing=True)
163
start.record()
164
else:
165
start = time.time()
166
167
result = fn()
168
if torch.cuda.is_available():
169
end.record()
170
torch.cuda.synchronize()
171
else:
172
end = time.time()
173
174
# Measure time taken to execute the function in miliseconds
175
if torch.cuda.is_available():
176
duration = start.elapsed_time(end)
177
else:
178
duration = (end - start) * 1000
179
180
return result, duration
181
182
183
######################################################################
184
# Lets measure the time for first inference using AOTInductor
185
186
torch._dynamo.reset()
187
188
model = torch._export.aot_load(model_so_path, device)
189
example_inputs = (torch.randn(1, 3, 224, 224, device=device),)
190
191
with torch.inference_mode():
192
_, time_taken = timed(lambda: model(example_inputs))
193
print(f"Time taken for first inference for AOTInductor is {time_taken:.2f} ms")
194
195
196
######################################################################
197
# Lets measure the time for first inference using ``torch.compile``
198
199
torch._dynamo.reset()
200
201
model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device)
202
model.eval()
203
204
model = torch.compile(model)
205
example_inputs = torch.randn(1, 3, 224, 224, device=device)
206
207
with torch.inference_mode():
208
_, time_taken = timed(lambda: model(example_inputs))
209
print(f"Time taken for first inference for torch.compile is {time_taken:.2f} ms")
210
211
######################################################################
212
# We see that there is a drastic speedup in first inference time using AOTInductor compared
213
# to ``torch.compile``
214
215
######################################################################
216
# Conclusion
217
# ----------
218
#
219
# In this recipe, we have learned how to effectively use the AOTInductor for Python runtime by
220
# compiling and loading a pretrained ``ResNet18`` model using the ``torch._inductor.aot_compile``
221
# and ``torch._export.aot_load`` APIs. This process demonstrates the practical application of
222
# generating a shared library and running it within a Python environment, even with dynamic shape
223
# considerations and device-specific optimizations. We also looked at the advantage of using
224
# AOTInductor in model deployments, with regards to speed up in first inference time.
225
226