CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/beginner_source/Intro_to_TorchScript_tutorial.py
Views: 494
1
"""
2
Introduction to TorchScript
3
===========================
4
5
**Authors:** James Reed ([email protected]), Michael Suo ([email protected]), rev2
6
7
This tutorial is an introduction to TorchScript, an intermediate
8
representation of a PyTorch model (subclass of ``nn.Module``) that
9
can then be run in a high-performance environment such as C++.
10
11
In this tutorial we will cover:
12
13
1. The basics of model authoring in PyTorch, including:
14
15
- Modules
16
- Defining ``forward`` functions
17
- Composing modules into a hierarchy of modules
18
19
2. Specific methods for converting PyTorch modules to TorchScript, our
20
high-performance deployment runtime
21
22
- Tracing an existing module
23
- Using scripting to directly compile a module
24
- How to compose both approaches
25
- Saving and loading TorchScript modules
26
27
We hope that after you complete this tutorial, you will proceed to go through
28
`the follow-on tutorial <https://pytorch.org/tutorials/advanced/cpp_export.html>`_
29
which will walk you through an example of actually calling a TorchScript
30
model from C++.
31
32
"""
33
34
import torch # This is all you need to use both PyTorch and TorchScript!
35
print(torch.__version__)
36
torch.manual_seed(191009) # set the seed for reproducibility
37
38
39
######################################################################
40
# Basics of PyTorch Model Authoring
41
# ---------------------------------
42
#
43
# Let’s start out by defining a simple ``Module``. A ``Module`` is the
44
# basic unit of composition in PyTorch. It contains:
45
#
46
# 1. A constructor, which prepares the module for invocation
47
# 2. A set of ``Parameters`` and sub-\ ``Modules``. These are initialized
48
# by the constructor and can be used by the module during invocation.
49
# 3. A ``forward`` function. This is the code that is run when the module
50
# is invoked.
51
#
52
# Let’s examine a small example:
53
#
54
55
class MyCell(torch.nn.Module):
56
def __init__(self):
57
super(MyCell, self).__init__()
58
59
def forward(self, x, h):
60
new_h = torch.tanh(x + h)
61
return new_h, new_h
62
63
my_cell = MyCell()
64
x = torch.rand(3, 4)
65
h = torch.rand(3, 4)
66
print(my_cell(x, h))
67
68
69
######################################################################
70
# So we’ve:
71
#
72
# 1. Created a class that subclasses ``torch.nn.Module``.
73
# 2. Defined a constructor. The constructor doesn’t do much, just calls
74
# the constructor for ``super``.
75
# 3. Defined a ``forward`` function, which takes two inputs and returns
76
# two outputs. The actual contents of the ``forward`` function are not
77
# really important, but it’s sort of a fake `RNN
78
# cell <https://colah.github.io/posts/2015-08-Understanding-LSTMs/>`__–that
79
# is–it’s a function that is applied on a loop.
80
#
81
# We instantiated the module, and made ``x`` and ``h``, which are just 3x4
82
# matrices of random values. Then we invoked the cell with
83
# ``my_cell(x, h)``. This in turn calls our ``forward`` function.
84
#
85
# Let’s do something a little more interesting:
86
#
87
88
class MyCell(torch.nn.Module):
89
def __init__(self):
90
super(MyCell, self).__init__()
91
self.linear = torch.nn.Linear(4, 4)
92
93
def forward(self, x, h):
94
new_h = torch.tanh(self.linear(x) + h)
95
return new_h, new_h
96
97
my_cell = MyCell()
98
print(my_cell)
99
print(my_cell(x, h))
100
101
102
######################################################################
103
# We’ve redefined our module ``MyCell``, but this time we’ve added a
104
# ``self.linear`` attribute, and we invoke ``self.linear`` in the forward
105
# function.
106
#
107
# What exactly is happening here? ``torch.nn.Linear`` is a ``Module`` from
108
# the PyTorch standard library. Just like ``MyCell``, it can be invoked
109
# using the call syntax. We are building a hierarchy of ``Module``\ s.
110
#
111
# ``print`` on a ``Module`` will give a visual representation of the
112
# ``Module``\ ’s subclass hierarchy. In our example, we can see our
113
# ``Linear`` subclass and its parameters.
114
#
115
# By composing ``Module``\ s in this way, we can succinctly and readably
116
# author models with reusable components.
117
#
118
# You may have noticed ``grad_fn`` on the outputs. This is a detail of
119
# PyTorch’s method of automatic differentiation, called
120
# `autograd <https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html>`__.
121
# In short, this system allows us to compute derivatives through
122
# potentially complex programs. The design allows for a massive amount of
123
# flexibility in model authoring.
124
#
125
# Now let’s examine said flexibility:
126
#
127
128
class MyDecisionGate(torch.nn.Module):
129
def forward(self, x):
130
if x.sum() > 0:
131
return x
132
else:
133
return -x
134
135
class MyCell(torch.nn.Module):
136
def __init__(self):
137
super(MyCell, self).__init__()
138
self.dg = MyDecisionGate()
139
self.linear = torch.nn.Linear(4, 4)
140
141
def forward(self, x, h):
142
new_h = torch.tanh(self.dg(self.linear(x)) + h)
143
return new_h, new_h
144
145
my_cell = MyCell()
146
print(my_cell)
147
print(my_cell(x, h))
148
149
150
######################################################################
151
# We’ve once again redefined our ``MyCell`` class, but here we’ve defined
152
# ``MyDecisionGate``. This module utilizes **control flow**. Control flow
153
# consists of things like loops and ``if``-statements.
154
#
155
# Many frameworks take the approach of computing symbolic derivatives
156
# given a full program representation. However, in PyTorch, we use a
157
# gradient tape. We record operations as they occur, and replay them
158
# backwards in computing derivatives. In this way, the framework does not
159
# have to explicitly define derivatives for all constructs in the
160
# language.
161
#
162
# .. figure:: https://github.com/pytorch/pytorch/raw/main/docs/source/_static/img/dynamic_graph.gif
163
# :alt: How autograd works
164
#
165
# How autograd works
166
#
167
168
169
######################################################################
170
# Basics of TorchScript
171
# ---------------------
172
#
173
# Now let’s take our running example and see how we can apply TorchScript.
174
#
175
# In short, TorchScript provides tools to capture the definition of your
176
# model, even in light of the flexible and dynamic nature of PyTorch.
177
# Let’s begin by examining what we call **tracing**.
178
#
179
# Tracing ``Modules``
180
# ~~~~~~~~~~~~~~~~~~~
181
#
182
183
class MyCell(torch.nn.Module):
184
def __init__(self):
185
super(MyCell, self).__init__()
186
self.linear = torch.nn.Linear(4, 4)
187
188
def forward(self, x, h):
189
new_h = torch.tanh(self.linear(x) + h)
190
return new_h, new_h
191
192
my_cell = MyCell()
193
x, h = torch.rand(3, 4), torch.rand(3, 4)
194
traced_cell = torch.jit.trace(my_cell, (x, h))
195
print(traced_cell)
196
traced_cell(x, h)
197
198
199
######################################################################
200
# We’ve rewinded a bit and taken the second version of our ``MyCell``
201
# class. As before, we’ve instantiated it, but this time, we’ve called
202
# ``torch.jit.trace``, passed in the ``Module``, and passed in *example
203
# inputs* the network might see.
204
#
205
# What exactly has this done? It has invoked the ``Module``, recorded the
206
# operations that occurred when the ``Module`` was run, and created an
207
# instance of ``torch.jit.ScriptModule`` (of which ``TracedModule`` is an
208
# instance)
209
#
210
# TorchScript records its definitions in an Intermediate Representation
211
# (or IR), commonly referred to in Deep learning as a *graph*. We can
212
# examine the graph with the ``.graph`` property:
213
#
214
215
print(traced_cell.graph)
216
217
218
######################################################################
219
# However, this is a very low-level representation and most of the
220
# information contained in the graph is not useful for end users. Instead,
221
# we can use the ``.code`` property to give a Python-syntax interpretation
222
# of the code:
223
#
224
225
print(traced_cell.code)
226
227
228
######################################################################
229
# So **why** did we do all this? There are several reasons:
230
#
231
# 1. TorchScript code can be invoked in its own interpreter, which is
232
# basically a restricted Python interpreter. This interpreter does not
233
# acquire the Global Interpreter Lock, and so many requests can be
234
# processed on the same instance simultaneously.
235
# 2. This format allows us to save the whole model to disk and load it
236
# into another environment, such as in a server written in a language
237
# other than Python
238
# 3. TorchScript gives us a representation in which we can do compiler
239
# optimizations on the code to provide more efficient execution
240
# 4. TorchScript allows us to interface with many backend/device runtimes
241
# that require a broader view of the program than individual operators.
242
#
243
# We can see that invoking ``traced_cell`` produces the same results as
244
# the Python module:
245
#
246
247
print(my_cell(x, h))
248
print(traced_cell(x, h))
249
250
251
######################################################################
252
# Using Scripting to Convert Modules
253
# ----------------------------------
254
#
255
# There’s a reason we used version two of our module, and not the one with
256
# the control-flow-laden submodule. Let’s examine that now:
257
#
258
259
class MyDecisionGate(torch.nn.Module):
260
def forward(self, x):
261
if x.sum() > 0:
262
return x
263
else:
264
return -x
265
266
class MyCell(torch.nn.Module):
267
def __init__(self, dg):
268
super(MyCell, self).__init__()
269
self.dg = dg
270
self.linear = torch.nn.Linear(4, 4)
271
272
def forward(self, x, h):
273
new_h = torch.tanh(self.dg(self.linear(x)) + h)
274
return new_h, new_h
275
276
my_cell = MyCell(MyDecisionGate())
277
traced_cell = torch.jit.trace(my_cell, (x, h))
278
279
print(traced_cell.dg.code)
280
print(traced_cell.code)
281
282
283
######################################################################
284
# Looking at the ``.code`` output, we can see that the ``if-else`` branch
285
# is nowhere to be found! Why? Tracing does exactly what we said it would:
286
# run the code, record the operations *that happen* and construct a
287
# ``ScriptModule`` that does exactly that. Unfortunately, things like control
288
# flow are erased.
289
#
290
# How can we faithfully represent this module in TorchScript? We provide a
291
# **script compiler**, which does direct analysis of your Python source
292
# code to transform it into TorchScript. Let’s convert ``MyDecisionGate``
293
# using the script compiler:
294
#
295
296
scripted_gate = torch.jit.script(MyDecisionGate())
297
298
my_cell = MyCell(scripted_gate)
299
scripted_cell = torch.jit.script(my_cell)
300
301
print(scripted_gate.code)
302
print(scripted_cell.code)
303
304
305
######################################################################
306
# Hooray! We’ve now faithfully captured the behavior of our program in
307
# TorchScript. Let’s now try running the program:
308
#
309
310
# New inputs
311
x, h = torch.rand(3, 4), torch.rand(3, 4)
312
print(scripted_cell(x, h))
313
314
315
######################################################################
316
# Mixing Scripting and Tracing
317
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
318
#
319
# Some situations call for using tracing rather than scripting (e.g. a
320
# module has many architectural decisions that are made based on constant
321
# Python values that we would like to not appear in TorchScript). In this
322
# case, scripting can be composed with tracing: ``torch.jit.script`` will
323
# inline the code for a traced module, and tracing will inline the code
324
# for a scripted module.
325
#
326
# An example of the first case:
327
#
328
329
class MyRNNLoop(torch.nn.Module):
330
def __init__(self):
331
super(MyRNNLoop, self).__init__()
332
self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))
333
334
def forward(self, xs):
335
h, y = torch.zeros(3, 4), torch.zeros(3, 4)
336
for i in range(xs.size(0)):
337
y, h = self.cell(xs[i], h)
338
return y, h
339
340
rnn_loop = torch.jit.script(MyRNNLoop())
341
print(rnn_loop.code)
342
343
344
345
######################################################################
346
# And an example of the second case:
347
#
348
349
class WrapRNN(torch.nn.Module):
350
def __init__(self):
351
super(WrapRNN, self).__init__()
352
self.loop = torch.jit.script(MyRNNLoop())
353
354
def forward(self, xs):
355
y, h = self.loop(xs)
356
return torch.relu(y)
357
358
traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
359
print(traced.code)
360
361
362
######################################################################
363
# This way, scripting and tracing can be used when the situation calls for
364
# each of them and used together.
365
#
366
# Saving and Loading models
367
# -------------------------
368
#
369
# We provide APIs to save and load TorchScript modules to/from disk in an
370
# archive format. This format includes code, parameters, attributes, and
371
# debug information, meaning that the archive is a freestanding
372
# representation of the model that can be loaded in an entirely separate
373
# process. Let’s save and load our wrapped RNN module:
374
#
375
376
traced.save('wrapped_rnn.pt')
377
378
loaded = torch.jit.load('wrapped_rnn.pt')
379
380
print(loaded)
381
print(loaded.code)
382
383
384
######################################################################
385
# As you can see, serialization preserves the module hierarchy and the
386
# code we’ve been examining throughout. The model can also be loaded, for
387
# example, `into
388
# C++ <https://pytorch.org/tutorials/advanced/cpp_export.html>`__ for
389
# python-free execution.
390
#
391
# Further Reading
392
# ~~~~~~~~~~~~~~~
393
#
394
# We’ve completed our tutorial! For a more involved demonstration, check
395
# out the NeurIPS demo for converting machine translation models using
396
# TorchScript:
397
# https://colab.research.google.com/drive/1HiICg6jRkBnr5hvK2-VnMi88Vi9pUzEJ
398
#
399
400