Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pytorch
GitHub Repository: pytorch/tutorials
Path: blob/main/advanced_source/python_custom_ops_mutable.py
6463 views
1
# -*- coding: utf-8 -*-
2
3
"""
4
.. _python-custom-ops-mutable:
5
6
Mutable Python Custom Operators
7
===============================
8
9
:ref:`Functional custom operators <python-custom-ops-functional>` showed
10
``numpy.sin`` as an operator that returns a fresh Tensor. This page shows the
11
mutable version: a kernel that writes ``sin(x)`` into an existing output Tensor.
12
Mutable operators have a different contract from functional operators.
13
14
Before writing the operator, read the required schema and mutation/aliasing
15
contract rules in :ref:`python-custom-ops-schema-contract`.
16
17
Checklist:
18
19
* choose one mutation pattern and keep it stable;
20
* list every mutated Tensor argument in ``mutates_args``;
21
* do not return mutated inputs unless you are using a tagged in-place or
22
``out=`` operator (available starting in PyTorch 2.13);
23
* validate the operator with ``torch.library.opcheck``.
24
"""
25
26
######################################################################
27
# Choose one mutation contract
28
# ----------------------------
29
# Choose the mutation behavior before adding optional registrations. PyTorch
30
# needs this contract for functionalization in ``torch.compile``
31
# and autograd.
32
#
33
# If the operator does not mutate any Tensor input, use the functional operator
34
# path instead.
35
#
36
# If the operator mutates the first positional Tensor and returns it, use a
37
# tagged in-place operator, starting in PyTorch 2.13.
38
#
39
# If the operator accepts write-only keyword-only ``out=`` Tensor arguments and
40
# returns them, use a tagged ``out=`` operator, starting in PyTorch 2.13.
41
#
42
# For other mutable operators, list every mutated argument in ``mutates_args``
43
# and do not return mutated inputs or their aliases.
44
45
import numpy as np
46
import torch
47
from torch import Tensor
48
49
50
######################################################################
51
# Example: write NumPy sin into an output buffer
52
# ----------------------------------------------
53
# Functions that mutate inputs are common because that is how many low-level
54
# kernels are written; for example, a kernel that computes ``sin`` may take in
55
# the input and an output tensor and write ``input.sin()`` to the output tensor.
56
#
57
# This operator writes ``sin(x)`` into ``out`` and returns ``None``.
58
59
60
@torch.library.custom_op(
61
"mylib_mutable::numpy_sin_out",
62
mutates_args={"out"},
63
device_types="cpu",
64
)
65
def numpy_sin_out(x: Tensor, out: Tensor) -> None:
66
if x.shape != out.shape:
67
raise RuntimeError("x and out must have the same shape")
68
if x.dtype != out.dtype:
69
raise RuntimeError("x and out must have the same dtype")
70
if x.device != out.device:
71
raise RuntimeError("x and out must be on the same device")
72
np.sin(x.detach().numpy(), out=out.numpy())
73
74
75
x = torch.randn(5)
76
out = torch.empty_like(x)
77
numpy_sin_out(x, out)
78
torch.testing.assert_close(out, x.sin())
79
80
######################################################################
81
# Because the operator doesn't return anything, there is no need to register a
82
# ``FakeTensor`` kernel (meta kernel) to get it to work with ``torch.compile``.
83
# If a mutable operator also returns a fresh Tensor, register a fake kernel for
84
# that output.
85
86
87
@torch.compile(fullgraph=True)
88
def compiled_numpy_sin_out(x):
89
out = torch.empty_like(x)
90
numpy_sin_out(x, out)
91
return out
92
93
94
torch.testing.assert_close(compiled_numpy_sin_out(x), x.sin())
95
96
######################################################################
97
# PyTorch-style in-place and out= operators
98
# -----------------------------------------
99
# Starting in PyTorch 2.13, ``torch.library.custom_op`` supports tagged
100
# in-place and ``out=`` custom operators.
101
# Tagged in-place operators return the same Tensor they mutate. Tagged ``out=``
102
# operators return their keyword-only output buffers in declaration order.
103
# This example uses ``mylib_mutable::sin_`` for a tagged in-place custom
104
# operator and ``mylib_mutable::sin_out`` for a tagged ``out=`` custom operator.
105
106
107
supports_tagged_mutable_ops = (
108
hasattr(torch, "Tag")
109
and hasattr(torch.Tag, "inplace")
110
and hasattr(torch.Tag, "out")
111
)
112
113
if supports_tagged_mutable_ops:
114
115
@torch.library.custom_op(
116
"mylib_mutable::sin_",
117
mutates_args={"x"},
118
tags=torch.Tag.inplace,
119
)
120
def sin_(x: Tensor) -> Tensor:
121
x.sin_()
122
return x
123
124
125
@torch.library.custom_op(
126
"mylib_mutable::sin_out",
127
mutates_args={"out"},
128
tags=torch.Tag.out,
129
)
130
def sin_out(x: Tensor, *, out: Tensor) -> Tensor:
131
torch.sin(x, out=out)
132
return out
133
134
135
x_for_inplace = torch.randn(3)
136
expected = x_for_inplace.sin()
137
torch.testing.assert_close(sin_(x_for_inplace), expected)
138
139
out_for_sin = torch.empty_like(x)
140
torch.testing.assert_close(
141
sin_out(x, out=out_for_sin),
142
x.sin(),
143
)
144
torch.testing.assert_close(out_for_sin, x.sin())
145
146
torch.library.opcheck(sin_, (torch.randn(3),))
147
torch.library.opcheck(
148
sin_out,
149
(torch.randn(3),),
150
{"out": torch.empty(3)},
151
)
152
else:
153
print("Tagged in-place and out= custom operators require PyTorch 2.13 or later.")
154
155
156
######################################################################
157
# Validate the operator
158
# ---------------------
159
# And here's an ``opcheck`` run telling us that we did indeed register the
160
# operator correctly. ``opcheck`` would error out if we forgot to add ``out`` to
161
# ``mutates_args``, for example.
162
163
164
examples = [
165
(torch.randn(5), torch.empty(5)),
166
(torch.randn(0, 3), torch.empty(0, 3)),
167
(
168
torch.randn(2, 3, dtype=torch.double),
169
torch.empty(2, 3, dtype=torch.double),
170
),
171
(
172
torch.randn(2, 3).t(),
173
torch.empty_strided((3, 2), (1, 3)),
174
),
175
]
176
177
for example in examples:
178
torch.library.opcheck(numpy_sin_out, example)
179
180
######################################################################
181
# For autograd, ``torch.vmap``, or other subsystem behavior, continue to
182
# :ref:`python-custom-ops-registrations`.
183
184