CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
Path: blob/main/prototype_source/torchscript_freezing.py
Views: 494
"""1Model Freezing in TorchScript2=============================34In this tutorial, we introduce the syntax for *model freezing* in TorchScript.5Freezing is the process of inlining Pytorch module parameters and attributes6values into the TorchScript internal representation. Parameter and attribute7values are treated as final values and they cannot be modified in the resulting8Frozen module.910Basic Syntax11------------12Model freezing can be invoked using API below:1314``torch.jit.freeze(mod : ScriptModule, names : str[]) -> ScriptModule``1516Note the input module can either be the result of scripting or tracing.17See https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html1819Next, we demonstrate how freezing works using an example:20"""2122import torch, time2324class Net(torch.nn.Module):25def __init__(self):26super(Net, self).__init__()27self.conv1 = torch.nn.Conv2d(1, 32, 3, 1)28self.conv2 = torch.nn.Conv2d(32, 64, 3, 1)29self.dropout1 = torch.nn.Dropout2d(0.25)30self.dropout2 = torch.nn.Dropout2d(0.5)31self.fc1 = torch.nn.Linear(9216, 128)32self.fc2 = torch.nn.Linear(128, 10)3334def forward(self, x):35x = self.conv1(x)36x = torch.nn.functional.relu(x)37x = self.conv2(x)38x = torch.nn.functional.max_pool2d(x, 2)39x = self.dropout1(x)40x = torch.flatten(x, 1)41x = self.fc1(x)42x = torch.nn.functional.relu(x)43x = self.dropout2(x)44x = self.fc2(x)45output = torch.nn.functional.log_softmax(x, dim=1)46return output4748@torch.jit.export49def version(self):50return 1.05152net = torch.jit.script(Net())53fnet = torch.jit.freeze(net)5455print(net.conv1.weight.size())56print(net.conv1.bias)5758try:59print(fnet.conv1.bias)60# without exception handling, prints:61# RuntimeError: __torch__.z.___torch_mangle_3.Net does not have a field62# with name 'conv1'63except RuntimeError:64print("field 'conv1' is inlined. It does not exist in 'fnet'")6566try:67fnet.version()68# without exception handling, prints:69# RuntimeError: __torch__.z.___torch_mangle_3.Net does not have a field70# with name 'version'71except RuntimeError:72print("method 'version' is not deleted in fnet. Only 'forward' is preserved")7374fnet2 = torch.jit.freeze(net, ["version"])7576print(fnet2.version())7778B=179warmup = 180iter = 100081input = torch.rand(B, 1,28, 28)8283start = time.time()84for i in range(warmup):85net(input)86end = time.time()87print("Scripted - Warm up time: {0:7.4f}".format(end-start), flush=True)8889start = time.time()90for i in range(warmup):91fnet(input)92end = time.time()93print("Frozen - Warm up time: {0:7.4f}".format(end-start), flush=True)9495start = time.time()96for i in range(iter):97input = torch.rand(B, 1,28, 28)98net(input)99end = time.time()100print("Scripted - Inference: {0:5.2f}".format(end-start), flush=True)101102start = time.time()103for i in range(iter):104input = torch.rand(B, 1,28, 28)105fnet2(input)106end = time.time()107print("Frozen - Inference time: {0:5.2f}".format(end-start), flush =True)108109###############################################################110# On my machine, I measured the time:111#112# * Scripted - Warm up time: 0.0107113# * Frozen - Warm up time: 0.0048114# * Scripted - Inference: 1.35115# * Frozen - Inference time: 1.17116117###############################################################118# In our example, warm up time measures the first two runs. The frozen model119# is 50% faster than the scripted model. On some more complex models, we120# observed even higher speed up of warm up time. freezing achieves this speed up121# because it is doing some the work TorchScript has to do when the first couple122# runs are initiated.123#124# Inference time measures inference execution time after the model is warmed up.125# Although we observed significant variation in execution time, the126# frozen model is often about 15% faster than the scripted model. When input is larger,127# we observe a smaller speed up because the execution is dominated by tensor operations.128129###############################################################130# Conclusion131# -----------132# In this tutorial, we learned about model freezing. Freezing is a useful technique to133# optimize models for inference and it also can significantly reduce TorchScript warmup time.134135136