CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
Path: blob/main/intermediate_source/_torch_export_nightly_tutorial.py
Views: 494
# -*- coding: utf-8 -*-12"""3torch.export Nightly Tutorial4================5**Author:** William Wen, Zhengxu Chen, Angela Yi6"""78######################################################################9#10# .. warning::11#12# ``torch.export`` and its related features are in prototype status and are subject to backwards compatibility13# breaking changes. This tutorial provides a snapshot of ``torch.export`` usage as of PyTorch 2.1.14#15# :func:`torch.export` is the PyTorch 2.X way to export PyTorch models into16# standardized model representations, intended17# to be run on different (i.e. Python-less) environments.18#19# In this tutorial, you will learn how to use :func:`torch.export` to extract20# ``ExportedProgram``'s (i.e. single-graph representations) from PyTorch programs.21# We also detail some considerations/modifications that you may need22# to make in order to make your model compatible with ``torch.export``.23#24# **Contents**25#26# .. contents::27# :local:2829######################################################################30# Basic Usage31# -----------32#33# ``torch.export`` extracts single-graph representations from PyTorch programs34# by tracing the target function, given example inputs.35# ``torch.export.export()`` is the main entry point for ``torch.export``.36#37# In this tutorial, ``torch.export`` and ``torch.export.export()`` are practically synonymous,38# though ``torch.export`` generally refers to the PyTorch 2.X export process, and ``torch.export.export()``39# generally refers to the actual function call.40#41# The signature of ``torch.export.export()`` is:42#43# .. code:: python44#45# export(46# f: Callable,47# args: Tuple[Any, ...],48# kwargs: Optional[Dict[str, Any]] = None,49# *,50# dynamic_shapes: Optional[Dict[str, Dict[int, Dim]]] = None51# ) -> ExportedProgram52#53# ``torch.export.export()`` traces the tensor computation graph from calling ``f(*args, **kwargs)``54# and wraps it in an ``ExportedProgram``, which can be serialized or executed later with55# different inputs. Note that while the output ``ExportedGraph`` is callable and can be56# called in the same way as the original input callable, it is not a ``torch.nn.Module``.57# We will detail the ``dynamic_shapes`` argument later in the tutorial.5859import torch60from torch.export import export6162class MyModule(torch.nn.Module):63def __init__(self):64super().__init__()65self.lin = torch.nn.Linear(100, 10)6667def forward(self, x, y):68return torch.nn.functional.relu(self.lin(x + y), inplace=True)6970mod = MyModule()71exported_mod = export(mod, (torch.randn(8, 100), torch.randn(8, 100)))72print(type(exported_mod))73print(exported_mod(torch.randn(8, 100), torch.randn(8, 100)))7475######################################################################76# Let's review some attributes of ``ExportedProgram`` that are of interest.77#78# The ``graph`` attribute is an `FX graph <https://pytorch.org/docs/stable/fx.html#torch.fx.Graph>`__79# traced from the function we exported, that is, the computation graph of all PyTorch operations.80# The FX graph has some important properties:81#82# - The operations are "ATen-level" operations.83# - The graph is "functionalized", meaning that no operations are mutations.84#85# The ``graph_module`` attribute is the ``GraphModule`` that wraps the ``graph`` attribute86# so that it can be ran as a ``torch.nn.Module``.8788print(exported_mod)89print(exported_mod.graph_module)9091######################################################################92# The printed code shows that FX graph only contains ATen-level ops (such as ``torch.ops.aten``)93# and that mutations were removed. For example, the mutating op ``torch.nn.functional.relu(..., inplace=True)``94# is represented in the printed code by ``torch.ops.aten.relu.default``, which does not mutate.95# Future uses of input to the original mutating ``relu`` op are replaced by the additional new output96# of the replacement non-mutating ``relu`` op.97#98# Other attributes of interest in ``ExportedProgram`` include:99#100# - ``graph_signature`` -- the inputs, outputs, parameters, buffers, etc. of the exported graph.101# - ``range_constraints`` and ``equality_constraints`` -- constraints, covered later102103print(exported_mod.graph_signature)104105######################################################################106# See the ``torch.export`` `documentation <https://pytorch.org/docs/main/export.html#torch.export.export>`__107# for more details.108109######################################################################110# Graph Breaks111# ------------112#113# Although ``torch.export`` shares components with ``torch.compile``,114# the key limitation of ``torch.export``, especially when compared to ``torch.compile``, is that it does not115# support graph breaks. This is because handling graph breaks involves interpreting116# the unsupported operation with default Python evaluation, which is incompatible117# with the export use case. Therefore, in order to make your model code compatible118# with ``torch.export``, you will need to modify your code to remove graph breaks.119#120# A graph break is necessary in cases such as:121#122# - data-dependent control flow123124def bad1(x):125if x.sum() > 0:126return torch.sin(x)127return torch.cos(x)128129import traceback as tb130try:131export(bad1, (torch.randn(3, 3),))132except Exception:133tb.print_exc()134135######################################################################136# - accessing tensor data with ``.data``137138def bad2(x):139x.data[0, 0] = 3140return x141142try:143export(bad2, (torch.randn(3, 3),))144except Exception:145tb.print_exc()146147######################################################################148# - calling unsupported functions (such as many built-in functions)149150def bad3(x):151x = x + 1152return x + id(x)153154try:155export(bad3, (torch.randn(3, 3),))156except Exception:157tb.print_exc()158159######################################################################160# - unsupported Python language features (e.g. throwing exceptions, match statements)161162def bad4(x):163try:164x = x + 1165raise RuntimeError("bad")166except:167x = x + 2168return x169170try:171export(bad4, (torch.randn(3, 3),))172except Exception:173tb.print_exc()174175######################################################################176# The sections below demonstrate some ways you can modify your code177# in order to remove graph breaks.178179######################################################################180# Control Flow Ops181# ----------------182#183# ``torch.export`` actually does support data-dependent control flow.184# But these need to be expressed using control flow ops. For example,185# we can fix the control flow example above using the ``cond`` op, like so:186187from functorch.experimental.control_flow import cond188189def bad1_fixed(x):190def true_fn(x):191return torch.sin(x)192def false_fn(x):193return torch.cos(x)194return cond(x.sum() > 0, true_fn, false_fn, [x])195196exported_bad1_fixed = export(bad1_fixed, (torch.randn(3, 3),))197print(exported_bad1_fixed(torch.ones(3, 3)))198print(exported_bad1_fixed(-torch.ones(3, 3)))199200######################################################################201# There are limitations to ``cond`` that one should be aware of:202#203# - The predicate (i.e. ``x.sum() > 0``) must result in a boolean or a single-element tensor.204# - The operands (i.e. ``[x]``) must be tensors.205# - The branch function (i.e. ``true_fn`` and ``false_fn``) signature must match with the206# operands and they must both return a single tensor with the same metadata (for example, ``dtype``, ``shape``, etc.).207# - Branch functions cannot mutate input or global variables.208# - Branch functions cannot access closure variables, except for ``self`` if the function is209# defined in the scope of a method.210#211# For more details about ``cond``, check out the `documentation <https://pytorch.org/docs/main/cond.html>`__.212213######################################################################214# ..215# [NOTE] map is not documented at the moment216# We can also use ``map``, which applies a function across the first dimension217# of the first tensor argument.218#219# from functorch.experimental.control_flow import map220#221# def map_example(xs):222# def map_fn(x, const):223# def true_fn(x):224# return x + const225# def false_fn(x):226# return x - const227# return control_flow.cond(x.sum() > 0, true_fn, false_fn, [x])228# return control_flow.map(map_fn, xs, torch.tensor([2.0]))229#230# exported_map_example= export(map_example, (torch.randn(4, 3),))231# inp = torch.cat((torch.ones(2, 3), -torch.ones(2, 3)))232# print(exported_map_example(inp))233234######################################################################235# Constraints/Dynamic Shapes236# --------------------------237#238# Ops can have different specializations/behaviors for different tensor shapes, so by default,239# ``torch.export`` requires inputs to ``ExportedProgram`` to have the same shape as the respective240# example inputs given to the initial ``torch.export.export()`` call.241# If we try to run the ``ExportedProgram`` in the example below with a tensor242# with a different shape, we get an error:243244class MyModule2(torch.nn.Module):245def __init__(self):246super().__init__()247self.lin = torch.nn.Linear(100, 10)248249def forward(self, x, y):250return torch.nn.functional.relu(self.lin(x + y), inplace=True)251252mod2 = MyModule2()253exported_mod2 = export(mod2, (torch.randn(8, 100), torch.randn(8, 100)))254255try:256exported_mod2(torch.randn(10, 100), torch.randn(10, 100))257except Exception:258tb.print_exc()259260######################################################################261# We can relax this constraint using the ``dynamic_shapes`` argument of262# ``torch.export.export()``, which allows us to specify, using ``torch.export.Dim``263# (`documentation <https://pytorch.org/docs/main/export.html#torch.export.Dim>`__),264# which dimensions of the input tensors are dynamic.265#266# For each tensor argument of the input callable, we can specify a mapping from the dimension267# to a ``torch.export.Dim``.268# A ``torch.export.Dim`` is essentially a named symbolic integer with optional269# minimum and maximum bounds.270#271# Then, the format of ``torch.export.export()``'s ``dynamic_shapes`` argument is a mapping272# from the input callable's tensor argument names, to dimension --> dim mappings as described above.273# If there is no ``torch.export.Dim`` given to a tensor argument's dimension, then that dimension is274# assumed to be static.275#276# The first argument of ``torch.export.Dim`` is the name for the symbolic integer, used for debugging.277# Then we can specify an optional minimum and maximum bound (inclusive). Below, we show example usage.278#279# In the example below, our input280# ``inp1`` has an unconstrained first dimension, but the size of the second281# dimension must be in the interval [4, 18].282283from torch.export import Dim284285inp1 = torch.randn(10, 10, 2)286287def dynamic_shapes_example1(x):288x = x[:, 2:]289return torch.relu(x)290291inp1_dim0 = Dim("inp1_dim0")292inp1_dim1 = Dim("inp1_dim1", min=4, max=18)293dynamic_shapes1 = {294"x": {0: inp1_dim0, 1: inp1_dim1},295}296297exported_dynamic_shapes_example1 = export(dynamic_shapes_example1, (inp1,), dynamic_shapes=dynamic_shapes1)298299print(exported_dynamic_shapes_example1(torch.randn(5, 5, 2)))300301try:302exported_dynamic_shapes_example1(torch.randn(8, 1, 2))303except Exception:304tb.print_exc()305306try:307exported_dynamic_shapes_example1(torch.randn(8, 20, 2))308except Exception:309tb.print_exc()310311try:312exported_dynamic_shapes_example1(torch.randn(8, 8, 3))313except Exception:314tb.print_exc()315316######################################################################317# Note that if our example inputs to ``torch.export`` do not satisfy the constraints318# given by ``dynamic_shapes``, then we get an error.319320inp1_dim1_bad = Dim("inp1_dim1_bad", min=11, max=18)321dynamic_shapes1_bad = {322"x": {0: inp1_dim0, 1: inp1_dim1_bad},323}324325try:326export(dynamic_shapes_example1, (inp1,), dynamic_shapes=dynamic_shapes1_bad)327except Exception:328tb.print_exc()329330######################################################################331# We can enforce that equalities between dimensions of different tensors332# by using the same ``torch.export.Dim`` object, for example, in matrix multiplication:333334inp2 = torch.randn(4, 8)335inp3 = torch.randn(8, 2)336337def dynamic_shapes_example2(x, y):338return x @ y339340inp2_dim0 = Dim("inp2_dim0")341inner_dim = Dim("inner_dim")342inp3_dim1 = Dim("inp3_dim1")343344dynamic_shapes2 = {345"x": {0: inp2_dim0, 1: inner_dim},346"y": {0: inner_dim, 1: inp3_dim1},347}348349exported_dynamic_shapes_example2 = export(dynamic_shapes_example2, (inp2, inp3), dynamic_shapes=dynamic_shapes2)350351print(exported_dynamic_shapes_example2(torch.randn(2, 16), torch.randn(16, 4)))352353try:354exported_dynamic_shapes_example2(torch.randn(4, 8), torch.randn(4, 2))355except Exception:356tb.print_exc()357358######################################################################359# We can actually use ``torch.export`` to guide us as to which ``dynamic_shapes`` constraints360# are necessary. We can do this by relaxing all constraints (recall that if we361# do not provide constraints for a dimension, the default behavior is to constrain362# to the exact shape value of the example input) and letting ``torch.export``363# error out.364365inp4 = torch.randn(8, 16)366inp5 = torch.randn(16, 32)367368def dynamic_shapes_example3(x, y):369if x.shape[0] <= 16:370return x @ y[:, :16]371return y372373dynamic_shapes3 = {374"x": {i: Dim(f"inp4_dim{i}") for i in range(inp4.dim())},375"y": {i: Dim(f"inp5_dim{i}") for i in range(inp5.dim())},376}377378try:379export(dynamic_shapes_example3, (inp4, inp5), dynamic_shapes=dynamic_shapes3)380except Exception:381tb.print_exc()382383######################################################################384# We can see that the error message gives us suggested fixes to our385# dynamic shape constraints. Let us follow those suggestions (exact386# suggestions may differ slightly):387388def suggested_fixes():389inp4_dim1 = Dim('shared_dim')390# suggested fixes below391inp4_dim0 = Dim('inp4_dim0', max=16)392inp5_dim1 = Dim('inp5_dim1', min=17)393inp5_dim0 = inp4_dim1394# end of suggested fixes395return {396"x": {0: inp4_dim0, 1: inp4_dim1},397"y": {0: inp5_dim0, 1: inp5_dim1},398}399400dynamic_shapes3_fixed = suggested_fixes()401exported_dynamic_shapes_example3 = export(dynamic_shapes_example3, (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed)402print(exported_dynamic_shapes_example3(torch.randn(4, 32), torch.randn(32, 64)))403404######################################################################405# Note that in the example above, because we constrained the value of ``x.shape[0]`` in406# ``dynamic_shapes_example3``, the exported program is sound even though there is a407# raw ``if`` statement.408#409# If you want to see why ``torch.export`` generated these constraints, you can410# re-run the script with the environment variable ``TORCH_LOGS=dynamic,dynamo``,411# or use ``torch._logging.set_logs``.412413import logging414torch._logging.set_logs(dynamic=logging.INFO, dynamo=logging.INFO)415exported_dynamic_shapes_example3 = export(dynamic_shapes_example3, (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed)416417# reset to previous values418torch._logging.set_logs(dynamic=logging.WARNING, dynamo=logging.WARNING)419420######################################################################421# We can view an ``ExportedProgram``'s constraints using the ``range_constraints`` and422# ``equality_constraints`` attributes. The logging above reveals what the symbols ``s0, s1, ...``423# represent.424425print(exported_dynamic_shapes_example3.range_constraints)426print(exported_dynamic_shapes_example3.equality_constraints)427428######################################################################429# Custom Ops430# ----------431#432# ``torch.export`` can export PyTorch programs with custom operators.433#434#435# Currently, the steps to register a custom op for use by ``torch.export`` are:436#437# - If you’re writing custom ops purely in Python, use torch.library.custom_op.438439import torch.library440import numpy as np441442@torch.library.custom_op("mylib::sin", mutates_args=())443def sin(x):444x_np = x.numpy()445y_np = np.sin(x_np)446return torch.from_numpy(y_np)447448######################################################################449# - You will need to provide abstract implementation so that PT2 can trace through it.450451@torch.library.register_fake("mylib::sin")452def _(x):453return torch.empty_like(x)454455# - Sometimes, the custom op you are exporting has data-dependent output, meaning456# we can't determine the shape of the output at compile time. In this case, you can do457# following:458@torch.library.custom_op("mylib::nonzero", mutates_args=())459def nonzero(x):460x_np = x.cpu().numpy()461res = np.stack(np.nonzero(x_np), axis=1)462return torch.tensor(res, device=x.device)463464@torch.library.register_fake("mylib::nonzero")465def _(x):466# The number of nonzero-elements is data-dependent.467# Since we cannot peek at the data in an abstract implementation,468# we use the `ctx` object to construct a new ``symint`` that469# represents the data-dependent size.470ctx = torch.library.get_ctx()471nnz = ctx.new_dynamic_size()472shape = [nnz, x.dim()]473result = x.new_empty(shape, dtype=torch.int64)474return result475476######################################################################477# - Call the custom op from the code you want to export using ``torch.ops``478479def custom_op_example(x):480x = torch.sin(x)481x = torch.ops.mylib.sin(x)482x = torch.cos(x)483y = torch.ops.mylib.nonzero(x)484return x + y.sum()485486######################################################################487# - Export the code as before488489exported_custom_op_example = export(custom_op_example, (torch.randn(3, 3),))490exported_custom_op_example.graph_module.print_readable()491print(exported_custom_op_example(torch.randn(3, 3)))492493######################################################################494# Note in the above outputs that the custom op is included in the exported graph.495# And when we call the exported graph as a function, the original custom op is called,496# as evidenced by the ``print`` call.497#498# If you have a custom operator implemented in C++, please refer to499# `this document <https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ahugy69p2jmz>`__500# to make it compatible with ``torch.export``.501502######################################################################503# Decompositions504# --------------505#506# The graph produced by ``torch.export`` by default returns a graph containing507# only functional ATen operators. This functional ATen operator set (or "opset") contains around 2000508# operators, all of which are functional, that is, they do not509# mutate or alias inputs. You can find a list of all ATen operators510# `here <https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml>`__511# and you can inspect if an operator is functional by checking512# ``op._schema.is_mutable``, for example:513514print(torch.ops.aten.add.Tensor._schema.is_mutable)515print(torch.ops.aten.add_.Tensor._schema.is_mutable)516517######################################################################518# By default, the environment in which you want to run the exported graph519# should support all ~2000 of these operators.520# However, you can use the following API on the exported program521# if your specific environment is only able to support a subset of522# the ~2000 operators.523#524# .. code:: python525#526# def run_decompositions(527# self: ExportedProgram,528# decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]]529# ) -> ExportedProgram530#531# ``run_decompositions`` takes in a decomposition table, which is a mapping of532# operators to a function specifying how to reduce, or decompose, that operator533# into an equivalent sequence of other ATen operators.534#535# The default decomposition table for ``run_decompositions`` is the536# `Core ATen decomposition table <https://github.com/pytorch/pytorch/blob/b460c3089367f3fadd40aa2cb3808ee370aa61e1/torch/_decomp/__init__.py#L252>`__537# which will decompose the all ATen operators to the538# `Core ATen Operator Set <https://pytorch.org/docs/main/torch.compiler_ir.html#core-aten-ir>`__539# which consists of only ~180 operators.540541class M(torch.nn.Module):542def __init__(self):543super().__init__()544self.linear = torch.nn.Linear(3, 4)545546def forward(self, x):547return self.linear(x)548549ep = export(M(), (torch.randn(2, 3),))550print(ep.graph)551552core_ir_ep = ep.run_decompositions()553print(core_ir_ep.graph)554555######################################################################556# Notice that after running ``run_decompositions`` the557# ``torch.ops.aten.t.default`` operator, which is not part of the Core ATen558# Opset, has been replaced with ``torch.ops.aten.permute.default`` which is part559# of the Core ATen Opset.560561######################################################################562# Most ATen operators already have decompositions, which are located563# `here <https://github.com/pytorch/pytorch/blob/b460c3089367f3fadd40aa2cb3808ee370aa61e1/torch/_decomp/decompositions.py>`__.564# If you would like to use some of these existing decomposition functions,565# you can pass in a list of operators you would like to decompose to the566# `get_decompositions <https://github.com/pytorch/pytorch/blob/b460c3089367f3fadd40aa2cb3808ee370aa61e1/torch/_decomp/__init__.py#L191>`__567# function, which will return a decomposition table using existing568# decomposition implementations.569570class M(torch.nn.Module):571def __init__(self):572super().__init__()573self.linear = torch.nn.Linear(3, 4)574575def forward(self, x):576return self.linear(x)577578ep = export(M(), (torch.randn(2, 3),))579print(ep.graph)580581from torch._decomp import get_decompositions582decomp_table = get_decompositions([torch.ops.aten.t.default, torch.ops.aten.transpose.int])583core_ir_ep = ep.run_decompositions(decomp_table)584print(core_ir_ep.graph)585586######################################################################587# If there is no existing decomposition function for an ATen operator that you would588# like to decompose, feel free to send a pull request into PyTorch589# implementing the decomposition!590591######################################################################592# ExportDB593# --------594#595# ``torch.export`` will only ever export a single computation graph from a PyTorch program. Because of this requirement,596# there will be Python or PyTorch features that are not compatible with ``torch.export``, which will require users to597# rewrite parts of their model code. We have seen examples of this earlier in the tutorial -- for example, rewriting598# if-statements using ``cond``.599#600# `ExportDB <https://pytorch.org/docs/main/generated/exportdb/index.html>`__ is the standard reference that documents601# supported and unsupported Python/PyTorch features for ``torch.export``. It is essentially a list a program samples, each602# of which represents the usage of one particular Python/PyTorch feature and its interaction with ``torch.export``.603# Examples are also tagged by category so that they can be more easily searched.604#605# For example, let's use ExportDB to get a better understanding of how the predicate works in the ``cond`` operator.606# We can look at the example called ``cond_predicate``, which has a ``torch.cond`` tag. The example code looks like:607608def cond_predicate(x):609"""610The conditional statement (aka predicate) passed to ``cond()`` must be one of the following:611- torch.Tensor with a single element612- boolean expression613NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.614"""615pred = x.dim() > 2 and x.shape[2] > 10616return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])617618######################################################################619# More generally, ExportDB can be used as a reference when one of the following occurs:620#621# 1. Before attempting ``torch.export``, you know ahead of time that your model uses some tricky Python/PyTorch features622# and you want to know if ``torch.export`` covers that feature.623# 2. When attempting ``torch.export``, there is a failure and it's unclear how to work around it.624#625# ExportDB is not exhaustive, but is intended to cover all use cases found in typical PyTorch code. Feel free to reach626# out if there is an important Python/PyTorch feature that should be added to ExportDB or supported by ``torch.export``.627628######################################################################629# Conclusion630# ----------631#632# We introduced ``torch.export``, the new PyTorch 2.X way to export single computation633# graphs from PyTorch programs. In particular, we demonstrate several code modifications634# and considerations (control flow ops, constraints, etc.) that need to be made in order to export a graph.635636637