Path: blob/main/advanced_source/python_custom_ops_functional.py
6463 views
# -*- coding: utf-8 -*-12"""3.. _python-custom-ops-functional:45Functional Python Custom Operators6==================================78Use this path when the operator mutates no Tensor inputs and returns fresh9Tensor outputs.1011If the operator must work with ``torch.compile`` or ``torch.export``,12:ref:`register a fake kernel <python-custom-ops-functional-register-fake>`.13The fake kernel describes output metadata without running the real kernel.1415Before writing the operator, read the required schema and mutation/aliasing16contract rules in :ref:`python-custom-ops-schema-contract`.1718Checklist:1920* use ``mutates_args=()``;21* return tensors that do not alias any input;22* register a fake kernel for ``torch.compile`` and ``torch.export``;23* validate the operator with ``torch.library.opcheck``.24"""2526######################################################################27# Example: wrapping NumPy sin into a custom operator28# --------------------------------------------------29# Let's say that we are using NumPy's ``sin`` operation. This is an ordinary30# Python function from PyTorch's point of view: it converts the Tensor to a31# NumPy array, calls NumPy, and returns a fresh Tensor.3233import numpy as np34import torch35from torch import Tensor363738def numpy_sin_impl(x: Tensor) -> Tensor:39result = torch.empty_like(x)40np.sin(x.detach().numpy(), out=result.numpy())41return result424344x = torch.randn(5)45torch.testing.assert_close(numpy_sin_impl(x), x.sin())4647# This small example focuses on the custom-operator mechanics. More complex48# Python or third-party library calls may not be handled effectively49# out-of-the-box by ``torch.compile``: ``torch.compile`` may induce a50# `"graph break" <https://pytorch.org/docs/stable/torch.compiler_faq.html#graph-breaks>`_51# on functions it is unable to handle, and graph breaks are bad for performance.52# A custom operator gives PyTorch an explicit boundary for such code.53#54# To make ``numpy_sin_impl`` available as a custom operator that works with55# ``torch.compile`` and ``torch.export``, we need to do two things:56#57# 1. wrap the function into a PyTorch custom operator.58# 2. add a "``FakeTensor`` kernel" (aka "meta kernel") to the operator.59# Given some ``FakeTensors`` inputs (dummy Tensors that don't have storage),60# this function should return dummy Tensors of your choice with the correct61# Tensor metadata (shape/strides/``dtype``/device).626364@torch.library.custom_op(65"mylib_functional::numpy_sin",66mutates_args=(),67device_types="cpu",68)69def numpy_sin(x: Tensor) -> Tensor:70result = torch.empty_like(x)71np.sin(x.detach().numpy(), out=result.numpy())72return result737475######################################################################76# .. _python-custom-ops-functional-register-fake:77#78# Use ``register_fake`` to add a ``FakeTensor`` kernel for the operator.79# ``numpy_sin`` returns one Tensor with the same shape, strides, dtype, device,80# and storage offset as ``torch.empty_like(x)``, so the fake kernel can return81# ``empty_like(x)``. In general, the fake kernel must match all output metadata,82# including storage offset when relevant.838485@numpy_sin.register_fake86def _(x):87return torch.empty_like(x)888990######################################################################91# After this, ``numpy_sin`` can be used under ``torch.compile``:929394@torch.compile(fullgraph=True)95def f(x):96return numpy_sin(x)979899result = f(x)100torch.testing.assert_close(result, x.sin())101102######################################################################103# A PIL image transform, Python binding to a C++ extension, or another104# third-party library call follows the same pattern. If it returns tensors,105# write the fake kernel to match the real output metadata exactly: shape,106# strides, dtype, device, layout, and storage offset when relevant.107108######################################################################109# Example: fake kernels must match strides110# ----------------------------------------111# The fake kernel must match the real output strides, not only the shape. This112# operator returns a fresh Tensor with the same shape as ``x`` but different113# strides.114115116def numpy_sin_strided_impl(x: Tensor) -> Tensor:117result = torch.empty_strided(118x.shape,119tuple(reversed(x.stride())),120dtype=x.dtype,121device=x.device,122)123np.sin(x.detach().numpy(), out=result.numpy())124return result125126127@torch.library.custom_op(128"mylib_functional::numpy_sin_strided_bad",129mutates_args=(),130device_types="cpu",131)132def numpy_sin_strided_bad(x: Tensor) -> Tensor:133return numpy_sin_strided_impl(x)134135136@numpy_sin_strided_bad.register_fake137def _(x):138return torch.empty_like(x)139140141try:142torch.library.opcheck(numpy_sin_strided_bad, (torch.randn(2, 3),))143except Exception as exc:144print(f"opcheck caught incorrect fake kernel metadata: {type(exc).__name__}")145else:146torch_version = tuple(147int(part) for part in torch.__version__.split("+")[0].split(".")[:2]148)149if torch_version >= (2, 13):150raise AssertionError("Expected opcheck to fail")151print("PyTorch versions before 2.13 may not catch this metadata mismatch")152153154@torch.library.custom_op(155"mylib_functional::numpy_sin_strided",156mutates_args=(),157device_types="cpu",158)159def numpy_sin_strided(x: Tensor) -> Tensor:160return numpy_sin_strided_impl(x)161162163@numpy_sin_strided.register_fake164def _(x):165return torch.empty_strided(166x.shape,167tuple(reversed(x.stride())),168dtype=x.dtype,169device=x.device,170)171172173torch.library.opcheck(numpy_sin_strided, (torch.randn(2, 3),))174175######################################################################176# Testing Python custom operators177# -------------------------------178# Use ``torch.library.opcheck`` to test that the custom operator was registered179# correctly. This does not test numerical correctness; write separate tests for180# that.181#182# To use ``opcheck``, pass it a set of example inputs to test against. If your183# operator supports training, then the examples should include Tensors that184# require grad. If your operator supports multiple devices, then the examples185# should include Tensors from each device.186187188examples = [189(torch.randn(5),),190(torch.randn(0, 3),),191(torch.randn(2, 3, dtype=torch.double),),192(torch.randn(2, 3).t(),),193(torch.randn(8)[1:],),194]195196for example in examples:197torch.library.opcheck(numpy_sin, example)198199######################################################################200# To add autograd, ``torch.vmap``, or other subsystem support, continue to201# :ref:`python-custom-ops-registrations`.202203204