Path: blob/main/advanced_source/python_custom_ops_registrations.py
6463 views
# -*- coding: utf-8 -*-12"""3.. _python-custom-ops-registrations:45Adding Training and Other Registrations to Python Custom Operators6==================================================================78Start here after a base operator passes ``torch.library.opcheck``:910* :ref:`python-custom-ops-functional`11* :ref:`python-custom-ops-mutable`1213Registrations do not change the base contract. After adding one, rerun14``torch.library.opcheck`` on representative inputs for that subsystem.15"""1617######################################################################18# Adding training support for NumPy sin19# -------------------------------------20# Use ``torch.library.register_autograd`` to add training support for an21# operator. Prefer this over directly using ``torch.autograd.Function``; some22# compositions of ``autograd.Function`` with PyTorch operator registration APIs23# can lead to (and has led to) silent incorrectness when composed with24# ``torch.compile``.25#26# If you don't need training support, there is no need to use27# ``torch.library.register_autograd``. If you end up training with a28# ``custom_op`` that doesn't have an autograd registration, we'll raise an error29# message.30#31# This page uses the same ``numpy.sin`` operation as the functional and mutable32# pages so the only new concept is the autograd registration.3334import numpy as np35import torch36from torch import Tensor373839@torch.library.custom_op(40"mylib_training::numpy_sin",41mutates_args=(),42device_types="cpu",43)44def numpy_sin(x: Tensor) -> Tensor:45result = torch.empty_like(x)46np.sin(x.detach().numpy(), out=result.numpy())47return result484950@numpy_sin.register_fake51def _(x):52return torch.empty_like(x)535455######################################################################56# The fake kernel must describe the same output metadata as the real kernel,57# including shape, strides, dtype, device, layout, and storage offset when58# relevant. Here the real kernel returns ``torch.empty_like(x)``, so the fake59# kernel does the same.60#61# The gradient formula for ``sin(x)`` is ``cos(x)``. The backward formula must62# be written in terms of PyTorch-understood operations or other custom63# operators. Do not directly use non-traceable Python or NumPy code from the64# backward formula.656667def numpy_sin_setup_context(ctx, inputs, output):68(x,) = inputs69ctx.save_for_backward(x)707172def numpy_sin_backward(ctx, grad_output):73(x,) = ctx.saved_tensors74return grad_output * x.cos()757677######################################################################78# Register the backward formula and the context setup function:798081numpy_sin.register_autograd(82numpy_sin_backward,83setup_context=numpy_sin_setup_context,84)858687x = torch.randn(5, requires_grad=True)88y = numpy_sin(x)89y.sum().backward()90torch.testing.assert_close(x.grad, x.detach().cos())9192######################################################################93# Testing autograd registration94# -----------------------------95# ``opcheck`` verifies that autograd was registered in a supported way, but it96# does not prove that the gradient formula is mathematically correct. Use97# separate numerical tests for that, either manual ones or98# ``torch.autograd.gradcheck``.99100101gradcheck_input = torch.randn(3, dtype=torch.double, requires_grad=True)102torch.autograd.gradcheck(numpy_sin, (gradcheck_input,))103104examples = [105(torch.randn(5),),106(torch.randn(0, 3),),107(torch.randn(4, requires_grad=True),),108(torch.randn(2, dtype=torch.double, requires_grad=True),),109(torch.randn(2, 3).t(),),110(torch.randn(8)[1:],),111]112113for example in examples:114torch.library.opcheck(numpy_sin, example)115116117######################################################################118# Other registrations119# -------------------120# Add these only when users need them.121#122# * **Multiple device kernels:** pass ``device_types="cpu"`` or123# ``device_types="cuda"`` if the implementation only works on one device.124# Register device-specific kernels when devices need different code.125# * **``torch.vmap``:** register a vmap rule with ``torch.library.register_vmap``126# when batching over the operator should do something different from a Python127# loop over the batch dimension.128# * **Tensor subclasses or modes:** use ``torch.library.register_torch_dispatch``129# when a Tensor subclass or ``TorchDispatchMode`` needs special behavior.130# * **Autocast:** for C++/CUDA operators that should participate in autocast,131# add an autocast registration as described in the C++ custom operator guide.132133134