CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/prototype_source/torchscript_freezing.py
Views: 494
1
"""
2
Model Freezing in TorchScript
3
=============================
4
5
In this tutorial, we introduce the syntax for *model freezing* in TorchScript.
6
Freezing is the process of inlining Pytorch module parameters and attributes
7
values into the TorchScript internal representation. Parameter and attribute
8
values are treated as final values and they cannot be modified in the resulting
9
Frozen module.
10
11
Basic Syntax
12
------------
13
Model freezing can be invoked using API below:
14
15
``torch.jit.freeze(mod : ScriptModule, names : str[]) -> ScriptModule``
16
17
Note the input module can either be the result of scripting or tracing.
18
See https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html
19
20
Next, we demonstrate how freezing works using an example:
21
"""
22
23
import torch, time
24
25
class Net(torch.nn.Module):
26
def __init__(self):
27
super(Net, self).__init__()
28
self.conv1 = torch.nn.Conv2d(1, 32, 3, 1)
29
self.conv2 = torch.nn.Conv2d(32, 64, 3, 1)
30
self.dropout1 = torch.nn.Dropout2d(0.25)
31
self.dropout2 = torch.nn.Dropout2d(0.5)
32
self.fc1 = torch.nn.Linear(9216, 128)
33
self.fc2 = torch.nn.Linear(128, 10)
34
35
def forward(self, x):
36
x = self.conv1(x)
37
x = torch.nn.functional.relu(x)
38
x = self.conv2(x)
39
x = torch.nn.functional.max_pool2d(x, 2)
40
x = self.dropout1(x)
41
x = torch.flatten(x, 1)
42
x = self.fc1(x)
43
x = torch.nn.functional.relu(x)
44
x = self.dropout2(x)
45
x = self.fc2(x)
46
output = torch.nn.functional.log_softmax(x, dim=1)
47
return output
48
49
@torch.jit.export
50
def version(self):
51
return 1.0
52
53
net = torch.jit.script(Net())
54
fnet = torch.jit.freeze(net)
55
56
print(net.conv1.weight.size())
57
print(net.conv1.bias)
58
59
try:
60
print(fnet.conv1.bias)
61
# without exception handling, prints:
62
# RuntimeError: __torch__.z.___torch_mangle_3.Net does not have a field
63
# with name 'conv1'
64
except RuntimeError:
65
print("field 'conv1' is inlined. It does not exist in 'fnet'")
66
67
try:
68
fnet.version()
69
# without exception handling, prints:
70
# RuntimeError: __torch__.z.___torch_mangle_3.Net does not have a field
71
# with name 'version'
72
except RuntimeError:
73
print("method 'version' is not deleted in fnet. Only 'forward' is preserved")
74
75
fnet2 = torch.jit.freeze(net, ["version"])
76
77
print(fnet2.version())
78
79
B=1
80
warmup = 1
81
iter = 1000
82
input = torch.rand(B, 1,28, 28)
83
84
start = time.time()
85
for i in range(warmup):
86
net(input)
87
end = time.time()
88
print("Scripted - Warm up time: {0:7.4f}".format(end-start), flush=True)
89
90
start = time.time()
91
for i in range(warmup):
92
fnet(input)
93
end = time.time()
94
print("Frozen - Warm up time: {0:7.4f}".format(end-start), flush=True)
95
96
start = time.time()
97
for i in range(iter):
98
input = torch.rand(B, 1,28, 28)
99
net(input)
100
end = time.time()
101
print("Scripted - Inference: {0:5.2f}".format(end-start), flush=True)
102
103
start = time.time()
104
for i in range(iter):
105
input = torch.rand(B, 1,28, 28)
106
fnet2(input)
107
end = time.time()
108
print("Frozen - Inference time: {0:5.2f}".format(end-start), flush =True)
109
110
###############################################################
111
# On my machine, I measured the time:
112
#
113
# * Scripted - Warm up time: 0.0107
114
# * Frozen - Warm up time: 0.0048
115
# * Scripted - Inference: 1.35
116
# * Frozen - Inference time: 1.17
117
118
###############################################################
119
# In our example, warm up time measures the first two runs. The frozen model
120
# is 50% faster than the scripted model. On some more complex models, we
121
# observed even higher speed up of warm up time. freezing achieves this speed up
122
# because it is doing some the work TorchScript has to do when the first couple
123
# runs are initiated.
124
#
125
# Inference time measures inference execution time after the model is warmed up.
126
# Although we observed significant variation in execution time, the
127
# frozen model is often about 15% faster than the scripted model. When input is larger,
128
# we observe a smaller speed up because the execution is dominated by tensor operations.
129
130
###############################################################
131
# Conclusion
132
# -----------
133
# In this tutorial, we learned about model freezing. Freezing is a useful technique to
134
# optimize models for inference and it also can significantly reduce TorchScript warmup time.
135
136