Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pytorch
GitHub Repository: pytorch/tutorials
Path: blob/main/advanced_source/python_custom_ops_functional.py
6463 views
1
# -*- coding: utf-8 -*-
2
3
"""
4
.. _python-custom-ops-functional:
5
6
Functional Python Custom Operators
7
==================================
8
9
Use this path when the operator mutates no Tensor inputs and returns fresh
10
Tensor outputs.
11
12
If the operator must work with ``torch.compile`` or ``torch.export``,
13
:ref:`register a fake kernel <python-custom-ops-functional-register-fake>`.
14
The fake kernel describes output metadata without running the real kernel.
15
16
Before writing the operator, read the required schema and mutation/aliasing
17
contract rules in :ref:`python-custom-ops-schema-contract`.
18
19
Checklist:
20
21
* use ``mutates_args=()``;
22
* return tensors that do not alias any input;
23
* register a fake kernel for ``torch.compile`` and ``torch.export``;
24
* validate the operator with ``torch.library.opcheck``.
25
"""
26
27
######################################################################
28
# Example: wrapping NumPy sin into a custom operator
29
# --------------------------------------------------
30
# Let's say that we are using NumPy's ``sin`` operation. This is an ordinary
31
# Python function from PyTorch's point of view: it converts the Tensor to a
32
# NumPy array, calls NumPy, and returns a fresh Tensor.
33
34
import numpy as np
35
import torch
36
from torch import Tensor
37
38
39
def numpy_sin_impl(x: Tensor) -> Tensor:
40
result = torch.empty_like(x)
41
np.sin(x.detach().numpy(), out=result.numpy())
42
return result
43
44
45
x = torch.randn(5)
46
torch.testing.assert_close(numpy_sin_impl(x), x.sin())
47
48
# This small example focuses on the custom-operator mechanics. More complex
49
# Python or third-party library calls may not be handled effectively
50
# out-of-the-box by ``torch.compile``: ``torch.compile`` may induce a
51
# `"graph break" <https://pytorch.org/docs/stable/torch.compiler_faq.html#graph-breaks>`_
52
# on functions it is unable to handle, and graph breaks are bad for performance.
53
# A custom operator gives PyTorch an explicit boundary for such code.
54
#
55
# To make ``numpy_sin_impl`` available as a custom operator that works with
56
# ``torch.compile`` and ``torch.export``, we need to do two things:
57
#
58
# 1. wrap the function into a PyTorch custom operator.
59
# 2. add a "``FakeTensor`` kernel" (aka "meta kernel") to the operator.
60
# Given some ``FakeTensors`` inputs (dummy Tensors that don't have storage),
61
# this function should return dummy Tensors of your choice with the correct
62
# Tensor metadata (shape/strides/``dtype``/device).
63
64
65
@torch.library.custom_op(
66
"mylib_functional::numpy_sin",
67
mutates_args=(),
68
device_types="cpu",
69
)
70
def numpy_sin(x: Tensor) -> Tensor:
71
result = torch.empty_like(x)
72
np.sin(x.detach().numpy(), out=result.numpy())
73
return result
74
75
76
######################################################################
77
# .. _python-custom-ops-functional-register-fake:
78
#
79
# Use ``register_fake`` to add a ``FakeTensor`` kernel for the operator.
80
# ``numpy_sin`` returns one Tensor with the same shape, strides, dtype, device,
81
# and storage offset as ``torch.empty_like(x)``, so the fake kernel can return
82
# ``empty_like(x)``. In general, the fake kernel must match all output metadata,
83
# including storage offset when relevant.
84
85
86
@numpy_sin.register_fake
87
def _(x):
88
return torch.empty_like(x)
89
90
91
######################################################################
92
# After this, ``numpy_sin`` can be used under ``torch.compile``:
93
94
95
@torch.compile(fullgraph=True)
96
def f(x):
97
return numpy_sin(x)
98
99
100
result = f(x)
101
torch.testing.assert_close(result, x.sin())
102
103
######################################################################
104
# A PIL image transform, Python binding to a C++ extension, or another
105
# third-party library call follows the same pattern. If it returns tensors,
106
# write the fake kernel to match the real output metadata exactly: shape,
107
# strides, dtype, device, layout, and storage offset when relevant.
108
109
######################################################################
110
# Example: fake kernels must match strides
111
# ----------------------------------------
112
# The fake kernel must match the real output strides, not only the shape. This
113
# operator returns a fresh Tensor with the same shape as ``x`` but different
114
# strides.
115
116
117
def numpy_sin_strided_impl(x: Tensor) -> Tensor:
118
result = torch.empty_strided(
119
x.shape,
120
tuple(reversed(x.stride())),
121
dtype=x.dtype,
122
device=x.device,
123
)
124
np.sin(x.detach().numpy(), out=result.numpy())
125
return result
126
127
128
@torch.library.custom_op(
129
"mylib_functional::numpy_sin_strided_bad",
130
mutates_args=(),
131
device_types="cpu",
132
)
133
def numpy_sin_strided_bad(x: Tensor) -> Tensor:
134
return numpy_sin_strided_impl(x)
135
136
137
@numpy_sin_strided_bad.register_fake
138
def _(x):
139
return torch.empty_like(x)
140
141
142
try:
143
torch.library.opcheck(numpy_sin_strided_bad, (torch.randn(2, 3),))
144
except Exception as exc:
145
print(f"opcheck caught incorrect fake kernel metadata: {type(exc).__name__}")
146
else:
147
torch_version = tuple(
148
int(part) for part in torch.__version__.split("+")[0].split(".")[:2]
149
)
150
if torch_version >= (2, 13):
151
raise AssertionError("Expected opcheck to fail")
152
print("PyTorch versions before 2.13 may not catch this metadata mismatch")
153
154
155
@torch.library.custom_op(
156
"mylib_functional::numpy_sin_strided",
157
mutates_args=(),
158
device_types="cpu",
159
)
160
def numpy_sin_strided(x: Tensor) -> Tensor:
161
return numpy_sin_strided_impl(x)
162
163
164
@numpy_sin_strided.register_fake
165
def _(x):
166
return torch.empty_strided(
167
x.shape,
168
tuple(reversed(x.stride())),
169
dtype=x.dtype,
170
device=x.device,
171
)
172
173
174
torch.library.opcheck(numpy_sin_strided, (torch.randn(2, 3),))
175
176
######################################################################
177
# Testing Python custom operators
178
# -------------------------------
179
# Use ``torch.library.opcheck`` to test that the custom operator was registered
180
# correctly. This does not test numerical correctness; write separate tests for
181
# that.
182
#
183
# To use ``opcheck``, pass it a set of example inputs to test against. If your
184
# operator supports training, then the examples should include Tensors that
185
# require grad. If your operator supports multiple devices, then the examples
186
# should include Tensors from each device.
187
188
189
examples = [
190
(torch.randn(5),),
191
(torch.randn(0, 3),),
192
(torch.randn(2, 3, dtype=torch.double),),
193
(torch.randn(2, 3).t(),),
194
(torch.randn(8)[1:],),
195
]
196
197
for example in examples:
198
torch.library.opcheck(numpy_sin, example)
199
200
######################################################################
201
# To add autograd, ``torch.vmap``, or other subsystem support, continue to
202
# :ref:`python-custom-ops-registrations`.
203
204