CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/intermediate_source/_torch_export_nightly_tutorial.py
Views: 494
1
# -*- coding: utf-8 -*-
2
3
"""
4
torch.export Nightly Tutorial
5
================
6
**Author:** William Wen, Zhengxu Chen, Angela Yi
7
"""
8
9
######################################################################
10
#
11
# .. warning::
12
#
13
# ``torch.export`` and its related features are in prototype status and are subject to backwards compatibility
14
# breaking changes. This tutorial provides a snapshot of ``torch.export`` usage as of PyTorch 2.1.
15
#
16
# :func:`torch.export` is the PyTorch 2.X way to export PyTorch models into
17
# standardized model representations, intended
18
# to be run on different (i.e. Python-less) environments.
19
#
20
# In this tutorial, you will learn how to use :func:`torch.export` to extract
21
# ``ExportedProgram``'s (i.e. single-graph representations) from PyTorch programs.
22
# We also detail some considerations/modifications that you may need
23
# to make in order to make your model compatible with ``torch.export``.
24
#
25
# **Contents**
26
#
27
# .. contents::
28
# :local:
29
30
######################################################################
31
# Basic Usage
32
# -----------
33
#
34
# ``torch.export`` extracts single-graph representations from PyTorch programs
35
# by tracing the target function, given example inputs.
36
# ``torch.export.export()`` is the main entry point for ``torch.export``.
37
#
38
# In this tutorial, ``torch.export`` and ``torch.export.export()`` are practically synonymous,
39
# though ``torch.export`` generally refers to the PyTorch 2.X export process, and ``torch.export.export()``
40
# generally refers to the actual function call.
41
#
42
# The signature of ``torch.export.export()`` is:
43
#
44
# .. code:: python
45
#
46
# export(
47
# f: Callable,
48
# args: Tuple[Any, ...],
49
# kwargs: Optional[Dict[str, Any]] = None,
50
# *,
51
# dynamic_shapes: Optional[Dict[str, Dict[int, Dim]]] = None
52
# ) -> ExportedProgram
53
#
54
# ``torch.export.export()`` traces the tensor computation graph from calling ``f(*args, **kwargs)``
55
# and wraps it in an ``ExportedProgram``, which can be serialized or executed later with
56
# different inputs. Note that while the output ``ExportedGraph`` is callable and can be
57
# called in the same way as the original input callable, it is not a ``torch.nn.Module``.
58
# We will detail the ``dynamic_shapes`` argument later in the tutorial.
59
60
import torch
61
from torch.export import export
62
63
class MyModule(torch.nn.Module):
64
def __init__(self):
65
super().__init__()
66
self.lin = torch.nn.Linear(100, 10)
67
68
def forward(self, x, y):
69
return torch.nn.functional.relu(self.lin(x + y), inplace=True)
70
71
mod = MyModule()
72
exported_mod = export(mod, (torch.randn(8, 100), torch.randn(8, 100)))
73
print(type(exported_mod))
74
print(exported_mod(torch.randn(8, 100), torch.randn(8, 100)))
75
76
######################################################################
77
# Let's review some attributes of ``ExportedProgram`` that are of interest.
78
#
79
# The ``graph`` attribute is an `FX graph <https://pytorch.org/docs/stable/fx.html#torch.fx.Graph>`__
80
# traced from the function we exported, that is, the computation graph of all PyTorch operations.
81
# The FX graph has some important properties:
82
#
83
# - The operations are "ATen-level" operations.
84
# - The graph is "functionalized", meaning that no operations are mutations.
85
#
86
# The ``graph_module`` attribute is the ``GraphModule`` that wraps the ``graph`` attribute
87
# so that it can be ran as a ``torch.nn.Module``.
88
89
print(exported_mod)
90
print(exported_mod.graph_module)
91
92
######################################################################
93
# The printed code shows that FX graph only contains ATen-level ops (such as ``torch.ops.aten``)
94
# and that mutations were removed. For example, the mutating op ``torch.nn.functional.relu(..., inplace=True)``
95
# is represented in the printed code by ``torch.ops.aten.relu.default``, which does not mutate.
96
# Future uses of input to the original mutating ``relu`` op are replaced by the additional new output
97
# of the replacement non-mutating ``relu`` op.
98
#
99
# Other attributes of interest in ``ExportedProgram`` include:
100
#
101
# - ``graph_signature`` -- the inputs, outputs, parameters, buffers, etc. of the exported graph.
102
# - ``range_constraints`` and ``equality_constraints`` -- constraints, covered later
103
104
print(exported_mod.graph_signature)
105
106
######################################################################
107
# See the ``torch.export`` `documentation <https://pytorch.org/docs/main/export.html#torch.export.export>`__
108
# for more details.
109
110
######################################################################
111
# Graph Breaks
112
# ------------
113
#
114
# Although ``torch.export`` shares components with ``torch.compile``,
115
# the key limitation of ``torch.export``, especially when compared to ``torch.compile``, is that it does not
116
# support graph breaks. This is because handling graph breaks involves interpreting
117
# the unsupported operation with default Python evaluation, which is incompatible
118
# with the export use case. Therefore, in order to make your model code compatible
119
# with ``torch.export``, you will need to modify your code to remove graph breaks.
120
#
121
# A graph break is necessary in cases such as:
122
#
123
# - data-dependent control flow
124
125
def bad1(x):
126
if x.sum() > 0:
127
return torch.sin(x)
128
return torch.cos(x)
129
130
import traceback as tb
131
try:
132
export(bad1, (torch.randn(3, 3),))
133
except Exception:
134
tb.print_exc()
135
136
######################################################################
137
# - accessing tensor data with ``.data``
138
139
def bad2(x):
140
x.data[0, 0] = 3
141
return x
142
143
try:
144
export(bad2, (torch.randn(3, 3),))
145
except Exception:
146
tb.print_exc()
147
148
######################################################################
149
# - calling unsupported functions (such as many built-in functions)
150
151
def bad3(x):
152
x = x + 1
153
return x + id(x)
154
155
try:
156
export(bad3, (torch.randn(3, 3),))
157
except Exception:
158
tb.print_exc()
159
160
######################################################################
161
# - unsupported Python language features (e.g. throwing exceptions, match statements)
162
163
def bad4(x):
164
try:
165
x = x + 1
166
raise RuntimeError("bad")
167
except:
168
x = x + 2
169
return x
170
171
try:
172
export(bad4, (torch.randn(3, 3),))
173
except Exception:
174
tb.print_exc()
175
176
######################################################################
177
# The sections below demonstrate some ways you can modify your code
178
# in order to remove graph breaks.
179
180
######################################################################
181
# Control Flow Ops
182
# ----------------
183
#
184
# ``torch.export`` actually does support data-dependent control flow.
185
# But these need to be expressed using control flow ops. For example,
186
# we can fix the control flow example above using the ``cond`` op, like so:
187
188
from functorch.experimental.control_flow import cond
189
190
def bad1_fixed(x):
191
def true_fn(x):
192
return torch.sin(x)
193
def false_fn(x):
194
return torch.cos(x)
195
return cond(x.sum() > 0, true_fn, false_fn, [x])
196
197
exported_bad1_fixed = export(bad1_fixed, (torch.randn(3, 3),))
198
print(exported_bad1_fixed(torch.ones(3, 3)))
199
print(exported_bad1_fixed(-torch.ones(3, 3)))
200
201
######################################################################
202
# There are limitations to ``cond`` that one should be aware of:
203
#
204
# - The predicate (i.e. ``x.sum() > 0``) must result in a boolean or a single-element tensor.
205
# - The operands (i.e. ``[x]``) must be tensors.
206
# - The branch function (i.e. ``true_fn`` and ``false_fn``) signature must match with the
207
# operands and they must both return a single tensor with the same metadata (for example, ``dtype``, ``shape``, etc.).
208
# - Branch functions cannot mutate input or global variables.
209
# - Branch functions cannot access closure variables, except for ``self`` if the function is
210
# defined in the scope of a method.
211
#
212
# For more details about ``cond``, check out the `documentation <https://pytorch.org/docs/main/cond.html>`__.
213
214
######################################################################
215
# ..
216
# [NOTE] map is not documented at the moment
217
# We can also use ``map``, which applies a function across the first dimension
218
# of the first tensor argument.
219
#
220
# from functorch.experimental.control_flow import map
221
#
222
# def map_example(xs):
223
# def map_fn(x, const):
224
# def true_fn(x):
225
# return x + const
226
# def false_fn(x):
227
# return x - const
228
# return control_flow.cond(x.sum() > 0, true_fn, false_fn, [x])
229
# return control_flow.map(map_fn, xs, torch.tensor([2.0]))
230
#
231
# exported_map_example= export(map_example, (torch.randn(4, 3),))
232
# inp = torch.cat((torch.ones(2, 3), -torch.ones(2, 3)))
233
# print(exported_map_example(inp))
234
235
######################################################################
236
# Constraints/Dynamic Shapes
237
# --------------------------
238
#
239
# Ops can have different specializations/behaviors for different tensor shapes, so by default,
240
# ``torch.export`` requires inputs to ``ExportedProgram`` to have the same shape as the respective
241
# example inputs given to the initial ``torch.export.export()`` call.
242
# If we try to run the ``ExportedProgram`` in the example below with a tensor
243
# with a different shape, we get an error:
244
245
class MyModule2(torch.nn.Module):
246
def __init__(self):
247
super().__init__()
248
self.lin = torch.nn.Linear(100, 10)
249
250
def forward(self, x, y):
251
return torch.nn.functional.relu(self.lin(x + y), inplace=True)
252
253
mod2 = MyModule2()
254
exported_mod2 = export(mod2, (torch.randn(8, 100), torch.randn(8, 100)))
255
256
try:
257
exported_mod2(torch.randn(10, 100), torch.randn(10, 100))
258
except Exception:
259
tb.print_exc()
260
261
######################################################################
262
# We can relax this constraint using the ``dynamic_shapes`` argument of
263
# ``torch.export.export()``, which allows us to specify, using ``torch.export.Dim``
264
# (`documentation <https://pytorch.org/docs/main/export.html#torch.export.Dim>`__),
265
# which dimensions of the input tensors are dynamic.
266
#
267
# For each tensor argument of the input callable, we can specify a mapping from the dimension
268
# to a ``torch.export.Dim``.
269
# A ``torch.export.Dim`` is essentially a named symbolic integer with optional
270
# minimum and maximum bounds.
271
#
272
# Then, the format of ``torch.export.export()``'s ``dynamic_shapes`` argument is a mapping
273
# from the input callable's tensor argument names, to dimension --> dim mappings as described above.
274
# If there is no ``torch.export.Dim`` given to a tensor argument's dimension, then that dimension is
275
# assumed to be static.
276
#
277
# The first argument of ``torch.export.Dim`` is the name for the symbolic integer, used for debugging.
278
# Then we can specify an optional minimum and maximum bound (inclusive). Below, we show example usage.
279
#
280
# In the example below, our input
281
# ``inp1`` has an unconstrained first dimension, but the size of the second
282
# dimension must be in the interval [4, 18].
283
284
from torch.export import Dim
285
286
inp1 = torch.randn(10, 10, 2)
287
288
def dynamic_shapes_example1(x):
289
x = x[:, 2:]
290
return torch.relu(x)
291
292
inp1_dim0 = Dim("inp1_dim0")
293
inp1_dim1 = Dim("inp1_dim1", min=4, max=18)
294
dynamic_shapes1 = {
295
"x": {0: inp1_dim0, 1: inp1_dim1},
296
}
297
298
exported_dynamic_shapes_example1 = export(dynamic_shapes_example1, (inp1,), dynamic_shapes=dynamic_shapes1)
299
300
print(exported_dynamic_shapes_example1(torch.randn(5, 5, 2)))
301
302
try:
303
exported_dynamic_shapes_example1(torch.randn(8, 1, 2))
304
except Exception:
305
tb.print_exc()
306
307
try:
308
exported_dynamic_shapes_example1(torch.randn(8, 20, 2))
309
except Exception:
310
tb.print_exc()
311
312
try:
313
exported_dynamic_shapes_example1(torch.randn(8, 8, 3))
314
except Exception:
315
tb.print_exc()
316
317
######################################################################
318
# Note that if our example inputs to ``torch.export`` do not satisfy the constraints
319
# given by ``dynamic_shapes``, then we get an error.
320
321
inp1_dim1_bad = Dim("inp1_dim1_bad", min=11, max=18)
322
dynamic_shapes1_bad = {
323
"x": {0: inp1_dim0, 1: inp1_dim1_bad},
324
}
325
326
try:
327
export(dynamic_shapes_example1, (inp1,), dynamic_shapes=dynamic_shapes1_bad)
328
except Exception:
329
tb.print_exc()
330
331
######################################################################
332
# We can enforce that equalities between dimensions of different tensors
333
# by using the same ``torch.export.Dim`` object, for example, in matrix multiplication:
334
335
inp2 = torch.randn(4, 8)
336
inp3 = torch.randn(8, 2)
337
338
def dynamic_shapes_example2(x, y):
339
return x @ y
340
341
inp2_dim0 = Dim("inp2_dim0")
342
inner_dim = Dim("inner_dim")
343
inp3_dim1 = Dim("inp3_dim1")
344
345
dynamic_shapes2 = {
346
"x": {0: inp2_dim0, 1: inner_dim},
347
"y": {0: inner_dim, 1: inp3_dim1},
348
}
349
350
exported_dynamic_shapes_example2 = export(dynamic_shapes_example2, (inp2, inp3), dynamic_shapes=dynamic_shapes2)
351
352
print(exported_dynamic_shapes_example2(torch.randn(2, 16), torch.randn(16, 4)))
353
354
try:
355
exported_dynamic_shapes_example2(torch.randn(4, 8), torch.randn(4, 2))
356
except Exception:
357
tb.print_exc()
358
359
######################################################################
360
# We can actually use ``torch.export`` to guide us as to which ``dynamic_shapes`` constraints
361
# are necessary. We can do this by relaxing all constraints (recall that if we
362
# do not provide constraints for a dimension, the default behavior is to constrain
363
# to the exact shape value of the example input) and letting ``torch.export``
364
# error out.
365
366
inp4 = torch.randn(8, 16)
367
inp5 = torch.randn(16, 32)
368
369
def dynamic_shapes_example3(x, y):
370
if x.shape[0] <= 16:
371
return x @ y[:, :16]
372
return y
373
374
dynamic_shapes3 = {
375
"x": {i: Dim(f"inp4_dim{i}") for i in range(inp4.dim())},
376
"y": {i: Dim(f"inp5_dim{i}") for i in range(inp5.dim())},
377
}
378
379
try:
380
export(dynamic_shapes_example3, (inp4, inp5), dynamic_shapes=dynamic_shapes3)
381
except Exception:
382
tb.print_exc()
383
384
######################################################################
385
# We can see that the error message gives us suggested fixes to our
386
# dynamic shape constraints. Let us follow those suggestions (exact
387
# suggestions may differ slightly):
388
389
def suggested_fixes():
390
inp4_dim1 = Dim('shared_dim')
391
# suggested fixes below
392
inp4_dim0 = Dim('inp4_dim0', max=16)
393
inp5_dim1 = Dim('inp5_dim1', min=17)
394
inp5_dim0 = inp4_dim1
395
# end of suggested fixes
396
return {
397
"x": {0: inp4_dim0, 1: inp4_dim1},
398
"y": {0: inp5_dim0, 1: inp5_dim1},
399
}
400
401
dynamic_shapes3_fixed = suggested_fixes()
402
exported_dynamic_shapes_example3 = export(dynamic_shapes_example3, (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed)
403
print(exported_dynamic_shapes_example3(torch.randn(4, 32), torch.randn(32, 64)))
404
405
######################################################################
406
# Note that in the example above, because we constrained the value of ``x.shape[0]`` in
407
# ``dynamic_shapes_example3``, the exported program is sound even though there is a
408
# raw ``if`` statement.
409
#
410
# If you want to see why ``torch.export`` generated these constraints, you can
411
# re-run the script with the environment variable ``TORCH_LOGS=dynamic,dynamo``,
412
# or use ``torch._logging.set_logs``.
413
414
import logging
415
torch._logging.set_logs(dynamic=logging.INFO, dynamo=logging.INFO)
416
exported_dynamic_shapes_example3 = export(dynamic_shapes_example3, (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed)
417
418
# reset to previous values
419
torch._logging.set_logs(dynamic=logging.WARNING, dynamo=logging.WARNING)
420
421
######################################################################
422
# We can view an ``ExportedProgram``'s constraints using the ``range_constraints`` and
423
# ``equality_constraints`` attributes. The logging above reveals what the symbols ``s0, s1, ...``
424
# represent.
425
426
print(exported_dynamic_shapes_example3.range_constraints)
427
print(exported_dynamic_shapes_example3.equality_constraints)
428
429
######################################################################
430
# Custom Ops
431
# ----------
432
#
433
# ``torch.export`` can export PyTorch programs with custom operators.
434
#
435
#
436
# Currently, the steps to register a custom op for use by ``torch.export`` are:
437
#
438
# - If you’re writing custom ops purely in Python, use torch.library.custom_op.
439
440
import torch.library
441
import numpy as np
442
443
@torch.library.custom_op("mylib::sin", mutates_args=())
444
def sin(x):
445
x_np = x.numpy()
446
y_np = np.sin(x_np)
447
return torch.from_numpy(y_np)
448
449
######################################################################
450
# - You will need to provide abstract implementation so that PT2 can trace through it.
451
452
@torch.library.register_fake("mylib::sin")
453
def _(x):
454
return torch.empty_like(x)
455
456
# - Sometimes, the custom op you are exporting has data-dependent output, meaning
457
# we can't determine the shape of the output at compile time. In this case, you can do
458
# following:
459
@torch.library.custom_op("mylib::nonzero", mutates_args=())
460
def nonzero(x):
461
x_np = x.cpu().numpy()
462
res = np.stack(np.nonzero(x_np), axis=1)
463
return torch.tensor(res, device=x.device)
464
465
@torch.library.register_fake("mylib::nonzero")
466
def _(x):
467
# The number of nonzero-elements is data-dependent.
468
# Since we cannot peek at the data in an abstract implementation,
469
# we use the `ctx` object to construct a new ``symint`` that
470
# represents the data-dependent size.
471
ctx = torch.library.get_ctx()
472
nnz = ctx.new_dynamic_size()
473
shape = [nnz, x.dim()]
474
result = x.new_empty(shape, dtype=torch.int64)
475
return result
476
477
######################################################################
478
# - Call the custom op from the code you want to export using ``torch.ops``
479
480
def custom_op_example(x):
481
x = torch.sin(x)
482
x = torch.ops.mylib.sin(x)
483
x = torch.cos(x)
484
y = torch.ops.mylib.nonzero(x)
485
return x + y.sum()
486
487
######################################################################
488
# - Export the code as before
489
490
exported_custom_op_example = export(custom_op_example, (torch.randn(3, 3),))
491
exported_custom_op_example.graph_module.print_readable()
492
print(exported_custom_op_example(torch.randn(3, 3)))
493
494
######################################################################
495
# Note in the above outputs that the custom op is included in the exported graph.
496
# And when we call the exported graph as a function, the original custom op is called,
497
# as evidenced by the ``print`` call.
498
#
499
# If you have a custom operator implemented in C++, please refer to
500
# `this document <https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ahugy69p2jmz>`__
501
# to make it compatible with ``torch.export``.
502
503
######################################################################
504
# Decompositions
505
# --------------
506
#
507
# The graph produced by ``torch.export`` by default returns a graph containing
508
# only functional ATen operators. This functional ATen operator set (or "opset") contains around 2000
509
# operators, all of which are functional, that is, they do not
510
# mutate or alias inputs. You can find a list of all ATen operators
511
# `here <https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml>`__
512
# and you can inspect if an operator is functional by checking
513
# ``op._schema.is_mutable``, for example:
514
515
print(torch.ops.aten.add.Tensor._schema.is_mutable)
516
print(torch.ops.aten.add_.Tensor._schema.is_mutable)
517
518
######################################################################
519
# By default, the environment in which you want to run the exported graph
520
# should support all ~2000 of these operators.
521
# However, you can use the following API on the exported program
522
# if your specific environment is only able to support a subset of
523
# the ~2000 operators.
524
#
525
# .. code:: python
526
#
527
# def run_decompositions(
528
# self: ExportedProgram,
529
# decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]]
530
# ) -> ExportedProgram
531
#
532
# ``run_decompositions`` takes in a decomposition table, which is a mapping of
533
# operators to a function specifying how to reduce, or decompose, that operator
534
# into an equivalent sequence of other ATen operators.
535
#
536
# The default decomposition table for ``run_decompositions`` is the
537
# `Core ATen decomposition table <https://github.com/pytorch/pytorch/blob/b460c3089367f3fadd40aa2cb3808ee370aa61e1/torch/_decomp/__init__.py#L252>`__
538
# which will decompose the all ATen operators to the
539
# `Core ATen Operator Set <https://pytorch.org/docs/main/torch.compiler_ir.html#core-aten-ir>`__
540
# which consists of only ~180 operators.
541
542
class M(torch.nn.Module):
543
def __init__(self):
544
super().__init__()
545
self.linear = torch.nn.Linear(3, 4)
546
547
def forward(self, x):
548
return self.linear(x)
549
550
ep = export(M(), (torch.randn(2, 3),))
551
print(ep.graph)
552
553
core_ir_ep = ep.run_decompositions()
554
print(core_ir_ep.graph)
555
556
######################################################################
557
# Notice that after running ``run_decompositions`` the
558
# ``torch.ops.aten.t.default`` operator, which is not part of the Core ATen
559
# Opset, has been replaced with ``torch.ops.aten.permute.default`` which is part
560
# of the Core ATen Opset.
561
562
######################################################################
563
# Most ATen operators already have decompositions, which are located
564
# `here <https://github.com/pytorch/pytorch/blob/b460c3089367f3fadd40aa2cb3808ee370aa61e1/torch/_decomp/decompositions.py>`__.
565
# If you would like to use some of these existing decomposition functions,
566
# you can pass in a list of operators you would like to decompose to the
567
# `get_decompositions <https://github.com/pytorch/pytorch/blob/b460c3089367f3fadd40aa2cb3808ee370aa61e1/torch/_decomp/__init__.py#L191>`__
568
# function, which will return a decomposition table using existing
569
# decomposition implementations.
570
571
class M(torch.nn.Module):
572
def __init__(self):
573
super().__init__()
574
self.linear = torch.nn.Linear(3, 4)
575
576
def forward(self, x):
577
return self.linear(x)
578
579
ep = export(M(), (torch.randn(2, 3),))
580
print(ep.graph)
581
582
from torch._decomp import get_decompositions
583
decomp_table = get_decompositions([torch.ops.aten.t.default, torch.ops.aten.transpose.int])
584
core_ir_ep = ep.run_decompositions(decomp_table)
585
print(core_ir_ep.graph)
586
587
######################################################################
588
# If there is no existing decomposition function for an ATen operator that you would
589
# like to decompose, feel free to send a pull request into PyTorch
590
# implementing the decomposition!
591
592
######################################################################
593
# ExportDB
594
# --------
595
#
596
# ``torch.export`` will only ever export a single computation graph from a PyTorch program. Because of this requirement,
597
# there will be Python or PyTorch features that are not compatible with ``torch.export``, which will require users to
598
# rewrite parts of their model code. We have seen examples of this earlier in the tutorial -- for example, rewriting
599
# if-statements using ``cond``.
600
#
601
# `ExportDB <https://pytorch.org/docs/main/generated/exportdb/index.html>`__ is the standard reference that documents
602
# supported and unsupported Python/PyTorch features for ``torch.export``. It is essentially a list a program samples, each
603
# of which represents the usage of one particular Python/PyTorch feature and its interaction with ``torch.export``.
604
# Examples are also tagged by category so that they can be more easily searched.
605
#
606
# For example, let's use ExportDB to get a better understanding of how the predicate works in the ``cond`` operator.
607
# We can look at the example called ``cond_predicate``, which has a ``torch.cond`` tag. The example code looks like:
608
609
def cond_predicate(x):
610
"""
611
The conditional statement (aka predicate) passed to ``cond()`` must be one of the following:
612
- torch.Tensor with a single element
613
- boolean expression
614
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
615
"""
616
pred = x.dim() > 2 and x.shape[2] > 10
617
return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])
618
619
######################################################################
620
# More generally, ExportDB can be used as a reference when one of the following occurs:
621
#
622
# 1. Before attempting ``torch.export``, you know ahead of time that your model uses some tricky Python/PyTorch features
623
# and you want to know if ``torch.export`` covers that feature.
624
# 2. When attempting ``torch.export``, there is a failure and it's unclear how to work around it.
625
#
626
# ExportDB is not exhaustive, but is intended to cover all use cases found in typical PyTorch code. Feel free to reach
627
# out if there is an important Python/PyTorch feature that should be added to ExportDB or supported by ``torch.export``.
628
629
######################################################################
630
# Conclusion
631
# ----------
632
#
633
# We introduced ``torch.export``, the new PyTorch 2.X way to export single computation
634
# graphs from PyTorch programs. In particular, we demonstrate several code modifications
635
# and considerations (control flow ops, constraints, etc.) that need to be made in order to export a graph.
636
637