Path: blob/main/advanced_source/python_custom_ops_mutable.py
6463 views
# -*- coding: utf-8 -*-12"""3.. _python-custom-ops-mutable:45Mutable Python Custom Operators6===============================78:ref:`Functional custom operators <python-custom-ops-functional>` showed9``numpy.sin`` as an operator that returns a fresh Tensor. This page shows the10mutable version: a kernel that writes ``sin(x)`` into an existing output Tensor.11Mutable operators have a different contract from functional operators.1213Before writing the operator, read the required schema and mutation/aliasing14contract rules in :ref:`python-custom-ops-schema-contract`.1516Checklist:1718* choose one mutation pattern and keep it stable;19* list every mutated Tensor argument in ``mutates_args``;20* do not return mutated inputs unless you are using a tagged in-place or21``out=`` operator (available starting in PyTorch 2.13);22* validate the operator with ``torch.library.opcheck``.23"""2425######################################################################26# Choose one mutation contract27# ----------------------------28# Choose the mutation behavior before adding optional registrations. PyTorch29# needs this contract for functionalization in ``torch.compile``30# and autograd.31#32# If the operator does not mutate any Tensor input, use the functional operator33# path instead.34#35# If the operator mutates the first positional Tensor and returns it, use a36# tagged in-place operator, starting in PyTorch 2.13.37#38# If the operator accepts write-only keyword-only ``out=`` Tensor arguments and39# returns them, use a tagged ``out=`` operator, starting in PyTorch 2.13.40#41# For other mutable operators, list every mutated argument in ``mutates_args``42# and do not return mutated inputs or their aliases.4344import numpy as np45import torch46from torch import Tensor474849######################################################################50# Example: write NumPy sin into an output buffer51# ----------------------------------------------52# Functions that mutate inputs are common because that is how many low-level53# kernels are written; for example, a kernel that computes ``sin`` may take in54# the input and an output tensor and write ``input.sin()`` to the output tensor.55#56# This operator writes ``sin(x)`` into ``out`` and returns ``None``.575859@torch.library.custom_op(60"mylib_mutable::numpy_sin_out",61mutates_args={"out"},62device_types="cpu",63)64def numpy_sin_out(x: Tensor, out: Tensor) -> None:65if x.shape != out.shape:66raise RuntimeError("x and out must have the same shape")67if x.dtype != out.dtype:68raise RuntimeError("x and out must have the same dtype")69if x.device != out.device:70raise RuntimeError("x and out must be on the same device")71np.sin(x.detach().numpy(), out=out.numpy())727374x = torch.randn(5)75out = torch.empty_like(x)76numpy_sin_out(x, out)77torch.testing.assert_close(out, x.sin())7879######################################################################80# Because the operator doesn't return anything, there is no need to register a81# ``FakeTensor`` kernel (meta kernel) to get it to work with ``torch.compile``.82# If a mutable operator also returns a fresh Tensor, register a fake kernel for83# that output.848586@torch.compile(fullgraph=True)87def compiled_numpy_sin_out(x):88out = torch.empty_like(x)89numpy_sin_out(x, out)90return out919293torch.testing.assert_close(compiled_numpy_sin_out(x), x.sin())9495######################################################################96# PyTorch-style in-place and out= operators97# -----------------------------------------98# Starting in PyTorch 2.13, ``torch.library.custom_op`` supports tagged99# in-place and ``out=`` custom operators.100# Tagged in-place operators return the same Tensor they mutate. Tagged ``out=``101# operators return their keyword-only output buffers in declaration order.102# This example uses ``mylib_mutable::sin_`` for a tagged in-place custom103# operator and ``mylib_mutable::sin_out`` for a tagged ``out=`` custom operator.104105106supports_tagged_mutable_ops = (107hasattr(torch, "Tag")108and hasattr(torch.Tag, "inplace")109and hasattr(torch.Tag, "out")110)111112if supports_tagged_mutable_ops:113114@torch.library.custom_op(115"mylib_mutable::sin_",116mutates_args={"x"},117tags=torch.Tag.inplace,118)119def sin_(x: Tensor) -> Tensor:120x.sin_()121return x122123124@torch.library.custom_op(125"mylib_mutable::sin_out",126mutates_args={"out"},127tags=torch.Tag.out,128)129def sin_out(x: Tensor, *, out: Tensor) -> Tensor:130torch.sin(x, out=out)131return out132133134x_for_inplace = torch.randn(3)135expected = x_for_inplace.sin()136torch.testing.assert_close(sin_(x_for_inplace), expected)137138out_for_sin = torch.empty_like(x)139torch.testing.assert_close(140sin_out(x, out=out_for_sin),141x.sin(),142)143torch.testing.assert_close(out_for_sin, x.sin())144145torch.library.opcheck(sin_, (torch.randn(3),))146torch.library.opcheck(147sin_out,148(torch.randn(3),),149{"out": torch.empty(3)},150)151else:152print("Tagged in-place and out= custom operators require PyTorch 2.13 or later.")153154155######################################################################156# Validate the operator157# ---------------------158# And here's an ``opcheck`` run telling us that we did indeed register the159# operator correctly. ``opcheck`` would error out if we forgot to add ``out`` to160# ``mutates_args``, for example.161162163examples = [164(torch.randn(5), torch.empty(5)),165(torch.randn(0, 3), torch.empty(0, 3)),166(167torch.randn(2, 3, dtype=torch.double),168torch.empty(2, 3, dtype=torch.double),169),170(171torch.randn(2, 3).t(),172torch.empty_strided((3, 2), (1, 3)),173),174]175176for example in examples:177torch.library.opcheck(numpy_sin_out, example)178179######################################################################180# For autograd, ``torch.vmap``, or other subsystem behavior, continue to181# :ref:`python-custom-ops-registrations`.182183184