CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/advanced_source/python_custom_ops.py
Views: 712
1
# -*- coding: utf-8 -*-
2
3
"""
4
.. _python-custom-ops-tutorial:
5
6
Python Custom Operators
7
=======================
8
9
.. grid:: 2
10
11
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
12
:class-card: card-prerequisites
13
14
* How to integrate custom operators written in Python with PyTorch
15
* How to test custom operators using ``torch.library.opcheck``
16
17
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
18
:class-card: card-prerequisites
19
20
* PyTorch 2.4 or later
21
22
PyTorch offers a large library of operators that work on Tensors (e.g.
23
``torch.add``, ``torch.sum``, etc). However, you might wish to use a new customized
24
operator with PyTorch, perhaps written by a third-party library. This tutorial
25
shows how to wrap Python functions so that they behave like PyTorch native
26
operators. Reasons why you may wish to create a custom operator in PyTorch include:
27
28
- Treating an arbitrary Python function as an opaque callable with respect
29
to ``torch.compile`` (that is, prevent ``torch.compile`` from tracing
30
into the function).
31
- Adding training support to an arbitrary Python function
32
33
Please note that if your operation can be expressed as a composition of
34
existing PyTorch operators, then there is usually no need to use the custom operator
35
API -- everything (for example ``torch.compile``, training support) should
36
just work.
37
"""
38
######################################################################
39
# Example: Wrapping PIL's crop into a custom operator
40
# ------------------------------------
41
# Let's say that we are using PIL's ``crop`` operation.
42
43
import torch
44
from torchvision.transforms.functional import to_pil_image, pil_to_tensor
45
import PIL
46
import IPython
47
import matplotlib.pyplot as plt
48
49
def crop(pic, box):
50
img = to_pil_image(pic.cpu())
51
cropped_img = img.crop(box)
52
return pil_to_tensor(cropped_img).to(pic.device) / 255.
53
54
def display(img):
55
plt.imshow(img.numpy().transpose((1, 2, 0)))
56
57
img = torch.ones(3, 64, 64)
58
img *= torch.linspace(0, 1, steps=64) * torch.linspace(0, 1, steps=64).unsqueeze(-1)
59
display(img)
60
61
######################################################################
62
63
cropped_img = crop(img, (10, 10, 50, 50))
64
display(cropped_img)
65
66
######################################################################
67
# ``crop`` is not handled effectively out-of-the-box by
68
# ``torch.compile``: ``torch.compile`` induces a
69
# `"graph break" <https://pytorch.org/docs/stable/torch.compiler_faq.html#graph-breaks>`_
70
# on functions it is unable to handle and graph breaks are bad for performance.
71
# The following code demonstrates this by raising an error
72
# (``torch.compile`` with ``fullgraph=True`` raises an error if a
73
# graph break occurs).
74
75
@torch.compile(fullgraph=True)
76
def f(img):
77
return crop(img, (10, 10, 50, 50))
78
79
# The following raises an error. Uncomment the line to see it.
80
# cropped_img = f(img)
81
82
######################################################################
83
# In order to black-box ``crop`` for use with ``torch.compile``, we need to
84
# do two things:
85
#
86
# 1. wrap the function into a PyTorch custom operator.
87
# 2. add a "``FakeTensor`` kernel" (aka "meta kernel") to the operator.
88
# Given some ``FakeTensors`` inputs (dummy Tensors that don't have storage),
89
# this function should return dummy Tensors of your choice with the correct
90
# Tensor metadata (shape/strides/``dtype``/device).
91
92
93
from typing import Sequence
94
95
# Use torch.library.custom_op to define a new custom operator.
96
# If your operator mutates any input Tensors, their names must be specified
97
# in the ``mutates_args`` argument.
98
@torch.library.custom_op("mylib::crop", mutates_args=())
99
def crop(pic: torch.Tensor, box: Sequence[int]) -> torch.Tensor:
100
img = to_pil_image(pic.cpu())
101
cropped_img = img.crop(box)
102
return (pil_to_tensor(cropped_img) / 255.).to(pic.device, pic.dtype)
103
104
# Use register_fake to add a ``FakeTensor`` kernel for the operator
105
@crop.register_fake
106
def _(pic, box):
107
channels = pic.shape[0]
108
x0, y0, x1, y1 = box
109
return pic.new_empty(channels, y1 - y0, x1 - x0)
110
111
######################################################################
112
# After this, ``crop`` now works without graph breaks:
113
114
@torch.compile(fullgraph=True)
115
def f(img):
116
return crop(img, (10, 10, 50, 50))
117
118
cropped_img = f(img)
119
display(img)
120
121
######################################################################
122
123
display(cropped_img)
124
125
######################################################################
126
# Adding training support for crop
127
# --------------------------------
128
# Use ``torch.library.register_autograd`` to add training support for an operator.
129
# Prefer this over directly using ``torch.autograd.Function``; some compositions of
130
# ``autograd.Function`` with PyTorch operator registration APIs can lead to (and
131
# has led to) silent incorrectness when composed with ``torch.compile``.
132
#
133
# If you don't need training support, there is no need to use
134
# ``torch.library.register_autograd``.
135
# If you end up training with a ``custom_op`` that doesn't have an autograd
136
# registration, we'll raise an error message.
137
#
138
# The gradient formula for ``crop`` is essentially ``PIL.paste`` (we'll leave the
139
# derivation as an exercise to the reader). Let's first wrap ``paste`` into a
140
# custom operator:
141
142
@torch.library.custom_op("mylib::paste", mutates_args=())
143
def paste(im1: torch.Tensor, im2: torch.Tensor, coord: Sequence[int]) -> torch.Tensor:
144
assert im1.device == im2.device
145
assert im1.dtype == im2.dtype
146
im1_pil = to_pil_image(im1.cpu())
147
im2_pil = to_pil_image(im2.cpu())
148
PIL.Image.Image.paste(im1_pil, im2_pil, coord)
149
return (pil_to_tensor(im1_pil) / 255.).to(im1.device, im1.dtype)
150
151
@paste.register_fake
152
def _(im1, im2, coord):
153
assert im1.device == im2.device
154
assert im1.dtype == im2.dtype
155
return torch.empty_like(im1)
156
157
######################################################################
158
# And now let's use ``register_autograd`` to specify the gradient formula for ``crop``:
159
160
def backward(ctx, grad_output):
161
grad_input = grad_output.new_zeros(ctx.pic_shape)
162
grad_input = paste(grad_input, grad_output, ctx.coords)
163
return grad_input, None
164
165
def setup_context(ctx, inputs, output):
166
pic, box = inputs
167
ctx.coords = box[:2]
168
ctx.pic_shape = pic.shape
169
170
crop.register_autograd(backward, setup_context=setup_context)
171
172
######################################################################
173
# Note that the backward must be a composition of PyTorch-understood operators,
174
# which is why we wrapped paste into a custom operator instead of directly using
175
# PIL's paste.
176
177
img = img.requires_grad_()
178
result = crop(img, (10, 10, 50, 50))
179
result.sum().backward()
180
display(img.grad)
181
182
######################################################################
183
# This is the correct gradient, with 1s (white) in the cropped region and 0s
184
# (black) in the unused region.
185
186
######################################################################
187
# Testing Python Custom operators
188
# -------------------------------
189
# Use ``torch.library.opcheck`` to test that the custom operator was registered
190
# correctly. This does not test that the gradients are mathematically correct;
191
# please write separate tests for that (either manual ones or ``torch.autograd.gradcheck``).
192
#
193
# To use ``opcheck``, pass it a set of example inputs to test against. If your
194
# operator supports training, then the examples should include Tensors that
195
# require grad. If your operator supports multiple devices, then the examples
196
# should include Tensors from each device.
197
198
examples = [
199
[torch.randn(3, 64, 64), [0, 0, 10, 10]],
200
[torch.randn(3, 91, 91, requires_grad=True), [10, 0, 20, 10]],
201
[torch.randn(3, 60, 60, dtype=torch.double), [3, 4, 32, 20]],
202
[torch.randn(3, 512, 512, requires_grad=True, dtype=torch.double), [3, 4, 32, 45]],
203
]
204
205
for example in examples:
206
torch.library.opcheck(crop, example)
207
208
######################################################################
209
# Mutable Python Custom operators
210
# -------------------------------
211
# You can also wrap a Python function that mutates its inputs into a custom
212
# operator.
213
# Functions that mutate inputs are common because that is how many low-level
214
# kernels are written; for example, a kernel that computes ``sin`` may take in
215
# the input and an output tensor and write ``input.sin()`` to the output tensor.
216
#
217
# We'll use ``numpy.sin`` to demonstrate an example of a mutable Python
218
# custom operator.
219
220
import numpy as np
221
222
@torch.library.custom_op("mylib::numpy_sin", mutates_args={"output"}, device_types="cpu")
223
def numpy_sin(input: torch.Tensor, output: torch.Tensor) -> None:
224
assert input.device == output.device
225
assert input.device.type == "cpu"
226
input_np = input.numpy()
227
output_np = output.numpy()
228
np.sin(input_np, out=output_np)
229
230
######################################################################
231
# Because the operator doesn't return anything, there is no need to register
232
# a ``FakeTensor`` kernel (meta kernel) to get it to work with ``torch.compile``.
233
234
@torch.compile(fullgraph=True)
235
def f(x):
236
out = torch.empty(3)
237
numpy_sin(x, out)
238
return out
239
240
x = torch.randn(3)
241
y = f(x)
242
assert torch.allclose(y, x.sin())
243
244
######################################################################
245
# And here's an ``opcheck`` run telling us that we did indeed register the operator correctly.
246
# ``opcheck`` would error out if we forgot to add the output to ``mutates_args``, for example.
247
248
example_inputs = [
249
[torch.randn(3), torch.empty(3)],
250
[torch.randn(0, 3), torch.empty(0, 3)],
251
[torch.randn(1, 2, 3, 4, dtype=torch.double), torch.empty(1, 2, 3, 4, dtype=torch.double)],
252
]
253
254
for example in example_inputs:
255
torch.library.opcheck(numpy_sin, example)
256
257
######################################################################
258
# Conclusion
259
# ----------
260
# In this tutorial, we learned how to use ``torch.library.custom_op`` to
261
# create a custom operator in Python that works with PyTorch subsystems
262
# such as ``torch.compile`` and autograd.
263
#
264
# This tutorial provides a basic introduction to custom operators.
265
# For more detailed information, see:
266
#
267
# - `the torch.library documentation <https://pytorch.org/docs/stable/library.html>`_
268
# - `the Custom Operators Manual <https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html#the-custom-operators-manual>`_
269
#
270
271