Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pytorch
GitHub Repository: pytorch/tutorials
Path: blob/main/advanced_source/python_custom_ops_registrations.py
6463 views
1
# -*- coding: utf-8 -*-
2
3
"""
4
.. _python-custom-ops-registrations:
5
6
Adding Training and Other Registrations to Python Custom Operators
7
==================================================================
8
9
Start here after a base operator passes ``torch.library.opcheck``:
10
11
* :ref:`python-custom-ops-functional`
12
* :ref:`python-custom-ops-mutable`
13
14
Registrations do not change the base contract. After adding one, rerun
15
``torch.library.opcheck`` on representative inputs for that subsystem.
16
"""
17
18
######################################################################
19
# Adding training support for NumPy sin
20
# -------------------------------------
21
# Use ``torch.library.register_autograd`` to add training support for an
22
# operator. Prefer this over directly using ``torch.autograd.Function``; some
23
# compositions of ``autograd.Function`` with PyTorch operator registration APIs
24
# can lead to (and has led to) silent incorrectness when composed with
25
# ``torch.compile``.
26
#
27
# If you don't need training support, there is no need to use
28
# ``torch.library.register_autograd``. If you end up training with a
29
# ``custom_op`` that doesn't have an autograd registration, we'll raise an error
30
# message.
31
#
32
# This page uses the same ``numpy.sin`` operation as the functional and mutable
33
# pages so the only new concept is the autograd registration.
34
35
import numpy as np
36
import torch
37
from torch import Tensor
38
39
40
@torch.library.custom_op(
41
"mylib_training::numpy_sin",
42
mutates_args=(),
43
device_types="cpu",
44
)
45
def numpy_sin(x: Tensor) -> Tensor:
46
result = torch.empty_like(x)
47
np.sin(x.detach().numpy(), out=result.numpy())
48
return result
49
50
51
@numpy_sin.register_fake
52
def _(x):
53
return torch.empty_like(x)
54
55
56
######################################################################
57
# The fake kernel must describe the same output metadata as the real kernel,
58
# including shape, strides, dtype, device, layout, and storage offset when
59
# relevant. Here the real kernel returns ``torch.empty_like(x)``, so the fake
60
# kernel does the same.
61
#
62
# The gradient formula for ``sin(x)`` is ``cos(x)``. The backward formula must
63
# be written in terms of PyTorch-understood operations or other custom
64
# operators. Do not directly use non-traceable Python or NumPy code from the
65
# backward formula.
66
67
68
def numpy_sin_setup_context(ctx, inputs, output):
69
(x,) = inputs
70
ctx.save_for_backward(x)
71
72
73
def numpy_sin_backward(ctx, grad_output):
74
(x,) = ctx.saved_tensors
75
return grad_output * x.cos()
76
77
78
######################################################################
79
# Register the backward formula and the context setup function:
80
81
82
numpy_sin.register_autograd(
83
numpy_sin_backward,
84
setup_context=numpy_sin_setup_context,
85
)
86
87
88
x = torch.randn(5, requires_grad=True)
89
y = numpy_sin(x)
90
y.sum().backward()
91
torch.testing.assert_close(x.grad, x.detach().cos())
92
93
######################################################################
94
# Testing autograd registration
95
# -----------------------------
96
# ``opcheck`` verifies that autograd was registered in a supported way, but it
97
# does not prove that the gradient formula is mathematically correct. Use
98
# separate numerical tests for that, either manual ones or
99
# ``torch.autograd.gradcheck``.
100
101
102
gradcheck_input = torch.randn(3, dtype=torch.double, requires_grad=True)
103
torch.autograd.gradcheck(numpy_sin, (gradcheck_input,))
104
105
examples = [
106
(torch.randn(5),),
107
(torch.randn(0, 3),),
108
(torch.randn(4, requires_grad=True),),
109
(torch.randn(2, dtype=torch.double, requires_grad=True),),
110
(torch.randn(2, 3).t(),),
111
(torch.randn(8)[1:],),
112
]
113
114
for example in examples:
115
torch.library.opcheck(numpy_sin, example)
116
117
118
######################################################################
119
# Other registrations
120
# -------------------
121
# Add these only when users need them.
122
#
123
# * **Multiple device kernels:** pass ``device_types="cpu"`` or
124
# ``device_types="cuda"`` if the implementation only works on one device.
125
# Register device-specific kernels when devices need different code.
126
# * **``torch.vmap``:** register a vmap rule with ``torch.library.register_vmap``
127
# when batching over the operator should do something different from a Python
128
# loop over the batch dimension.
129
# * **Tensor subclasses or modes:** use ``torch.library.register_torch_dispatch``
130
# when a Tensor subclass or ``TorchDispatchMode`` needs special behavior.
131
# * **Autocast:** for C++/CUDA operators that should participate in autocast,
132
# add an autocast registration as described in the C++ custom operator guide.
133
134