CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/recipes_source/recipes/dynamic_quantization.py
Views: 494
1
"""
2
Dynamic Quantization
3
====================
4
5
In this recipe you will see how to take advantage of Dynamic
6
Quantization to accelerate inference on an LSTM-style recurrent neural
7
network. This reduces the size of the model weights and speeds up model
8
execution.
9
10
Introduction
11
-------------
12
13
There are a number of trade-offs that can be made when designing neural
14
networks. During model development and training you can alter the
15
number of layers and number of parameters in a recurrent neural network
16
and trade-off accuracy against model size and/or model latency or
17
throughput. Such changes can take lot of time and compute resources
18
because you are iterating over the model training. Quantization gives
19
you a way to make a similar trade off between performance and model
20
accuracy with a known model after training is completed.
21
22
You can give it a try in a single session and you will certainly reduce
23
your model size significantly and may get a significant latency
24
reduction without losing a lot of accuracy.
25
26
What is dynamic quantization?
27
-----------------------------
28
29
Quantizing a network means converting it to use a reduced precision
30
integer representation for the weights and/or activations. This saves on
31
model size and allows the use of higher throughput math operations on
32
your CPU or GPU.
33
34
When converting from floating point to integer values you are
35
essentially multiplying the floating point value by some scale factor
36
and rounding the result to a whole number. The various quantization
37
approaches differ in the way they approach determining that scale
38
factor.
39
40
The key idea with dynamic quantization as described here is that we are
41
going to determine the scale factor for activations dynamically based on
42
the data range observed at runtime. This ensures that the scale factor
43
is "tuned" so that as much signal as possible about each observed
44
dataset is preserved.
45
46
The model parameters on the other hand are known during model conversion
47
and they are converted ahead of time and stored in INT8 form.
48
49
Arithmetic in the quantized model is done using vectorized INT8
50
instructions. Accumulation is typically done with INT16 or INT32 to
51
avoid overflow. This higher precision value is scaled back to INT8 if
52
the next layer is quantized or converted to FP32 for output.
53
54
Dynamic quantization is relatively free of tuning parameters which makes
55
it well suited to be added into production pipelines as a standard part
56
of converting LSTM models to deployment.
57
58
59
60
.. note::
61
Limitations on the approach taken here
62
63
64
This recipe provides a quick introduction to the dynamic quantization
65
features in PyTorch and the workflow for using it. Our focus is on
66
explaining the specific functions used to convert the model. We will
67
make a number of significant simplifications in the interest of brevity
68
and clarity
69
70
71
1. You will start with a minimal LSTM network
72
2. You are simply going to initialize the network with a random hidden
73
state
74
3. You are going to test the network with random inputs
75
4. You are not going to train the network in this tutorial
76
5. You will see that the quantized form of this network is smaller and
77
runs faster than the floating point network we started with
78
6. You will see that the output values are generally in the same
79
ballpark as the output of the FP32 network, but we are not
80
demonstrating here the expected accuracy loss on a real trained
81
network
82
83
You will see how dynamic quantization is done and be able to see
84
suggestive reductions in memory use and latency times. Providing a
85
demonstration that the technique can preserve high levels of model
86
accuracy on a trained LSTM is left to a more advanced tutorial. If you
87
want to move right away to that more rigorous treatment please proceed
88
to the `advanced dynamic quantization
89
tutorial <https://pytorch.org/tutorials/advanced/dynamic_quantization_tutorial.html>`__.
90
91
Steps
92
-------------
93
94
This recipe has 5 steps.
95
96
1. Set Up - Here you define a very simple LSTM, import modules, and establish
97
some random input tensors.
98
99
2. Do the Quantization - Here you instantiate a floating point model and then create quantized
100
version of it.
101
102
3. Look at Model Size - Here you show that the model size gets smaller.
103
104
4. Look at Latency - Here you run the two models and compare model runtime (latency).
105
106
5. Look at Accuracy - Here you run the two models and compare outputs.
107
108
109
1: Set Up
110
~~~~~~~~~~~~~~~
111
This is a straightforward bit of code to set up for the rest of the
112
recipe.
113
114
The unique module we are importing here is torch.quantization which
115
includes PyTorch's quantized operators and conversion functions. We also
116
define a very simple LSTM model and set up some inputs.
117
118
"""
119
120
# import the modules used here in this recipe
121
import torch
122
import torch.quantization
123
import torch.nn as nn
124
import copy
125
import os
126
import time
127
128
# define a very, very simple LSTM for demonstration purposes
129
# in this case, we are wrapping ``nn.LSTM``, one layer, no preprocessing or postprocessing
130
# inspired by
131
# `Sequence Models and Long Short-Term Memory Networks tutorial <https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html`_, by Robert Guthrie
132
# and `Dynamic Quanitzation tutorial <https://pytorch.org/tutorials/advanced/dynamic_quantization_tutorial.html>`__.
133
class lstm_for_demonstration(nn.Module):
134
"""Elementary Long Short Term Memory style model which simply wraps ``nn.LSTM``
135
Not to be used for anything other than demonstration.
136
"""
137
def __init__(self,in_dim,out_dim,depth):
138
super(lstm_for_demonstration,self).__init__()
139
self.lstm = nn.LSTM(in_dim,out_dim,depth)
140
141
def forward(self,inputs,hidden):
142
out,hidden = self.lstm(inputs,hidden)
143
return out, hidden
144
145
146
torch.manual_seed(29592) # set the seed for reproducibility
147
148
#shape parameters
149
model_dimension=8
150
sequence_length=20
151
batch_size=1
152
lstm_depth=1
153
154
# random data for input
155
inputs = torch.randn(sequence_length,batch_size,model_dimension)
156
# hidden is actually is a tuple of the initial hidden state and the initial cell state
157
hidden = (torch.randn(lstm_depth,batch_size,model_dimension), torch.randn(lstm_depth,batch_size,model_dimension))
158
159
160
######################################################################
161
# 2: Do the Quantization
162
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
163
#
164
# Now we get to the fun part. First we create an instance of the model
165
# called ``float\_lstm`` then we are going to quantize it. We're going to use
166
# the `torch.quantization.quantize_dynamic <https://pytorch.org/docs/stable/quantization.html#torch.quantization.quantize_dynamic>`__ function, which takes the model, then a list of the submodules
167
# which we want to
168
# have quantized if they appear, then the datatype we are targeting. This
169
# function returns a quantized version of the original model as a new
170
# module.
171
#
172
# That's all it takes.
173
#
174
175
# here is our floating point instance
176
float_lstm = lstm_for_demonstration(model_dimension, model_dimension,lstm_depth)
177
178
# this is the call that does the work
179
quantized_lstm = torch.quantization.quantize_dynamic(
180
float_lstm, {nn.LSTM, nn.Linear}, dtype=torch.qint8
181
)
182
183
# show the changes that were made
184
print('Here is the floating point version of this module:')
185
print(float_lstm)
186
print('')
187
print('and now the quantized version:')
188
print(quantized_lstm)
189
190
191
######################################################################
192
# 3. Look at Model Size
193
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194
# We've quantized the model. What does that get us? Well the first
195
# benefit is that we've replaced the FP32 model parameters with INT8
196
# values (and some recorded scale factors). This means about 75% less data
197
# to store and move around. With the default values the reduction shown
198
# below will be less than 75% but if you increase the model size above
199
# (for example you can set model dimension to something like 80) this will
200
# converge towards 4x smaller as the stored model size dominated more and
201
# more by the parameter values.
202
#
203
204
def print_size_of_model(model, label=""):
205
torch.save(model.state_dict(), "temp.p")
206
size=os.path.getsize("temp.p")
207
print("model: ",label,' \t','Size (KB):', size/1e3)
208
os.remove('temp.p')
209
return size
210
211
# compare the sizes
212
f=print_size_of_model(float_lstm,"fp32")
213
q=print_size_of_model(quantized_lstm,"int8")
214
print("{0:.2f} times smaller".format(f/q))
215
216
217
######################################################################
218
# 4. Look at Latency
219
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
220
# The second benefit is that the quantized model will typically run
221
# faster. This is due to a combinations of effects including at least:
222
#
223
# 1. Less time spent moving parameter data in
224
# 2. Faster INT8 operations
225
#
226
# As you will see the quantized version of this super-simple network runs
227
# faster. This will generally be true of more complex networks but as they
228
# say "your mileage may vary" depending on a number of factors including
229
# the structure of the model and the hardware you are running on.
230
#
231
232
# compare the performance
233
print("Floating point FP32")
234
235
#####################################################################
236
# .. code-block:: python
237
#
238
# %timeit float_lstm.forward(inputs, hidden)
239
240
print("Quantized INT8")
241
242
######################################################################
243
# .. code-block:: python
244
#
245
# %timeit quantized_lstm.forward(inputs,hidden)
246
247
248
######################################################################
249
# 5: Look at Accuracy
250
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
251
# We are not going to do a careful look at accuracy here because we are
252
# working with a randomly initialized network rather than a properly
253
# trained one. However, I think it is worth quickly showing that the
254
# quantized network does produce output tensors that are "in the same
255
# ballpark" as the original one.
256
#
257
# For a more detailed analysis please see the more advanced tutorials
258
# referenced at the end of this recipe.
259
#
260
261
# run the float model
262
out1, hidden1 = float_lstm(inputs, hidden)
263
mag1 = torch.mean(abs(out1)).item()
264
print('mean absolute value of output tensor values in the FP32 model is {0:.5f} '.format(mag1))
265
266
# run the quantized model
267
out2, hidden2 = quantized_lstm(inputs, hidden)
268
mag2 = torch.mean(abs(out2)).item()
269
print('mean absolute value of output tensor values in the INT8 model is {0:.5f}'.format(mag2))
270
271
# compare them
272
mag3 = torch.mean(abs(out1-out2)).item()
273
print('mean absolute value of the difference between the output tensors is {0:.5f} or {1:.2f} percent'.format(mag3,mag3/mag1*100))
274
275
276
######################################################################
277
# Learn More
278
# ------------
279
# We've explained what dynamic quantization is, what benefits it brings,
280
# and you have used the ``torch.quantization.quantize_dynamic()`` function
281
# to quickly quantize a simple LSTM model.
282
#
283
# This was a fast and high level treatment of this material; for more
284
# detail please continue learning with `(beta) Dynamic Quantization on an LSTM Word Language Model Tutorial <https://pytorch.org/tutorials/advanced/dynamic\_quantization\_tutorial.html>`_.
285
#
286
#
287
# Additional Resources
288
# --------------------
289
#
290
# * `Quantization API Documentaion <https://pytorch.org/docs/stable/quantization.html>`_
291
# * `(beta) Dynamic Quantization on BERT <https://pytorch.org/tutorials/intermediate/dynamic\_quantization\_bert\_tutorial.html>`_
292
# * `(beta) Dynamic Quantization on an LSTM Word Language Model <https://pytorch.org/tutorials/advanced/dynamic\_quantization\_tutorial.html>`_
293
# * `Introduction to Quantization on PyTorch <https://pytorch.org/blog/introduction-to-quantization-on-pytorch/>`_
294
#
295
296