CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/onnx/onnx_registry_tutorial.py
Views: 713
1
# -*- coding: utf-8 -*-
2
3
"""
4
`Introduction to ONNX <intro_onnx.html>`_ ||
5
`Exporting a PyTorch model to ONNX <export_simple_model_to_onnx_tutorial.html>`_ ||
6
**Extending the ONNX Registry**
7
8
Extending the ONNX Registry
9
===========================
10
11
**Authors:** Ti-Tai Wang ([email protected])
12
"""
13
14
15
###############################################################################
16
# Overview
17
# --------
18
#
19
# This tutorial is an introduction to ONNX registry, which empowers users to implement new ONNX operators
20
# or even replace existing operators with a new implementation.
21
#
22
# During the model export to ONNX, the PyTorch model is lowered to an intermediate
23
# representation composed of `ATen operators <https://pytorch.org/docs/stable/torch.compiler_ir.html>`_.
24
# While ATen operators are maintained by PyTorch core team, it is the responsibility of the ONNX exporter team
25
# to independently implement each of these operators to ONNX through `ONNX Script <https://onnxscript.ai/>`_.
26
# The users can also replace the behavior implemented by the ONNX exporter team with their own implementation
27
# to fix bugs or improve performance for a specific ONNX runtime.
28
#
29
# The ONNX Registry manages the mapping between PyTorch operators and the ONNX operators counterparts and provides
30
# APIs to extend the registry.
31
#
32
# In this tutorial, we will cover three scenarios that require extending the ONNX registry with custom operators:
33
#
34
# * Unsupported ATen operators
35
# * Custom operators with existing ONNX Runtime support
36
# * Custom operators without ONNX Runtime support
37
#
38
# Unsupported ATen operators
39
# --------------------------
40
#
41
# Although the ONNX exporter team does their best efforts to support all ATen operators, some of them
42
# might not be supported yet. In this section, we will demonstrate how you can add
43
# unsupported ATen operators to the ONNX Registry.
44
#
45
# .. note::
46
# The steps to implement unsupported ATen operators are the same to replace the implementation of an existing
47
# ATen operator with a custom implementation.
48
# Because we don't actually have an unsupported ATen operator to use in this tutorial, we are going to leverage
49
# this and replace the implementation of ``aten::add.Tensor`` with a custom implementation the same way we would
50
# if the operator was not present in the ONNX Registry.
51
#
52
# When a model cannot be exported to ONNX due to an unsupported operator, the ONNX exporter will show an error message
53
# similar to:
54
#
55
# .. code-block:: python
56
#
57
# RuntimeErrorWithDiagnostic: Unsupported FX nodes: {'call_function': ['aten.add.Tensor']}.
58
#
59
# The error message indicates that the fully qualified name of unsupported ATen operator is ``aten::add.Tensor``.
60
# The fully qualified name of an operator is composed of the namespace, operator name, and overload following
61
# the format ``namespace::operator_name.overload``.
62
#
63
# To add support for an unsupported ATen operator or to replace the implementation for an existing one, we need:
64
#
65
# * The fully qualified name of the ATen operator (e.g. ``aten::add.Tensor``).
66
# This information is always present in the error message as show above.
67
# * The implementation of the operator using `ONNX Script <https://github.com/microsoft/onnxscript>`__.
68
# ONNX Script is a prerequisite for this tutorial. Please make sure you have read the
69
# `ONNX Script tutorial <https://github.com/microsoft/onnxscript/blob/main/docs/tutorial/index.md>`_
70
# before proceeding.
71
#
72
# Because ``aten::add.Tensor`` is already supported by the ONNX Registry, we will demonstrate how to replace it with a
73
# custom implementation, but keep in mind that the same steps apply to support new unsupported ATen operators.
74
#
75
# This is possible because the :class:`OnnxRegistry` allows users to override an operator registration.
76
# We will override the registration of ``aten::add.Tensor`` with our custom implementation and verify it exists.
77
#
78
79
import torch
80
import onnxruntime
81
import onnxscript
82
from onnxscript import opset18 # opset 18 is the latest (and only) supported version for now
83
84
class Model(torch.nn.Module):
85
def forward(self, input_x, input_y):
86
return torch.ops.aten.add(input_x, input_y) # generates a aten::add.Tensor node
87
88
input_add_x = torch.randn(3, 4)
89
input_add_y = torch.randn(3, 4)
90
aten_add_model = Model()
91
92
93
# Now we create a ONNX Script function that implements ``aten::add.Tensor``.
94
# The function name (e.g. ``custom_aten_add``) is displayed in the ONNX graph, so we recommend to use intuitive names.
95
custom_aten = onnxscript.values.Opset(domain="custom.aten", version=1)
96
97
# NOTE: The function signature must match the signature of the unsupported ATen operator.
98
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
99
# NOTE: All attributes must be annotated with type hints.
100
@onnxscript.script(custom_aten)
101
def custom_aten_add(input_x, input_y, alpha: float = 1.0):
102
input_y = opset18.Mul(input_y, alpha)
103
return opset18.Add(input_x, input_y)
104
105
106
# Now we have everything we need to support unsupported ATen operators.
107
# Let's register the ``custom_aten_add`` function to ONNX registry, and export the model to ONNX again.
108
onnx_registry = torch.onnx.OnnxRegistry()
109
onnx_registry.register_op(
110
namespace="aten", op_name="add", overload="Tensor", function=custom_aten_add
111
)
112
print(f"aten::add.Tensor is supported by ONNX registry: \
113
{onnx_registry.is_registered_op(namespace='aten', op_name='add', overload='Tensor')}"
114
)
115
export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry)
116
onnx_program = torch.onnx.dynamo_export(
117
aten_add_model, input_add_x, input_add_y, export_options=export_options
118
)
119
120
######################################################################
121
# Now let's inspect the model and verify the model has a ``custom_aten_add`` instead of ``aten::add.Tensor``.
122
# The graph has one graph node for ``custom_aten_add``, and inside of it there are four function nodes, one for each
123
# operator, and one for constant attribute.
124
#
125
126
# graph node domain is the custom domain we registered
127
assert onnx_program.model_proto.graph.node[0].domain == "custom.aten"
128
assert len(onnx_program.model_proto.graph.node) == 1
129
# graph node name is the function name
130
assert onnx_program.model_proto.graph.node[0].op_type == "custom_aten_add"
131
# function node domain is empty because we use standard ONNX operators
132
assert {node.domain for node in onnx_program.model_proto.functions[0].node} == {""}
133
# function node name is the standard ONNX operator name
134
assert {node.op_type for node in onnx_program.model_proto.functions[0].node} == {"Add", "Mul", "Constant"}
135
136
137
######################################################################
138
# This is how ``custom_aten_add_model`` looks in the ONNX graph using Netron:
139
#
140
# .. image:: /_static/img/onnx/custom_aten_add_model.png
141
# :width: 70%
142
# :align: center
143
#
144
# Inside the ``custom_aten_add`` function, we can see the three ONNX nodes we
145
# used in the function (``CastLike``, ``Add``, and ``Mul``), and one ``Constant`` attribute:
146
#
147
# .. image:: /_static/img/onnx/custom_aten_add_function.png
148
# :width: 70%
149
# :align: center
150
#
151
# This was all that we needed to register the new ATen operator into the ONNX Registry.
152
# As an additional step, we can use ONNX Runtime to run the model, and compare the results with PyTorch.
153
#
154
155
156
# Use ONNX Runtime to run the model, and compare the results with PyTorch
157
onnx_program.save("./custom_add_model.onnx")
158
ort_session = onnxruntime.InferenceSession(
159
"./custom_add_model.onnx", providers=['CPUExecutionProvider']
160
)
161
162
def to_numpy(tensor):
163
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
164
165
onnx_input = onnx_program.adapt_torch_inputs_to_onnx(input_add_x, input_add_y)
166
onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}
167
onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
168
169
torch_outputs = aten_add_model(input_add_x, input_add_y)
170
torch_outputs = onnx_program.adapt_torch_outputs_to_onnx(torch_outputs)
171
172
assert len(torch_outputs) == len(onnxruntime_outputs)
173
for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):
174
torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output))
175
176
177
######################################################################
178
# Custom operators with existing ONNX Runtime support
179
# ---------------------------------------------------
180
#
181
# In this case, the user creates a model with standard PyTorch operators, but the ONNX runtime
182
# (e.g. Microsoft's ONNX Runtime) can provide a custom implementation for that kernel, effectively replacing the
183
# existing implementation in the ONNX Registry. Another use case is when the user wants to use a custom implementation
184
# of an existing ONNX operator to fix a bug or improve performance of a specific operator.
185
# To achieve this, we only need to register the new implementation with the existing ATen fully qualified name.
186
#
187
# In the following example, we use the ``com.microsoft.Gelu`` from ONNX Runtime,
188
# which is not the same ``Gelu`` from ONNX spec. Thus, we register the Gelu with
189
# the namespace ``com.microsoft`` and operator name ``Gelu``.
190
#
191
# Before we begin, let's check whether ``aten::gelu.default`` is really supported by the ONNX registry.
192
193
onnx_registry = torch.onnx.OnnxRegistry()
194
print(f"aten::gelu.default is supported by ONNX registry: \
195
{onnx_registry.is_registered_op(namespace='aten', op_name='gelu', overload='default')}")
196
197
198
######################################################################
199
# In our example, ``aten::gelu.default`` operator is supported by the ONNX registry,
200
# so :meth:`onnx_registry.is_registered_op` returns ``True``.
201
202
class CustomGelu(torch.nn.Module):
203
def forward(self, input_x):
204
return torch.ops.aten.gelu(input_x)
205
206
# com.microsoft is an official ONNX Runtime namspace
207
custom_ort = onnxscript.values.Opset(domain="com.microsoft", version=1)
208
209
# NOTE: The function signature must match the signature of the unsupported ATen operator.
210
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
211
# NOTE: All attributes must be annotated with type hints.
212
@onnxscript.script(custom_ort)
213
def custom_aten_gelu(input_x, approximate: str = "none"):
214
# We know com.microsoft::Gelu is supported by ONNX Runtime
215
# It's only not supported by ONNX
216
return custom_ort.Gelu(input_x)
217
218
219
onnx_registry = torch.onnx.OnnxRegistry()
220
onnx_registry.register_op(
221
namespace="aten", op_name="gelu", overload="default", function=custom_aten_gelu)
222
export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry)
223
224
aten_gelu_model = CustomGelu()
225
input_gelu_x = torch.randn(3, 3)
226
227
onnx_program = torch.onnx.dynamo_export(
228
aten_gelu_model, input_gelu_x, export_options=export_options
229
)
230
231
232
######################################################################
233
# Let's inspect the model and verify the model uses op_type ``Gelu``
234
# from namespace ``com.microsoft``.
235
#
236
# .. note::
237
# :func:`custom_aten_gelu` does not exist in the graph because
238
# functions with fewer than three operators are inlined automatically.
239
#
240
241
# graph node domain is the custom domain we registered
242
assert onnx_program.model_proto.graph.node[0].domain == "com.microsoft"
243
# graph node name is the function name
244
assert onnx_program.model_proto.graph.node[0].op_type == "Gelu"
245
246
247
######################################################################
248
# The following diagram shows ``custom_aten_gelu_model`` ONNX graph using Netron,
249
# we can see the ``Gelu`` node from module ``com.microsoft`` used in the function:
250
#
251
# .. image:: /_static/img/onnx/custom_aten_gelu_model.png
252
#
253
# That is all we need to do. As an additional step, we can use ONNX Runtime to run the model,
254
# and compare the results with PyTorch.
255
#
256
257
onnx_program.save("./custom_gelu_model.onnx")
258
ort_session = onnxruntime.InferenceSession(
259
"./custom_gelu_model.onnx", providers=['CPUExecutionProvider']
260
)
261
262
def to_numpy(tensor):
263
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
264
265
onnx_input = onnx_program.adapt_torch_inputs_to_onnx(input_gelu_x)
266
onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}
267
onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
268
269
torch_outputs = aten_gelu_model(input_gelu_x)
270
torch_outputs = onnx_program.adapt_torch_outputs_to_onnx(torch_outputs)
271
272
assert len(torch_outputs) == len(onnxruntime_outputs)
273
for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):
274
torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output))
275
276
######################################################################
277
# Custom operators without ONNX Runtime support
278
# ---------------------------------------------
279
#
280
# In this case, the operator is not supported by any ONNX runtime, but we
281
# would like to use it as custom operator in ONNX graph. Therefore, we need to implement
282
# the operator in three places:
283
#
284
# 1. PyTorch FX graph
285
# 2. ONNX Registry
286
# 3. ONNX Runtime
287
#
288
# In the following example, we would like to use a custom operator
289
# that takes one tensor input, and returns one output. The operator adds
290
# the input to itself, and returns the rounded result.
291
#
292
#
293
# Custom Ops Registration in PyTorch FX Graph (Beta)
294
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
295
#
296
# Firstly, we need to implement the operator in PyTorch FX graph.
297
# This can be done by using ``torch._custom_op``.
298
#
299
300
# NOTE: This is a beta feature in PyTorch, and is subject to change.
301
from torch._custom_op import impl as custom_op
302
303
@custom_op.custom_op("mylibrary::addandround_op")
304
def addandround_op(tensor_x: torch.Tensor) -> torch.Tensor:
305
...
306
307
@addandround_op.impl_abstract()
308
def addandround_op_impl_abstract(tensor_x):
309
return torch.empty_like(tensor_x)
310
311
@addandround_op.impl("cpu")
312
def addandround_op_impl(tensor_x):
313
return torch.round(tensor_x + tensor_x) # add x to itself, and round the result
314
315
torch._dynamo.allow_in_graph(addandround_op)
316
317
class CustomFoo(torch.nn.Module):
318
def forward(self, tensor_x):
319
return addandround_op(tensor_x)
320
321
input_addandround_x = torch.randn(3)
322
custom_addandround_model = CustomFoo()
323
324
325
######################################################################
326
#
327
# Custom Ops Registration in ONNX Registry
328
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
329
#
330
# For the step 2 and 3, we need to implement the operator in ONNX registry.
331
# In this example, we will implement the operator in ONNX registry
332
# with the namespace ``test.customop`` and operator name ``CustomOpOne``,
333
# and ``CustomOpTwo``. These two ops are registered and built in
334
# `cpu_ops.cc <https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc>`__.
335
#
336
337
338
custom_opset = onnxscript.values.Opset(domain="test.customop", version=1)
339
340
# NOTE: The function signature must match the signature of the unsupported ATen operator.
341
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
342
# NOTE: All attributes must be annotated with type hints.
343
@onnxscript.script(custom_opset)
344
def custom_addandround(input_x):
345
# The same as opset18.Add(x, x)
346
add_x = custom_opset.CustomOpOne(input_x, input_x)
347
# The same as opset18.Round(x, x)
348
round_x = custom_opset.CustomOpTwo(add_x)
349
# Cast to FLOAT to match the ONNX type
350
return opset18.Cast(round_x, to=1)
351
352
353
onnx_registry = torch.onnx.OnnxRegistry()
354
onnx_registry.register_op(
355
namespace="mylibrary", op_name="addandround_op", overload="default", function=custom_addandround
356
)
357
358
export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry)
359
onnx_program = torch.onnx.dynamo_export(
360
custom_addandround_model, input_addandround_x, export_options=export_options
361
)
362
onnx_program.save("./custom_addandround_model.onnx")
363
364
365
######################################################################
366
# The ``onnx_program`` exposes the exported model as protobuf through ``onnx_program.model_proto``.
367
# The graph has one graph nodes for ``custom_addandround``, and inside ``custom_addandround``,
368
# there are two function nodes, one for each operator.
369
#
370
371
assert onnx_program.model_proto.graph.node[0].domain == "test.customop"
372
assert onnx_program.model_proto.graph.node[0].op_type == "custom_addandround"
373
assert onnx_program.model_proto.functions[0].node[0].domain == "test.customop"
374
assert onnx_program.model_proto.functions[0].node[0].op_type == "CustomOpOne"
375
assert onnx_program.model_proto.functions[0].node[1].domain == "test.customop"
376
assert onnx_program.model_proto.functions[0].node[1].op_type == "CustomOpTwo"
377
378
379
######################################################################
380
# This is how ``custom_addandround_model`` ONNX graph looks using Netron:
381
#
382
# .. image:: /_static/img/onnx/custom_addandround_model.png
383
# :width: 70%
384
# :align: center
385
#
386
# Inside the ``custom_addandround`` function, we can see the two custom operators we
387
# used in the function (``CustomOpOne``, and ``CustomOpTwo``), and they are from module
388
# ``test.customop``:
389
#
390
# .. image:: /_static/img/onnx/custom_addandround_function.png
391
#
392
# Custom Ops Registration in ONNX Runtime
393
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
394
#
395
# To link your custom op library to ONNX Runtime, you need to
396
# compile your C++ code into a shared library and link it to ONNX Runtime.
397
# Follow the instructions below:
398
#
399
# 1. Implement your custom op in C++ by following
400
# `ONNX Runtime instructions <`https://github.com/microsoft/onnxruntime/blob/gh-pages/docs/reference/operators/add-custom-op.md>`__.
401
# 2. Download ONNX Runtime source distribution from
402
# `ONNX Runtime releases <https://github.com/microsoft/onnxruntime/releases>`__.
403
# 3. Compile and link your custom op library to ONNX Runtime, for example:
404
#
405
# .. code-block:: bash
406
#
407
# $ gcc -shared -o libcustom_op_library.so custom_op_library.cc -L /path/to/downloaded/ort/lib/ -lonnxruntime -fPIC
408
#
409
# 4. Run the model with ONNX Runtime Python API and compare the results with PyTorch.
410
#
411
# .. code-block:: python
412
#
413
# ort_session_options = onnxruntime.SessionOptions()
414
#
415
# # NOTE: Link the custom op library to ONNX Runtime and replace the path
416
# # with the path to your custom op library
417
# ort_session_options.register_custom_ops_library(
418
# "/path/to/libcustom_op_library.so"
419
# )
420
# ort_session = onnxruntime.InferenceSession(
421
# "./custom_addandround_model.onnx", providers=['CPUExecutionProvider'], sess_options=ort_session_options)
422
#
423
# def to_numpy(tensor):
424
# return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
425
#
426
# onnx_input = onnx_program.adapt_torch_inputs_to_onnx(input_addandround_x)
427
# onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}
428
# onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
429
#
430
# torch_outputs = custom_addandround_model(input_addandround_x)
431
# torch_outputs = onnx_program.adapt_torch_outputs_to_onnx(torch_outputs)
432
#
433
# assert len(torch_outputs) == len(onnxruntime_outputs)
434
# for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):
435
# torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output))
436
#
437
# Conclusion
438
# ----------
439
#
440
# Congratulations! In this tutorial, we explored the :class:`ONNXRegistry` API and
441
# discovered how to create custom implementations for unsupported or existing ATen operators
442
# using ONNX Script.
443
# Finally, we leveraged ONNX Runtime to execute the model and compare the results with PyTorch,
444
# providing us with a comprehensive understanding of handling unsupported
445
# operators in the ONNX ecosystem.
446
#
447
# Further reading
448
# ---------------
449
#
450
# The list below refers to tutorials that ranges from basic examples to advanced scenarios,
451
# not necessarily in the order they are listed.
452
# Feel free to jump directly to specific topics of your interest or
453
# sit tight and have fun going through all of them to learn all there is about the ONNX exporter.
454
#
455
# .. include:: /beginner_source/onnx/onnx_toc.txt
456
#
457
# .. toctree::
458
# :hidden:
459
#
460
461