CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.

GitHub Repository: pytorch/tutorials
Path: blob/main/advanced_source/super_resolution_with_onnxruntime.py
Views: 712
1
"""
2
(optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime
3
===================================================================================
4
5
.. note::
6
As of PyTorch 2.1, there are two versions of ONNX Exporter.
7
8
* ``torch.onnx.dynamo_export`` is the newest (still in beta) exporter based on the TorchDynamo technology released with PyTorch 2.0.
9
* ``torch.onnx.export`` is based on TorchScript backend and has been available since PyTorch 1.2.0.
10
11
In this tutorial, we describe how to convert a model defined
12
in PyTorch into the ONNX format using the TorchScript ``torch.onnx.export`` ONNX exporter.
13
14
The exported model will be executed with ONNX Runtime.
15
ONNX Runtime is a performance-focused engine for ONNX models,
16
which inferences efficiently across multiple platforms and hardware
17
(Windows, Linux, and Mac and on both CPUs and GPUs).
18
ONNX Runtime has proved to considerably increase performance over
19
multiple models as explained `here
20
<https://cloudblogs.microsoft.com/opensource/2019/05/22/onnx-runtime-machine-learning-inferencing-0-4-release>`__
21
22
For this tutorial, you will need to install `ONNX <https://github.com/onnx/onnx>`__
23
and `ONNX Runtime <https://github.com/microsoft/onnxruntime>`__.
24
You can get binary builds of ONNX and ONNX Runtime with
25
26
.. code-block:: bash
27
28
%%bash
29
pip install onnx onnxruntime
30
31
ONNX Runtime recommends using the latest stable runtime for PyTorch.
32
33
"""
34
35
# Some standard imports
36
import numpy as np
37
38
from torch import nn
39
import torch.utils.model_zoo as model_zoo
40
import torch.onnx
41
42
43
######################################################################
44
# Super-resolution is a way of increasing the resolution of images, videos
45
# and is widely used in image processing or video editing. For this
46
# tutorial, we will use a small super-resolution model.
47
#
48
# First, let's create a ``SuperResolution`` model in PyTorch.
49
# This model uses the efficient sub-pixel convolution layer described in
50
# `"Real-Time Single Image and Video Super-Resolution Using an Efficient
51
# Sub-Pixel Convolutional Neural Network" - Shi et al <https://arxiv.org/abs/1609.05158>`__
52
# for increasing the resolution of an image by an upscale factor.
53
# The model expects the Y component of the ``YCbCr`` of an image as an input, and
54
# outputs the upscaled Y component in super resolution.
55
#
56
# `The
57
# model <https://github.com/pytorch/examples/blob/master/super_resolution/model.py>`__
58
# comes directly from PyTorch's examples without modification:
59
#
60
61
# Super Resolution model definition in PyTorch
62
import torch.nn as nn
63
import torch.nn.init as init
64
65
66
class SuperResolutionNet(nn.Module):
67
def __init__(self, upscale_factor, inplace=False):
68
super(SuperResolutionNet, self).__init__()
69
70
self.relu = nn.ReLU(inplace=inplace)
71
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
72
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
73
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
74
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
75
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
76
77
self._initialize_weights()
78
79
def forward(self, x):
80
x = self.relu(self.conv1(x))
81
x = self.relu(self.conv2(x))
82
x = self.relu(self.conv3(x))
83
x = self.pixel_shuffle(self.conv4(x))
84
return x
85
86
def _initialize_weights(self):
87
init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
88
init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
89
init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
90
init.orthogonal_(self.conv4.weight)
91
92
# Create the super-resolution model by using the above model definition.
93
torch_model = SuperResolutionNet(upscale_factor=3)
94
95
96
######################################################################
97
# Ordinarily, you would now train this model; however, for this tutorial,
98
# we will instead download some pretrained weights. Note that this model
99
# was not trained fully for good accuracy and is used here for
100
# demonstration purposes only.
101
#
102
# It is important to call ``torch_model.eval()`` or ``torch_model.train(False)``
103
# before exporting the model, to turn the model to inference mode.
104
# This is required since operators like dropout or batchnorm behave
105
# differently in inference and training mode.
106
#
107
108
# Load pretrained model weights
109
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
110
batch_size = 64 # just a random number
111
112
# Initialize model with the pretrained weights
113
map_location = lambda storage, loc: storage
114
if torch.cuda.is_available():
115
map_location = None
116
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))
117
118
# set the model to inference mode
119
torch_model.eval()
120
121
122
######################################################################
123
# Exporting a model in PyTorch works via tracing or scripting. This
124
# tutorial will use as an example a model exported by tracing.
125
# To export a model, we call the ``torch.onnx.export()`` function.
126
# This will execute the model, recording a trace of what operators
127
# are used to compute the outputs.
128
# Because ``export`` runs the model, we need to provide an input
129
# tensor ``x``. The values in this can be random as long as it is the
130
# right type and size.
131
# Note that the input size will be fixed in the exported ONNX graph for
132
# all the input's dimensions, unless specified as a dynamic axes.
133
# In this example we export the model with an input of batch_size 1,
134
# but then specify the first dimension as dynamic in the ``dynamic_axes``
135
# parameter in ``torch.onnx.export()``.
136
# The exported model will thus accept inputs of size [batch_size, 1, 224, 224]
137
# where batch_size can be variable.
138
#
139
# To learn more details about PyTorch's export interface, check out the
140
# `torch.onnx documentation <https://pytorch.org/docs/master/onnx.html>`__.
141
#
142
143
# Input to the model
144
x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
145
torch_out = torch_model(x)
146
147
# Export the model
148
torch.onnx.export(torch_model, # model being run
149
x, # model input (or a tuple for multiple inputs)
150
"super_resolution.onnx", # where to save the model (can be a file or file-like object)
151
export_params=True, # store the trained parameter weights inside the model file
152
opset_version=10, # the ONNX version to export the model to
153
do_constant_folding=True, # whether to execute constant folding for optimization
154
input_names = ['input'], # the model's input names
155
output_names = ['output'], # the model's output names
156
dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
157
'output' : {0 : 'batch_size'}})
158
159
######################################################################
160
# We also computed ``torch_out``, the output after of the model,
161
# which we will use to verify that the model we exported computes
162
# the same values when run in ONNX Runtime.
163
#
164
# But before verifying the model's output with ONNX Runtime, we will check
165
# the ONNX model with ONNX API.
166
# First, ``onnx.load("super_resolution.onnx")`` will load the saved model and
167
# will output a ``onnx.ModelProto`` structure (a top-level file/container format for bundling a ML model.
168
# For more information `onnx.proto documentation <https://github.com/onnx/onnx/blob/master/onnx/onnx.proto>`__.).
169
# Then, ``onnx.checker.check_model(onnx_model)`` will verify the model's structure
170
# and confirm that the model has a valid schema.
171
# The validity of the ONNX graph is verified by checking the model's
172
# version, the graph's structure, as well as the nodes and their inputs
173
# and outputs.
174
#
175
176
import onnx
177
178
onnx_model = onnx.load("super_resolution.onnx")
179
onnx.checker.check_model(onnx_model)
180
181
182
######################################################################
183
# Now let's compute the output using ONNX Runtime's Python APIs.
184
# This part can normally be done in a separate process or on another
185
# machine, but we will continue in the same process so that we can
186
# verify that ONNX Runtime and PyTorch are computing the same value
187
# for the network.
188
#
189
# In order to run the model with ONNX Runtime, we need to create an
190
# inference session for the model with the chosen configuration
191
# parameters (here we use the default config).
192
# Once the session is created, we evaluate the model using the run() API.
193
# The output of this call is a list containing the outputs of the model
194
# computed by ONNX Runtime.
195
#
196
197
import onnxruntime
198
199
ort_session = onnxruntime.InferenceSession("super_resolution.onnx", providers=["CPUExecutionProvider"])
200
201
def to_numpy(tensor):
202
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
203
204
# compute ONNX Runtime output prediction
205
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
206
ort_outs = ort_session.run(None, ort_inputs)
207
208
# compare ONNX Runtime and PyTorch results
209
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
210
211
print("Exported model has been tested with ONNXRuntime, and the result looks good!")
212
213
214
######################################################################
215
# We should see that the output of PyTorch and ONNX Runtime runs match
216
# numerically with the given precision (``rtol=1e-03`` and ``atol=1e-05``).
217
# As a side-note, if they do not match then there is an issue in the
218
# ONNX exporter, so please contact us in that case.
219
#
220
221
######################################################################
222
# Timing Comparison Between Models
223
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
224
#
225
226
######################################################################
227
# Since ONNX models optimize for inference speed, running the same
228
# data on an ONNX model instead of a native pytorch model should result in an
229
# improvement of up to 2x. Improvement is more pronounced with higher batch sizes.
230
231
232
import time
233
234
x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
235
236
start = time.time()
237
torch_out = torch_model(x)
238
end = time.time()
239
print(f"Inference of Pytorch model used {end - start} seconds")
240
241
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
242
start = time.time()
243
ort_outs = ort_session.run(None, ort_inputs)
244
end = time.time()
245
print(f"Inference of ONNX model used {end - start} seconds")
246
247
248
######################################################################
249
# Running the model on an image using ONNX Runtime
250
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
251
#
252
253
254
######################################################################
255
# So far we have exported a model from PyTorch and shown how to load it
256
# and run it in ONNX Runtime with a dummy tensor as an input.
257
258
######################################################################
259
# For this tutorial, we will use a famous cat image used widely which
260
# looks like below
261
#
262
# .. figure:: /_static/img/cat_224x224.jpg
263
# :alt: cat
264
#
265
266
######################################################################
267
# First, let's load the image, preprocess it using standard PIL
268
# python library. Note that this preprocessing is the standard practice of
269
# processing data for training/testing neural networks.
270
#
271
# We first resize the image to fit the size of the model's input (224x224).
272
# Then we split the image into its Y, Cb, and Cr components.
273
# These components represent a grayscale image (Y), and
274
# the blue-difference (Cb) and red-difference (Cr) chroma components.
275
# The Y component being more sensitive to the human eye, we are
276
# interested in this component which we will be transforming.
277
# After extracting the Y component, we convert it to a tensor which
278
# will be the input of our model.
279
#
280
281
from PIL import Image
282
import torchvision.transforms as transforms
283
284
img = Image.open("./_static/img/cat.jpg")
285
286
resize = transforms.Resize([224, 224])
287
img = resize(img)
288
289
img_ycbcr = img.convert('YCbCr')
290
img_y, img_cb, img_cr = img_ycbcr.split()
291
292
to_tensor = transforms.ToTensor()
293
img_y = to_tensor(img_y)
294
img_y.unsqueeze_(0)
295
296
297
######################################################################
298
# Now, as a next step, let's take the tensor representing the
299
# grayscale resized cat image and run the super-resolution model in
300
# ONNX Runtime as explained previously.
301
#
302
303
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img_y)}
304
ort_outs = ort_session.run(None, ort_inputs)
305
img_out_y = ort_outs[0]
306
307
308
######################################################################
309
# At this point, the output of the model is a tensor.
310
# Now, we'll process the output of the model to construct back the
311
# final output image from the output tensor, and save the image.
312
# The post-processing steps have been adopted from PyTorch
313
# implementation of super-resolution model
314
# `here <https://github.com/pytorch/examples/blob/master/super_resolution/super_resolve.py>`__.
315
#
316
317
img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode='L')
318
319
# get the output image follow post-processing step from PyTorch implementation
320
final_img = Image.merge(
321
"YCbCr", [
322
img_out_y,
323
img_cb.resize(img_out_y.size, Image.BICUBIC),
324
img_cr.resize(img_out_y.size, Image.BICUBIC),
325
]).convert("RGB")
326
327
# Save the image, we will compare this with the output image from mobile device
328
final_img.save("./_static/img/cat_superres_with_ort.jpg")
329
330
# Save resized original image (without super-resolution)
331
img = transforms.Resize([img_out_y.size[0], img_out_y.size[1]])(img)
332
img.save("cat_resized.jpg")
333
334
######################################################################
335
# Here is the comparison between the two images:
336
#
337
# .. figure:: /_static/img/cat_resized.jpg
338
#
339
# Low-resolution image
340
#
341
# .. figure:: /_static/img/cat_superres_with_ort.jpg
342
#
343
# Image after super-resolution
344
#
345
#
346
# ONNX Runtime being a cross platform engine, you can run it across
347
# multiple platforms and on both CPUs and GPUs.
348
#
349
# ONNX Runtime can also be deployed to the cloud for model inferencing
350
# using Azure Machine Learning Services. More information `here <https://docs.microsoft.com/en-us/azure/machine-learning/service/concept-onnx>`__.
351
#
352
# More information about ONNX Runtime's performance `here <https://onnxruntime.ai/docs/performance>`__.
353
#
354
#
355
# For more information about ONNX Runtime `here <https://github.com/microsoft/onnxruntime>`__.
356
#
357
358