Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/beginner_source/onnx/onnx_registry_tutorial.py
Views: 713
# -*- coding: utf-8 -*-12"""3`Introduction to ONNX <intro_onnx.html>`_ ||4`Exporting a PyTorch model to ONNX <export_simple_model_to_onnx_tutorial.html>`_ ||5**Extending the ONNX Registry**67Extending the ONNX Registry8===========================910**Authors:** Ti-Tai Wang ([email protected])11"""121314###############################################################################15# Overview16# --------17#18# This tutorial is an introduction to ONNX registry, which empowers users to implement new ONNX operators19# or even replace existing operators with a new implementation.20#21# During the model export to ONNX, the PyTorch model is lowered to an intermediate22# representation composed of `ATen operators <https://pytorch.org/docs/stable/torch.compiler_ir.html>`_.23# While ATen operators are maintained by PyTorch core team, it is the responsibility of the ONNX exporter team24# to independently implement each of these operators to ONNX through `ONNX Script <https://onnxscript.ai/>`_.25# The users can also replace the behavior implemented by the ONNX exporter team with their own implementation26# to fix bugs or improve performance for a specific ONNX runtime.27#28# The ONNX Registry manages the mapping between PyTorch operators and the ONNX operators counterparts and provides29# APIs to extend the registry.30#31# In this tutorial, we will cover three scenarios that require extending the ONNX registry with custom operators:32#33# * Unsupported ATen operators34# * Custom operators with existing ONNX Runtime support35# * Custom operators without ONNX Runtime support36#37# Unsupported ATen operators38# --------------------------39#40# Although the ONNX exporter team does their best efforts to support all ATen operators, some of them41# might not be supported yet. In this section, we will demonstrate how you can add42# unsupported ATen operators to the ONNX Registry.43#44# .. note::45# The steps to implement unsupported ATen operators are the same to replace the implementation of an existing46# ATen operator with a custom implementation.47# Because we don't actually have an unsupported ATen operator to use in this tutorial, we are going to leverage48# this and replace the implementation of ``aten::add.Tensor`` with a custom implementation the same way we would49# if the operator was not present in the ONNX Registry.50#51# When a model cannot be exported to ONNX due to an unsupported operator, the ONNX exporter will show an error message52# similar to:53#54# .. code-block:: python55#56# RuntimeErrorWithDiagnostic: Unsupported FX nodes: {'call_function': ['aten.add.Tensor']}.57#58# The error message indicates that the fully qualified name of unsupported ATen operator is ``aten::add.Tensor``.59# The fully qualified name of an operator is composed of the namespace, operator name, and overload following60# the format ``namespace::operator_name.overload``.61#62# To add support for an unsupported ATen operator or to replace the implementation for an existing one, we need:63#64# * The fully qualified name of the ATen operator (e.g. ``aten::add.Tensor``).65# This information is always present in the error message as show above.66# * The implementation of the operator using `ONNX Script <https://github.com/microsoft/onnxscript>`__.67# ONNX Script is a prerequisite for this tutorial. Please make sure you have read the68# `ONNX Script tutorial <https://github.com/microsoft/onnxscript/blob/main/docs/tutorial/index.md>`_69# before proceeding.70#71# Because ``aten::add.Tensor`` is already supported by the ONNX Registry, we will demonstrate how to replace it with a72# custom implementation, but keep in mind that the same steps apply to support new unsupported ATen operators.73#74# This is possible because the :class:`OnnxRegistry` allows users to override an operator registration.75# We will override the registration of ``aten::add.Tensor`` with our custom implementation and verify it exists.76#7778import torch79import onnxruntime80import onnxscript81from onnxscript import opset18 # opset 18 is the latest (and only) supported version for now8283class Model(torch.nn.Module):84def forward(self, input_x, input_y):85return torch.ops.aten.add(input_x, input_y) # generates a aten::add.Tensor node8687input_add_x = torch.randn(3, 4)88input_add_y = torch.randn(3, 4)89aten_add_model = Model()909192# Now we create a ONNX Script function that implements ``aten::add.Tensor``.93# The function name (e.g. ``custom_aten_add``) is displayed in the ONNX graph, so we recommend to use intuitive names.94custom_aten = onnxscript.values.Opset(domain="custom.aten", version=1)9596# NOTE: The function signature must match the signature of the unsupported ATen operator.97# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml98# NOTE: All attributes must be annotated with type hints.99@onnxscript.script(custom_aten)100def custom_aten_add(input_x, input_y, alpha: float = 1.0):101input_y = opset18.Mul(input_y, alpha)102return opset18.Add(input_x, input_y)103104105# Now we have everything we need to support unsupported ATen operators.106# Let's register the ``custom_aten_add`` function to ONNX registry, and export the model to ONNX again.107onnx_registry = torch.onnx.OnnxRegistry()108onnx_registry.register_op(109namespace="aten", op_name="add", overload="Tensor", function=custom_aten_add110)111print(f"aten::add.Tensor is supported by ONNX registry: \112{onnx_registry.is_registered_op(namespace='aten', op_name='add', overload='Tensor')}"113)114export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry)115onnx_program = torch.onnx.dynamo_export(116aten_add_model, input_add_x, input_add_y, export_options=export_options117)118119######################################################################120# Now let's inspect the model and verify the model has a ``custom_aten_add`` instead of ``aten::add.Tensor``.121# The graph has one graph node for ``custom_aten_add``, and inside of it there are four function nodes, one for each122# operator, and one for constant attribute.123#124125# graph node domain is the custom domain we registered126assert onnx_program.model_proto.graph.node[0].domain == "custom.aten"127assert len(onnx_program.model_proto.graph.node) == 1128# graph node name is the function name129assert onnx_program.model_proto.graph.node[0].op_type == "custom_aten_add"130# function node domain is empty because we use standard ONNX operators131assert {node.domain for node in onnx_program.model_proto.functions[0].node} == {""}132# function node name is the standard ONNX operator name133assert {node.op_type for node in onnx_program.model_proto.functions[0].node} == {"Add", "Mul", "Constant"}134135136######################################################################137# This is how ``custom_aten_add_model`` looks in the ONNX graph using Netron:138#139# .. image:: /_static/img/onnx/custom_aten_add_model.png140# :width: 70%141# :align: center142#143# Inside the ``custom_aten_add`` function, we can see the three ONNX nodes we144# used in the function (``CastLike``, ``Add``, and ``Mul``), and one ``Constant`` attribute:145#146# .. image:: /_static/img/onnx/custom_aten_add_function.png147# :width: 70%148# :align: center149#150# This was all that we needed to register the new ATen operator into the ONNX Registry.151# As an additional step, we can use ONNX Runtime to run the model, and compare the results with PyTorch.152#153154155# Use ONNX Runtime to run the model, and compare the results with PyTorch156onnx_program.save("./custom_add_model.onnx")157ort_session = onnxruntime.InferenceSession(158"./custom_add_model.onnx", providers=['CPUExecutionProvider']159)160161def to_numpy(tensor):162return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()163164onnx_input = onnx_program.adapt_torch_inputs_to_onnx(input_add_x, input_add_y)165onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}166onnxruntime_outputs = ort_session.run(None, onnxruntime_input)167168torch_outputs = aten_add_model(input_add_x, input_add_y)169torch_outputs = onnx_program.adapt_torch_outputs_to_onnx(torch_outputs)170171assert len(torch_outputs) == len(onnxruntime_outputs)172for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):173torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output))174175176######################################################################177# Custom operators with existing ONNX Runtime support178# ---------------------------------------------------179#180# In this case, the user creates a model with standard PyTorch operators, but the ONNX runtime181# (e.g. Microsoft's ONNX Runtime) can provide a custom implementation for that kernel, effectively replacing the182# existing implementation in the ONNX Registry. Another use case is when the user wants to use a custom implementation183# of an existing ONNX operator to fix a bug or improve performance of a specific operator.184# To achieve this, we only need to register the new implementation with the existing ATen fully qualified name.185#186# In the following example, we use the ``com.microsoft.Gelu`` from ONNX Runtime,187# which is not the same ``Gelu`` from ONNX spec. Thus, we register the Gelu with188# the namespace ``com.microsoft`` and operator name ``Gelu``.189#190# Before we begin, let's check whether ``aten::gelu.default`` is really supported by the ONNX registry.191192onnx_registry = torch.onnx.OnnxRegistry()193print(f"aten::gelu.default is supported by ONNX registry: \194{onnx_registry.is_registered_op(namespace='aten', op_name='gelu', overload='default')}")195196197######################################################################198# In our example, ``aten::gelu.default`` operator is supported by the ONNX registry,199# so :meth:`onnx_registry.is_registered_op` returns ``True``.200201class CustomGelu(torch.nn.Module):202def forward(self, input_x):203return torch.ops.aten.gelu(input_x)204205# com.microsoft is an official ONNX Runtime namspace206custom_ort = onnxscript.values.Opset(domain="com.microsoft", version=1)207208# NOTE: The function signature must match the signature of the unsupported ATen operator.209# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml210# NOTE: All attributes must be annotated with type hints.211@onnxscript.script(custom_ort)212def custom_aten_gelu(input_x, approximate: str = "none"):213# We know com.microsoft::Gelu is supported by ONNX Runtime214# It's only not supported by ONNX215return custom_ort.Gelu(input_x)216217218onnx_registry = torch.onnx.OnnxRegistry()219onnx_registry.register_op(220namespace="aten", op_name="gelu", overload="default", function=custom_aten_gelu)221export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry)222223aten_gelu_model = CustomGelu()224input_gelu_x = torch.randn(3, 3)225226onnx_program = torch.onnx.dynamo_export(227aten_gelu_model, input_gelu_x, export_options=export_options228)229230231######################################################################232# Let's inspect the model and verify the model uses op_type ``Gelu``233# from namespace ``com.microsoft``.234#235# .. note::236# :func:`custom_aten_gelu` does not exist in the graph because237# functions with fewer than three operators are inlined automatically.238#239240# graph node domain is the custom domain we registered241assert onnx_program.model_proto.graph.node[0].domain == "com.microsoft"242# graph node name is the function name243assert onnx_program.model_proto.graph.node[0].op_type == "Gelu"244245246######################################################################247# The following diagram shows ``custom_aten_gelu_model`` ONNX graph using Netron,248# we can see the ``Gelu`` node from module ``com.microsoft`` used in the function:249#250# .. image:: /_static/img/onnx/custom_aten_gelu_model.png251#252# That is all we need to do. As an additional step, we can use ONNX Runtime to run the model,253# and compare the results with PyTorch.254#255256onnx_program.save("./custom_gelu_model.onnx")257ort_session = onnxruntime.InferenceSession(258"./custom_gelu_model.onnx", providers=['CPUExecutionProvider']259)260261def to_numpy(tensor):262return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()263264onnx_input = onnx_program.adapt_torch_inputs_to_onnx(input_gelu_x)265onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}266onnxruntime_outputs = ort_session.run(None, onnxruntime_input)267268torch_outputs = aten_gelu_model(input_gelu_x)269torch_outputs = onnx_program.adapt_torch_outputs_to_onnx(torch_outputs)270271assert len(torch_outputs) == len(onnxruntime_outputs)272for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):273torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output))274275######################################################################276# Custom operators without ONNX Runtime support277# ---------------------------------------------278#279# In this case, the operator is not supported by any ONNX runtime, but we280# would like to use it as custom operator in ONNX graph. Therefore, we need to implement281# the operator in three places:282#283# 1. PyTorch FX graph284# 2. ONNX Registry285# 3. ONNX Runtime286#287# In the following example, we would like to use a custom operator288# that takes one tensor input, and returns one output. The operator adds289# the input to itself, and returns the rounded result.290#291#292# Custom Ops Registration in PyTorch FX Graph (Beta)293# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^294#295# Firstly, we need to implement the operator in PyTorch FX graph.296# This can be done by using ``torch._custom_op``.297#298299# NOTE: This is a beta feature in PyTorch, and is subject to change.300from torch._custom_op import impl as custom_op301302@custom_op.custom_op("mylibrary::addandround_op")303def addandround_op(tensor_x: torch.Tensor) -> torch.Tensor:304...305306@addandround_op.impl_abstract()307def addandround_op_impl_abstract(tensor_x):308return torch.empty_like(tensor_x)309310@addandround_op.impl("cpu")311def addandround_op_impl(tensor_x):312return torch.round(tensor_x + tensor_x) # add x to itself, and round the result313314torch._dynamo.allow_in_graph(addandround_op)315316class CustomFoo(torch.nn.Module):317def forward(self, tensor_x):318return addandround_op(tensor_x)319320input_addandround_x = torch.randn(3)321custom_addandround_model = CustomFoo()322323324######################################################################325#326# Custom Ops Registration in ONNX Registry327# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^328#329# For the step 2 and 3, we need to implement the operator in ONNX registry.330# In this example, we will implement the operator in ONNX registry331# with the namespace ``test.customop`` and operator name ``CustomOpOne``,332# and ``CustomOpTwo``. These two ops are registered and built in333# `cpu_ops.cc <https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc>`__.334#335336337custom_opset = onnxscript.values.Opset(domain="test.customop", version=1)338339# NOTE: The function signature must match the signature of the unsupported ATen operator.340# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml341# NOTE: All attributes must be annotated with type hints.342@onnxscript.script(custom_opset)343def custom_addandround(input_x):344# The same as opset18.Add(x, x)345add_x = custom_opset.CustomOpOne(input_x, input_x)346# The same as opset18.Round(x, x)347round_x = custom_opset.CustomOpTwo(add_x)348# Cast to FLOAT to match the ONNX type349return opset18.Cast(round_x, to=1)350351352onnx_registry = torch.onnx.OnnxRegistry()353onnx_registry.register_op(354namespace="mylibrary", op_name="addandround_op", overload="default", function=custom_addandround355)356357export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry)358onnx_program = torch.onnx.dynamo_export(359custom_addandround_model, input_addandround_x, export_options=export_options360)361onnx_program.save("./custom_addandround_model.onnx")362363364######################################################################365# The ``onnx_program`` exposes the exported model as protobuf through ``onnx_program.model_proto``.366# The graph has one graph nodes for ``custom_addandround``, and inside ``custom_addandround``,367# there are two function nodes, one for each operator.368#369370assert onnx_program.model_proto.graph.node[0].domain == "test.customop"371assert onnx_program.model_proto.graph.node[0].op_type == "custom_addandround"372assert onnx_program.model_proto.functions[0].node[0].domain == "test.customop"373assert onnx_program.model_proto.functions[0].node[0].op_type == "CustomOpOne"374assert onnx_program.model_proto.functions[0].node[1].domain == "test.customop"375assert onnx_program.model_proto.functions[0].node[1].op_type == "CustomOpTwo"376377378######################################################################379# This is how ``custom_addandround_model`` ONNX graph looks using Netron:380#381# .. image:: /_static/img/onnx/custom_addandround_model.png382# :width: 70%383# :align: center384#385# Inside the ``custom_addandround`` function, we can see the two custom operators we386# used in the function (``CustomOpOne``, and ``CustomOpTwo``), and they are from module387# ``test.customop``:388#389# .. image:: /_static/img/onnx/custom_addandround_function.png390#391# Custom Ops Registration in ONNX Runtime392# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^393#394# To link your custom op library to ONNX Runtime, you need to395# compile your C++ code into a shared library and link it to ONNX Runtime.396# Follow the instructions below:397#398# 1. Implement your custom op in C++ by following399# `ONNX Runtime instructions <`https://github.com/microsoft/onnxruntime/blob/gh-pages/docs/reference/operators/add-custom-op.md>`__.400# 2. Download ONNX Runtime source distribution from401# `ONNX Runtime releases <https://github.com/microsoft/onnxruntime/releases>`__.402# 3. Compile and link your custom op library to ONNX Runtime, for example:403#404# .. code-block:: bash405#406# $ gcc -shared -o libcustom_op_library.so custom_op_library.cc -L /path/to/downloaded/ort/lib/ -lonnxruntime -fPIC407#408# 4. Run the model with ONNX Runtime Python API and compare the results with PyTorch.409#410# .. code-block:: python411#412# ort_session_options = onnxruntime.SessionOptions()413#414# # NOTE: Link the custom op library to ONNX Runtime and replace the path415# # with the path to your custom op library416# ort_session_options.register_custom_ops_library(417# "/path/to/libcustom_op_library.so"418# )419# ort_session = onnxruntime.InferenceSession(420# "./custom_addandround_model.onnx", providers=['CPUExecutionProvider'], sess_options=ort_session_options)421#422# def to_numpy(tensor):423# return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()424#425# onnx_input = onnx_program.adapt_torch_inputs_to_onnx(input_addandround_x)426# onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}427# onnxruntime_outputs = ort_session.run(None, onnxruntime_input)428#429# torch_outputs = custom_addandround_model(input_addandround_x)430# torch_outputs = onnx_program.adapt_torch_outputs_to_onnx(torch_outputs)431#432# assert len(torch_outputs) == len(onnxruntime_outputs)433# for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):434# torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output))435#436# Conclusion437# ----------438#439# Congratulations! In this tutorial, we explored the :class:`ONNXRegistry` API and440# discovered how to create custom implementations for unsupported or existing ATen operators441# using ONNX Script.442# Finally, we leveraged ONNX Runtime to execute the model and compare the results with PyTorch,443# providing us with a comprehensive understanding of handling unsupported444# operators in the ONNX ecosystem.445#446# Further reading447# ---------------448#449# The list below refers to tutorials that ranges from basic examples to advanced scenarios,450# not necessarily in the order they are listed.451# Feel free to jump directly to specific topics of your interest or452# sit tight and have fun going through all of them to learn all there is about the ONNX exporter.453#454# .. include:: /beginner_source/onnx/onnx_toc.txt455#456# .. toctree::457# :hidden:458#459460461