Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/advanced_source/python_custom_ops.py
Views: 712
# -*- coding: utf-8 -*-12"""3.. _python-custom-ops-tutorial:45Python Custom Operators6=======================78.. grid:: 2910.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn11:class-card: card-prerequisites1213* How to integrate custom operators written in Python with PyTorch14* How to test custom operators using ``torch.library.opcheck``1516.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites17:class-card: card-prerequisites1819* PyTorch 2.4 or later2021PyTorch offers a large library of operators that work on Tensors (e.g.22``torch.add``, ``torch.sum``, etc). However, you might wish to use a new customized23operator with PyTorch, perhaps written by a third-party library. This tutorial24shows how to wrap Python functions so that they behave like PyTorch native25operators. Reasons why you may wish to create a custom operator in PyTorch include:2627- Treating an arbitrary Python function as an opaque callable with respect28to ``torch.compile`` (that is, prevent ``torch.compile`` from tracing29into the function).30- Adding training support to an arbitrary Python function3132Please note that if your operation can be expressed as a composition of33existing PyTorch operators, then there is usually no need to use the custom operator34API -- everything (for example ``torch.compile``, training support) should35just work.36"""37######################################################################38# Example: Wrapping PIL's crop into a custom operator39# ------------------------------------40# Let's say that we are using PIL's ``crop`` operation.4142import torch43from torchvision.transforms.functional import to_pil_image, pil_to_tensor44import PIL45import IPython46import matplotlib.pyplot as plt4748def crop(pic, box):49img = to_pil_image(pic.cpu())50cropped_img = img.crop(box)51return pil_to_tensor(cropped_img).to(pic.device) / 255.5253def display(img):54plt.imshow(img.numpy().transpose((1, 2, 0)))5556img = torch.ones(3, 64, 64)57img *= torch.linspace(0, 1, steps=64) * torch.linspace(0, 1, steps=64).unsqueeze(-1)58display(img)5960######################################################################6162cropped_img = crop(img, (10, 10, 50, 50))63display(cropped_img)6465######################################################################66# ``crop`` is not handled effectively out-of-the-box by67# ``torch.compile``: ``torch.compile`` induces a68# `"graph break" <https://pytorch.org/docs/stable/torch.compiler_faq.html#graph-breaks>`_69# on functions it is unable to handle and graph breaks are bad for performance.70# The following code demonstrates this by raising an error71# (``torch.compile`` with ``fullgraph=True`` raises an error if a72# graph break occurs).7374@torch.compile(fullgraph=True)75def f(img):76return crop(img, (10, 10, 50, 50))7778# The following raises an error. Uncomment the line to see it.79# cropped_img = f(img)8081######################################################################82# In order to black-box ``crop`` for use with ``torch.compile``, we need to83# do two things:84#85# 1. wrap the function into a PyTorch custom operator.86# 2. add a "``FakeTensor`` kernel" (aka "meta kernel") to the operator.87# Given some ``FakeTensors`` inputs (dummy Tensors that don't have storage),88# this function should return dummy Tensors of your choice with the correct89# Tensor metadata (shape/strides/``dtype``/device).909192from typing import Sequence9394# Use torch.library.custom_op to define a new custom operator.95# If your operator mutates any input Tensors, their names must be specified96# in the ``mutates_args`` argument.97@torch.library.custom_op("mylib::crop", mutates_args=())98def crop(pic: torch.Tensor, box: Sequence[int]) -> torch.Tensor:99img = to_pil_image(pic.cpu())100cropped_img = img.crop(box)101return (pil_to_tensor(cropped_img) / 255.).to(pic.device, pic.dtype)102103# Use register_fake to add a ``FakeTensor`` kernel for the operator104@crop.register_fake105def _(pic, box):106channels = pic.shape[0]107x0, y0, x1, y1 = box108return pic.new_empty(channels, y1 - y0, x1 - x0)109110######################################################################111# After this, ``crop`` now works without graph breaks:112113@torch.compile(fullgraph=True)114def f(img):115return crop(img, (10, 10, 50, 50))116117cropped_img = f(img)118display(img)119120######################################################################121122display(cropped_img)123124######################################################################125# Adding training support for crop126# --------------------------------127# Use ``torch.library.register_autograd`` to add training support for an operator.128# Prefer this over directly using ``torch.autograd.Function``; some compositions of129# ``autograd.Function`` with PyTorch operator registration APIs can lead to (and130# has led to) silent incorrectness when composed with ``torch.compile``.131#132# If you don't need training support, there is no need to use133# ``torch.library.register_autograd``.134# If you end up training with a ``custom_op`` that doesn't have an autograd135# registration, we'll raise an error message.136#137# The gradient formula for ``crop`` is essentially ``PIL.paste`` (we'll leave the138# derivation as an exercise to the reader). Let's first wrap ``paste`` into a139# custom operator:140141@torch.library.custom_op("mylib::paste", mutates_args=())142def paste(im1: torch.Tensor, im2: torch.Tensor, coord: Sequence[int]) -> torch.Tensor:143assert im1.device == im2.device144assert im1.dtype == im2.dtype145im1_pil = to_pil_image(im1.cpu())146im2_pil = to_pil_image(im2.cpu())147PIL.Image.Image.paste(im1_pil, im2_pil, coord)148return (pil_to_tensor(im1_pil) / 255.).to(im1.device, im1.dtype)149150@paste.register_fake151def _(im1, im2, coord):152assert im1.device == im2.device153assert im1.dtype == im2.dtype154return torch.empty_like(im1)155156######################################################################157# And now let's use ``register_autograd`` to specify the gradient formula for ``crop``:158159def backward(ctx, grad_output):160grad_input = grad_output.new_zeros(ctx.pic_shape)161grad_input = paste(grad_input, grad_output, ctx.coords)162return grad_input, None163164def setup_context(ctx, inputs, output):165pic, box = inputs166ctx.coords = box[:2]167ctx.pic_shape = pic.shape168169crop.register_autograd(backward, setup_context=setup_context)170171######################################################################172# Note that the backward must be a composition of PyTorch-understood operators,173# which is why we wrapped paste into a custom operator instead of directly using174# PIL's paste.175176img = img.requires_grad_()177result = crop(img, (10, 10, 50, 50))178result.sum().backward()179display(img.grad)180181######################################################################182# This is the correct gradient, with 1s (white) in the cropped region and 0s183# (black) in the unused region.184185######################################################################186# Testing Python Custom operators187# -------------------------------188# Use ``torch.library.opcheck`` to test that the custom operator was registered189# correctly. This does not test that the gradients are mathematically correct;190# please write separate tests for that (either manual ones or ``torch.autograd.gradcheck``).191#192# To use ``opcheck``, pass it a set of example inputs to test against. If your193# operator supports training, then the examples should include Tensors that194# require grad. If your operator supports multiple devices, then the examples195# should include Tensors from each device.196197examples = [198[torch.randn(3, 64, 64), [0, 0, 10, 10]],199[torch.randn(3, 91, 91, requires_grad=True), [10, 0, 20, 10]],200[torch.randn(3, 60, 60, dtype=torch.double), [3, 4, 32, 20]],201[torch.randn(3, 512, 512, requires_grad=True, dtype=torch.double), [3, 4, 32, 45]],202]203204for example in examples:205torch.library.opcheck(crop, example)206207######################################################################208# Mutable Python Custom operators209# -------------------------------210# You can also wrap a Python function that mutates its inputs into a custom211# operator.212# Functions that mutate inputs are common because that is how many low-level213# kernels are written; for example, a kernel that computes ``sin`` may take in214# the input and an output tensor and write ``input.sin()`` to the output tensor.215#216# We'll use ``numpy.sin`` to demonstrate an example of a mutable Python217# custom operator.218219import numpy as np220221@torch.library.custom_op("mylib::numpy_sin", mutates_args={"output"}, device_types="cpu")222def numpy_sin(input: torch.Tensor, output: torch.Tensor) -> None:223assert input.device == output.device224assert input.device.type == "cpu"225input_np = input.numpy()226output_np = output.numpy()227np.sin(input_np, out=output_np)228229######################################################################230# Because the operator doesn't return anything, there is no need to register231# a ``FakeTensor`` kernel (meta kernel) to get it to work with ``torch.compile``.232233@torch.compile(fullgraph=True)234def f(x):235out = torch.empty(3)236numpy_sin(x, out)237return out238239x = torch.randn(3)240y = f(x)241assert torch.allclose(y, x.sin())242243######################################################################244# And here's an ``opcheck`` run telling us that we did indeed register the operator correctly.245# ``opcheck`` would error out if we forgot to add the output to ``mutates_args``, for example.246247example_inputs = [248[torch.randn(3), torch.empty(3)],249[torch.randn(0, 3), torch.empty(0, 3)],250[torch.randn(1, 2, 3, 4, dtype=torch.double), torch.empty(1, 2, 3, 4, dtype=torch.double)],251]252253for example in example_inputs:254torch.library.opcheck(numpy_sin, example)255256######################################################################257# Conclusion258# ----------259# In this tutorial, we learned how to use ``torch.library.custom_op`` to260# create a custom operator in Python that works with PyTorch subsystems261# such as ``torch.compile`` and autograd.262#263# This tutorial provides a basic introduction to custom operators.264# For more detailed information, see:265#266# - `the torch.library documentation <https://pytorch.org/docs/stable/library.html>`_267# - `the Custom Operators Manual <https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html#the-custom-operators-manual>`_268#269270271