CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
Path: blob/main/beginner_source/Intro_to_TorchScript_tutorial.py
Views: 494
"""1Introduction to TorchScript2===========================34**Authors:** James Reed ([email protected]), Michael Suo ([email protected]), rev256This tutorial is an introduction to TorchScript, an intermediate7representation of a PyTorch model (subclass of ``nn.Module``) that8can then be run in a high-performance environment such as C++.910In this tutorial we will cover:11121. The basics of model authoring in PyTorch, including:1314- Modules15- Defining ``forward`` functions16- Composing modules into a hierarchy of modules17182. Specific methods for converting PyTorch modules to TorchScript, our19high-performance deployment runtime2021- Tracing an existing module22- Using scripting to directly compile a module23- How to compose both approaches24- Saving and loading TorchScript modules2526We hope that after you complete this tutorial, you will proceed to go through27`the follow-on tutorial <https://pytorch.org/tutorials/advanced/cpp_export.html>`_28which will walk you through an example of actually calling a TorchScript29model from C++.3031"""3233import torch # This is all you need to use both PyTorch and TorchScript!34print(torch.__version__)35torch.manual_seed(191009) # set the seed for reproducibility363738######################################################################39# Basics of PyTorch Model Authoring40# ---------------------------------41#42# Let’s start out by defining a simple ``Module``. A ``Module`` is the43# basic unit of composition in PyTorch. It contains:44#45# 1. A constructor, which prepares the module for invocation46# 2. A set of ``Parameters`` and sub-\ ``Modules``. These are initialized47# by the constructor and can be used by the module during invocation.48# 3. A ``forward`` function. This is the code that is run when the module49# is invoked.50#51# Let’s examine a small example:52#5354class MyCell(torch.nn.Module):55def __init__(self):56super(MyCell, self).__init__()5758def forward(self, x, h):59new_h = torch.tanh(x + h)60return new_h, new_h6162my_cell = MyCell()63x = torch.rand(3, 4)64h = torch.rand(3, 4)65print(my_cell(x, h))666768######################################################################69# So we’ve:70#71# 1. Created a class that subclasses ``torch.nn.Module``.72# 2. Defined a constructor. The constructor doesn’t do much, just calls73# the constructor for ``super``.74# 3. Defined a ``forward`` function, which takes two inputs and returns75# two outputs. The actual contents of the ``forward`` function are not76# really important, but it’s sort of a fake `RNN77# cell <https://colah.github.io/posts/2015-08-Understanding-LSTMs/>`__–that78# is–it’s a function that is applied on a loop.79#80# We instantiated the module, and made ``x`` and ``h``, which are just 3x481# matrices of random values. Then we invoked the cell with82# ``my_cell(x, h)``. This in turn calls our ``forward`` function.83#84# Let’s do something a little more interesting:85#8687class MyCell(torch.nn.Module):88def __init__(self):89super(MyCell, self).__init__()90self.linear = torch.nn.Linear(4, 4)9192def forward(self, x, h):93new_h = torch.tanh(self.linear(x) + h)94return new_h, new_h9596my_cell = MyCell()97print(my_cell)98print(my_cell(x, h))99100101######################################################################102# We’ve redefined our module ``MyCell``, but this time we’ve added a103# ``self.linear`` attribute, and we invoke ``self.linear`` in the forward104# function.105#106# What exactly is happening here? ``torch.nn.Linear`` is a ``Module`` from107# the PyTorch standard library. Just like ``MyCell``, it can be invoked108# using the call syntax. We are building a hierarchy of ``Module``\ s.109#110# ``print`` on a ``Module`` will give a visual representation of the111# ``Module``\ ’s subclass hierarchy. In our example, we can see our112# ``Linear`` subclass and its parameters.113#114# By composing ``Module``\ s in this way, we can succinctly and readably115# author models with reusable components.116#117# You may have noticed ``grad_fn`` on the outputs. This is a detail of118# PyTorch’s method of automatic differentiation, called119# `autograd <https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html>`__.120# In short, this system allows us to compute derivatives through121# potentially complex programs. The design allows for a massive amount of122# flexibility in model authoring.123#124# Now let’s examine said flexibility:125#126127class MyDecisionGate(torch.nn.Module):128def forward(self, x):129if x.sum() > 0:130return x131else:132return -x133134class MyCell(torch.nn.Module):135def __init__(self):136super(MyCell, self).__init__()137self.dg = MyDecisionGate()138self.linear = torch.nn.Linear(4, 4)139140def forward(self, x, h):141new_h = torch.tanh(self.dg(self.linear(x)) + h)142return new_h, new_h143144my_cell = MyCell()145print(my_cell)146print(my_cell(x, h))147148149######################################################################150# We’ve once again redefined our ``MyCell`` class, but here we’ve defined151# ``MyDecisionGate``. This module utilizes **control flow**. Control flow152# consists of things like loops and ``if``-statements.153#154# Many frameworks take the approach of computing symbolic derivatives155# given a full program representation. However, in PyTorch, we use a156# gradient tape. We record operations as they occur, and replay them157# backwards in computing derivatives. In this way, the framework does not158# have to explicitly define derivatives for all constructs in the159# language.160#161# .. figure:: https://github.com/pytorch/pytorch/raw/main/docs/source/_static/img/dynamic_graph.gif162# :alt: How autograd works163#164# How autograd works165#166167168######################################################################169# Basics of TorchScript170# ---------------------171#172# Now let’s take our running example and see how we can apply TorchScript.173#174# In short, TorchScript provides tools to capture the definition of your175# model, even in light of the flexible and dynamic nature of PyTorch.176# Let’s begin by examining what we call **tracing**.177#178# Tracing ``Modules``179# ~~~~~~~~~~~~~~~~~~~180#181182class MyCell(torch.nn.Module):183def __init__(self):184super(MyCell, self).__init__()185self.linear = torch.nn.Linear(4, 4)186187def forward(self, x, h):188new_h = torch.tanh(self.linear(x) + h)189return new_h, new_h190191my_cell = MyCell()192x, h = torch.rand(3, 4), torch.rand(3, 4)193traced_cell = torch.jit.trace(my_cell, (x, h))194print(traced_cell)195traced_cell(x, h)196197198######################################################################199# We’ve rewinded a bit and taken the second version of our ``MyCell``200# class. As before, we’ve instantiated it, but this time, we’ve called201# ``torch.jit.trace``, passed in the ``Module``, and passed in *example202# inputs* the network might see.203#204# What exactly has this done? It has invoked the ``Module``, recorded the205# operations that occurred when the ``Module`` was run, and created an206# instance of ``torch.jit.ScriptModule`` (of which ``TracedModule`` is an207# instance)208#209# TorchScript records its definitions in an Intermediate Representation210# (or IR), commonly referred to in Deep learning as a *graph*. We can211# examine the graph with the ``.graph`` property:212#213214print(traced_cell.graph)215216217######################################################################218# However, this is a very low-level representation and most of the219# information contained in the graph is not useful for end users. Instead,220# we can use the ``.code`` property to give a Python-syntax interpretation221# of the code:222#223224print(traced_cell.code)225226227######################################################################228# So **why** did we do all this? There are several reasons:229#230# 1. TorchScript code can be invoked in its own interpreter, which is231# basically a restricted Python interpreter. This interpreter does not232# acquire the Global Interpreter Lock, and so many requests can be233# processed on the same instance simultaneously.234# 2. This format allows us to save the whole model to disk and load it235# into another environment, such as in a server written in a language236# other than Python237# 3. TorchScript gives us a representation in which we can do compiler238# optimizations on the code to provide more efficient execution239# 4. TorchScript allows us to interface with many backend/device runtimes240# that require a broader view of the program than individual operators.241#242# We can see that invoking ``traced_cell`` produces the same results as243# the Python module:244#245246print(my_cell(x, h))247print(traced_cell(x, h))248249250######################################################################251# Using Scripting to Convert Modules252# ----------------------------------253#254# There’s a reason we used version two of our module, and not the one with255# the control-flow-laden submodule. Let’s examine that now:256#257258class MyDecisionGate(torch.nn.Module):259def forward(self, x):260if x.sum() > 0:261return x262else:263return -x264265class MyCell(torch.nn.Module):266def __init__(self, dg):267super(MyCell, self).__init__()268self.dg = dg269self.linear = torch.nn.Linear(4, 4)270271def forward(self, x, h):272new_h = torch.tanh(self.dg(self.linear(x)) + h)273return new_h, new_h274275my_cell = MyCell(MyDecisionGate())276traced_cell = torch.jit.trace(my_cell, (x, h))277278print(traced_cell.dg.code)279print(traced_cell.code)280281282######################################################################283# Looking at the ``.code`` output, we can see that the ``if-else`` branch284# is nowhere to be found! Why? Tracing does exactly what we said it would:285# run the code, record the operations *that happen* and construct a286# ``ScriptModule`` that does exactly that. Unfortunately, things like control287# flow are erased.288#289# How can we faithfully represent this module in TorchScript? We provide a290# **script compiler**, which does direct analysis of your Python source291# code to transform it into TorchScript. Let’s convert ``MyDecisionGate``292# using the script compiler:293#294295scripted_gate = torch.jit.script(MyDecisionGate())296297my_cell = MyCell(scripted_gate)298scripted_cell = torch.jit.script(my_cell)299300print(scripted_gate.code)301print(scripted_cell.code)302303304######################################################################305# Hooray! We’ve now faithfully captured the behavior of our program in306# TorchScript. Let’s now try running the program:307#308309# New inputs310x, h = torch.rand(3, 4), torch.rand(3, 4)311print(scripted_cell(x, h))312313314######################################################################315# Mixing Scripting and Tracing316# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~317#318# Some situations call for using tracing rather than scripting (e.g. a319# module has many architectural decisions that are made based on constant320# Python values that we would like to not appear in TorchScript). In this321# case, scripting can be composed with tracing: ``torch.jit.script`` will322# inline the code for a traced module, and tracing will inline the code323# for a scripted module.324#325# An example of the first case:326#327328class MyRNNLoop(torch.nn.Module):329def __init__(self):330super(MyRNNLoop, self).__init__()331self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))332333def forward(self, xs):334h, y = torch.zeros(3, 4), torch.zeros(3, 4)335for i in range(xs.size(0)):336y, h = self.cell(xs[i], h)337return y, h338339rnn_loop = torch.jit.script(MyRNNLoop())340print(rnn_loop.code)341342343344######################################################################345# And an example of the second case:346#347348class WrapRNN(torch.nn.Module):349def __init__(self):350super(WrapRNN, self).__init__()351self.loop = torch.jit.script(MyRNNLoop())352353def forward(self, xs):354y, h = self.loop(xs)355return torch.relu(y)356357traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))358print(traced.code)359360361######################################################################362# This way, scripting and tracing can be used when the situation calls for363# each of them and used together.364#365# Saving and Loading models366# -------------------------367#368# We provide APIs to save and load TorchScript modules to/from disk in an369# archive format. This format includes code, parameters, attributes, and370# debug information, meaning that the archive is a freestanding371# representation of the model that can be loaded in an entirely separate372# process. Let’s save and load our wrapped RNN module:373#374375traced.save('wrapped_rnn.pt')376377loaded = torch.jit.load('wrapped_rnn.pt')378379print(loaded)380print(loaded.code)381382383######################################################################384# As you can see, serialization preserves the module hierarchy and the385# code we’ve been examining throughout. The model can also be loaded, for386# example, `into387# C++ <https://pytorch.org/tutorials/advanced/cpp_export.html>`__ for388# python-free execution.389#390# Further Reading391# ~~~~~~~~~~~~~~~392#393# We’ve completed our tutorial! For a more involved demonstration, check394# out the NeurIPS demo for converting machine translation models using395# TorchScript:396# https://colab.research.google.com/drive/1HiICg6jRkBnr5hvK2-VnMi88Vi9pUzEJ397#398399400