Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/advanced_source/super_resolution_with_onnxruntime.py
Views: 712
"""1(optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime2===================================================================================34.. note::5As of PyTorch 2.1, there are two versions of ONNX Exporter.67* ``torch.onnx.dynamo_export`` is the newest (still in beta) exporter based on the TorchDynamo technology released with PyTorch 2.0.8* ``torch.onnx.export`` is based on TorchScript backend and has been available since PyTorch 1.2.0.910In this tutorial, we describe how to convert a model defined11in PyTorch into the ONNX format using the TorchScript ``torch.onnx.export`` ONNX exporter.1213The exported model will be executed with ONNX Runtime.14ONNX Runtime is a performance-focused engine for ONNX models,15which inferences efficiently across multiple platforms and hardware16(Windows, Linux, and Mac and on both CPUs and GPUs).17ONNX Runtime has proved to considerably increase performance over18multiple models as explained `here19<https://cloudblogs.microsoft.com/opensource/2019/05/22/onnx-runtime-machine-learning-inferencing-0-4-release>`__2021For this tutorial, you will need to install `ONNX <https://github.com/onnx/onnx>`__22and `ONNX Runtime <https://github.com/microsoft/onnxruntime>`__.23You can get binary builds of ONNX and ONNX Runtime with2425.. code-block:: bash2627%%bash28pip install onnx onnxruntime2930ONNX Runtime recommends using the latest stable runtime for PyTorch.3132"""3334# Some standard imports35import numpy as np3637from torch import nn38import torch.utils.model_zoo as model_zoo39import torch.onnx404142######################################################################43# Super-resolution is a way of increasing the resolution of images, videos44# and is widely used in image processing or video editing. For this45# tutorial, we will use a small super-resolution model.46#47# First, let's create a ``SuperResolution`` model in PyTorch.48# This model uses the efficient sub-pixel convolution layer described in49# `"Real-Time Single Image and Video Super-Resolution Using an Efficient50# Sub-Pixel Convolutional Neural Network" - Shi et al <https://arxiv.org/abs/1609.05158>`__51# for increasing the resolution of an image by an upscale factor.52# The model expects the Y component of the ``YCbCr`` of an image as an input, and53# outputs the upscaled Y component in super resolution.54#55# `The56# model <https://github.com/pytorch/examples/blob/master/super_resolution/model.py>`__57# comes directly from PyTorch's examples without modification:58#5960# Super Resolution model definition in PyTorch61import torch.nn as nn62import torch.nn.init as init636465class SuperResolutionNet(nn.Module):66def __init__(self, upscale_factor, inplace=False):67super(SuperResolutionNet, self).__init__()6869self.relu = nn.ReLU(inplace=inplace)70self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))71self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))72self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))73self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))74self.pixel_shuffle = nn.PixelShuffle(upscale_factor)7576self._initialize_weights()7778def forward(self, x):79x = self.relu(self.conv1(x))80x = self.relu(self.conv2(x))81x = self.relu(self.conv3(x))82x = self.pixel_shuffle(self.conv4(x))83return x8485def _initialize_weights(self):86init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))87init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))88init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))89init.orthogonal_(self.conv4.weight)9091# Create the super-resolution model by using the above model definition.92torch_model = SuperResolutionNet(upscale_factor=3)939495######################################################################96# Ordinarily, you would now train this model; however, for this tutorial,97# we will instead download some pretrained weights. Note that this model98# was not trained fully for good accuracy and is used here for99# demonstration purposes only.100#101# It is important to call ``torch_model.eval()`` or ``torch_model.train(False)``102# before exporting the model, to turn the model to inference mode.103# This is required since operators like dropout or batchnorm behave104# differently in inference and training mode.105#106107# Load pretrained model weights108model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'109batch_size = 64 # just a random number110111# Initialize model with the pretrained weights112map_location = lambda storage, loc: storage113if torch.cuda.is_available():114map_location = None115torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))116117# set the model to inference mode118torch_model.eval()119120121######################################################################122# Exporting a model in PyTorch works via tracing or scripting. This123# tutorial will use as an example a model exported by tracing.124# To export a model, we call the ``torch.onnx.export()`` function.125# This will execute the model, recording a trace of what operators126# are used to compute the outputs.127# Because ``export`` runs the model, we need to provide an input128# tensor ``x``. The values in this can be random as long as it is the129# right type and size.130# Note that the input size will be fixed in the exported ONNX graph for131# all the input's dimensions, unless specified as a dynamic axes.132# In this example we export the model with an input of batch_size 1,133# but then specify the first dimension as dynamic in the ``dynamic_axes``134# parameter in ``torch.onnx.export()``.135# The exported model will thus accept inputs of size [batch_size, 1, 224, 224]136# where batch_size can be variable.137#138# To learn more details about PyTorch's export interface, check out the139# `torch.onnx documentation <https://pytorch.org/docs/master/onnx.html>`__.140#141142# Input to the model143x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)144torch_out = torch_model(x)145146# Export the model147torch.onnx.export(torch_model, # model being run148x, # model input (or a tuple for multiple inputs)149"super_resolution.onnx", # where to save the model (can be a file or file-like object)150export_params=True, # store the trained parameter weights inside the model file151opset_version=10, # the ONNX version to export the model to152do_constant_folding=True, # whether to execute constant folding for optimization153input_names = ['input'], # the model's input names154output_names = ['output'], # the model's output names155dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes156'output' : {0 : 'batch_size'}})157158######################################################################159# We also computed ``torch_out``, the output after of the model,160# which we will use to verify that the model we exported computes161# the same values when run in ONNX Runtime.162#163# But before verifying the model's output with ONNX Runtime, we will check164# the ONNX model with ONNX API.165# First, ``onnx.load("super_resolution.onnx")`` will load the saved model and166# will output a ``onnx.ModelProto`` structure (a top-level file/container format for bundling a ML model.167# For more information `onnx.proto documentation <https://github.com/onnx/onnx/blob/master/onnx/onnx.proto>`__.).168# Then, ``onnx.checker.check_model(onnx_model)`` will verify the model's structure169# and confirm that the model has a valid schema.170# The validity of the ONNX graph is verified by checking the model's171# version, the graph's structure, as well as the nodes and their inputs172# and outputs.173#174175import onnx176177onnx_model = onnx.load("super_resolution.onnx")178onnx.checker.check_model(onnx_model)179180181######################################################################182# Now let's compute the output using ONNX Runtime's Python APIs.183# This part can normally be done in a separate process or on another184# machine, but we will continue in the same process so that we can185# verify that ONNX Runtime and PyTorch are computing the same value186# for the network.187#188# In order to run the model with ONNX Runtime, we need to create an189# inference session for the model with the chosen configuration190# parameters (here we use the default config).191# Once the session is created, we evaluate the model using the run() API.192# The output of this call is a list containing the outputs of the model193# computed by ONNX Runtime.194#195196import onnxruntime197198ort_session = onnxruntime.InferenceSession("super_resolution.onnx", providers=["CPUExecutionProvider"])199200def to_numpy(tensor):201return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()202203# compute ONNX Runtime output prediction204ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}205ort_outs = ort_session.run(None, ort_inputs)206207# compare ONNX Runtime and PyTorch results208np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)209210print("Exported model has been tested with ONNXRuntime, and the result looks good!")211212213######################################################################214# We should see that the output of PyTorch and ONNX Runtime runs match215# numerically with the given precision (``rtol=1e-03`` and ``atol=1e-05``).216# As a side-note, if they do not match then there is an issue in the217# ONNX exporter, so please contact us in that case.218#219220######################################################################221# Timing Comparison Between Models222# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~223#224225######################################################################226# Since ONNX models optimize for inference speed, running the same227# data on an ONNX model instead of a native pytorch model should result in an228# improvement of up to 2x. Improvement is more pronounced with higher batch sizes.229230231import time232233x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)234235start = time.time()236torch_out = torch_model(x)237end = time.time()238print(f"Inference of Pytorch model used {end - start} seconds")239240ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}241start = time.time()242ort_outs = ort_session.run(None, ort_inputs)243end = time.time()244print(f"Inference of ONNX model used {end - start} seconds")245246247######################################################################248# Running the model on an image using ONNX Runtime249# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~250#251252253######################################################################254# So far we have exported a model from PyTorch and shown how to load it255# and run it in ONNX Runtime with a dummy tensor as an input.256257######################################################################258# For this tutorial, we will use a famous cat image used widely which259# looks like below260#261# .. figure:: /_static/img/cat_224x224.jpg262# :alt: cat263#264265######################################################################266# First, let's load the image, preprocess it using standard PIL267# python library. Note that this preprocessing is the standard practice of268# processing data for training/testing neural networks.269#270# We first resize the image to fit the size of the model's input (224x224).271# Then we split the image into its Y, Cb, and Cr components.272# These components represent a grayscale image (Y), and273# the blue-difference (Cb) and red-difference (Cr) chroma components.274# The Y component being more sensitive to the human eye, we are275# interested in this component which we will be transforming.276# After extracting the Y component, we convert it to a tensor which277# will be the input of our model.278#279280from PIL import Image281import torchvision.transforms as transforms282283img = Image.open("./_static/img/cat.jpg")284285resize = transforms.Resize([224, 224])286img = resize(img)287288img_ycbcr = img.convert('YCbCr')289img_y, img_cb, img_cr = img_ycbcr.split()290291to_tensor = transforms.ToTensor()292img_y = to_tensor(img_y)293img_y.unsqueeze_(0)294295296######################################################################297# Now, as a next step, let's take the tensor representing the298# grayscale resized cat image and run the super-resolution model in299# ONNX Runtime as explained previously.300#301302ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img_y)}303ort_outs = ort_session.run(None, ort_inputs)304img_out_y = ort_outs[0]305306307######################################################################308# At this point, the output of the model is a tensor.309# Now, we'll process the output of the model to construct back the310# final output image from the output tensor, and save the image.311# The post-processing steps have been adopted from PyTorch312# implementation of super-resolution model313# `here <https://github.com/pytorch/examples/blob/master/super_resolution/super_resolve.py>`__.314#315316img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode='L')317318# get the output image follow post-processing step from PyTorch implementation319final_img = Image.merge(320"YCbCr", [321img_out_y,322img_cb.resize(img_out_y.size, Image.BICUBIC),323img_cr.resize(img_out_y.size, Image.BICUBIC),324]).convert("RGB")325326# Save the image, we will compare this with the output image from mobile device327final_img.save("./_static/img/cat_superres_with_ort.jpg")328329# Save resized original image (without super-resolution)330img = transforms.Resize([img_out_y.size[0], img_out_y.size[1]])(img)331img.save("cat_resized.jpg")332333######################################################################334# Here is the comparison between the two images:335#336# .. figure:: /_static/img/cat_resized.jpg337#338# Low-resolution image339#340# .. figure:: /_static/img/cat_superres_with_ort.jpg341#342# Image after super-resolution343#344#345# ONNX Runtime being a cross platform engine, you can run it across346# multiple platforms and on both CPUs and GPUs.347#348# ONNX Runtime can also be deployed to the cloud for model inferencing349# using Azure Machine Learning Services. More information `here <https://docs.microsoft.com/en-us/azure/machine-learning/service/concept-onnx>`__.350#351# More information about ONNX Runtime's performance `here <https://onnxruntime.ai/docs/performance>`__.352#353#354# For more information about ONNX Runtime `here <https://github.com/microsoft/onnxruntime>`__.355#356357358