CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/prototype_source/numeric_suite_tutorial.py
Views: 494
1
# -*- coding: utf-8 -*-
2
"""
3
PyTorch Numeric Suite Tutorial
4
==============================
5
6
Introduction
7
------------
8
9
Quantization is good when it works, but it’s difficult to know what's wrong when it doesn't satisfy the accuracy we expect. Debugging the accuracy issue of quantization is not easy and time consuming.
10
11
One important step of debugging is to measure the statistics of the float model and its corresponding quantized model to know where are they differ most. We built a suite of numeric tools called PyTorch Numeric Suite in PyTorch quantization to enable the measurement of the statistics between quantized module and float module to support quantization debugging efforts. Even for the quantized model with good accuracy, PyTorch Numeric Suite can still be used as the profiling tool to better understand the quantization error within the model and provide the guidance for further optimization.
12
13
PyTorch Numeric Suite currently supports models quantized through both static quantization and dynamic quantization with unified APIs.
14
15
In this tutorial we will first use ResNet18 as an example to show how to use PyTorch Numeric Suite to measure the statistics between static quantized model and float model in eager mode. Then we will use LSTM based sequence model as an example to show the usage of PyTorch Numeric Suite for dynamic quantized model.
16
17
Numeric Suite for Static Quantization
18
-------------------------------------
19
20
Setup
21
^^^^^^
22
We’ll start by doing the necessary imports:
23
"""
24
25
##############################################################################
26
27
import numpy as np
28
import torch
29
import torch.nn as nn
30
import torchvision
31
from torchvision import models, datasets
32
import torchvision.transforms as transforms
33
import os
34
import torch.quantization
35
import torch.quantization._numeric_suite as ns
36
from torch.quantization import (
37
default_eval_fn,
38
default_qconfig,
39
quantize,
40
)
41
42
##############################################################################
43
# Then we load the pretrained float ResNet18 model, and quantize it into qmodel. We cannot compare two arbitrary models, only a float model and the quantized model derived from it can be compared.
44
45
46
float_model = torchvision.models.quantization.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1, quantize=False)
47
float_model.to('cpu')
48
float_model.eval()
49
float_model.fuse_model()
50
float_model.qconfig = torch.quantization.default_qconfig
51
img_data = [(torch.rand(2, 3, 10, 10, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long)) for _ in range(2)]
52
qmodel = quantize(float_model, default_eval_fn, [img_data], inplace=False)
53
54
##############################################################################
55
# 1. Compare the weights of float and quantized models
56
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
57
# The first thing we usually want to compare are the weights of quantized model and float model.
58
# We can call ``compare_weights()`` from PyTorch Numeric Suite to get a dictionary ``wt_compare_dict`` with key corresponding to module names and each entry is a dictionary with two keys 'float' and 'quantized', containing the float and quantized weights.
59
# ``compare_weights()`` takes in floating point and quantized state dict and returns a dict, with keys corresponding to the
60
# floating point weights and values being a dictionary of floating point and quantized weights
61
62
wt_compare_dict = ns.compare_weights(float_model.state_dict(), qmodel.state_dict())
63
64
print('keys of wt_compare_dict:')
65
print(wt_compare_dict.keys())
66
67
print("\nkeys of wt_compare_dict entry for conv1's weight:")
68
print(wt_compare_dict['conv1.weight'].keys())
69
print(wt_compare_dict['conv1.weight']['float'].shape)
70
print(wt_compare_dict['conv1.weight']['quantized'].shape)
71
72
73
##############################################################################
74
# Once get ``wt_compare_dict``, users can process this dictionary in whatever way they want. Here as an example we compute the quantization error of the weights of float and quantized models as following.
75
# Compute the Signal-to-Quantization-Noise Ratio (SQNR) of the quantized tensor ``y``. The SQNR reflects the
76
# relationship between the maximum nominal signal strength and the quantization error introduced in the
77
# quantization. Higher SQNR corresponds to lower quantization error.
78
79
def compute_error(x, y):
80
Ps = torch.norm(x)
81
Pn = torch.norm(x-y)
82
return 20*torch.log10(Ps/Pn)
83
84
for key in wt_compare_dict:
85
print(key, compute_error(wt_compare_dict[key]['float'], wt_compare_dict[key]['quantized'].dequantize()))
86
87
##############################################################################
88
# As another example ``wt_compare_dict`` can also be used to plot the histogram of the weights of floating point and quantized models.
89
90
import matplotlib.pyplot as plt
91
92
f = wt_compare_dict['conv1.weight']['float'].flatten()
93
plt.hist(f, bins = 100)
94
plt.title("Floating point model weights of conv1")
95
plt.show()
96
97
q = wt_compare_dict['conv1.weight']['quantized'].flatten().dequantize()
98
plt.hist(q, bins = 100)
99
plt.title("Quantized model weights of conv1")
100
plt.show()
101
102
103
104
##############################################################################
105
#
106
# 2. Compare float point and quantized models at corresponding locations
107
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
108
#
109
# The second tool allows for comparison of weights and activations between float and quantized models at corresponding locations for the same input as shown in the figure below. Red arrows indicate the locations of the comparison.
110
#
111
# .. figure:: /_static/img/compare_output.png
112
#
113
# We call ``compare_model_outputs()`` from PyTorch Numeric Suite to get the activations in float model and quantized model at corresponding locations for the given input data. This API returns a dict with module names being keys. Each entry is itself a dict with two keys 'float' and 'quantized' containing the activations.
114
data = img_data[0][0]
115
116
# Take in floating point and quantized model as well as input data, and returns a dict, with keys
117
# corresponding to the quantized module names and each entry being a dictionary with two keys 'float' and
118
# 'quantized', containing the activations of floating point and quantized model at matching locations.
119
act_compare_dict = ns.compare_model_outputs(float_model, qmodel, data)
120
121
print('keys of act_compare_dict:')
122
print(act_compare_dict.keys())
123
124
print("\nkeys of act_compare_dict entry for conv1's output:")
125
print(act_compare_dict['conv1.stats'].keys())
126
print(act_compare_dict['conv1.stats']['float'][0].shape)
127
print(act_compare_dict['conv1.stats']['quantized'][0].shape)
128
129
##############################################################################
130
# This dict can be used to compare and compute the quantization error of the activations of float and quantized models as following.
131
for key in act_compare_dict:
132
print(key, compute_error(act_compare_dict[key]['float'][0], act_compare_dict[key]['quantized'][0].dequantize()))
133
134
##############################################################################
135
# If we want to do the comparison for more than one input data, we can do the following.
136
# Prepare the model by attaching the logger to both floating point module and quantized
137
# module if they are in the ``white_list``. Default logger is ``OutputLogger``, and default white_list
138
# is ``DEFAULT_NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_WHITE_LIST``
139
ns.prepare_model_outputs(float_model, qmodel)
140
141
for data in img_data:
142
float_model(data[0])
143
qmodel(data[0])
144
145
# Find the matching activation between floating point and quantized modules, and return a dict with key
146
# corresponding to quantized module names and each entry being a dictionary with two keys 'float'
147
# and 'quantized', containing the matching floating point and quantized activations logged by the logger
148
act_compare_dict = ns.get_matching_activations(float_model, qmodel)
149
150
151
##############################################################################
152
# The default logger used in above APIs is ``OutputLogger``, which is used to log the outputs of the modules. We can inherit from base ``Logger`` class and create our own logger to perform different functionalities. For example we can make a new ``MyOutputLogger`` class as below.
153
154
class MyOutputLogger(ns.Logger):
155
r"""Customized logger class
156
"""
157
158
def __init__(self):
159
super(MyOutputLogger, self).__init__()
160
161
def forward(self, x):
162
# Custom functionalities
163
# ...
164
return x
165
166
##############################################################################
167
# And then we can pass this logger into above APIs such as:
168
169
data = img_data[0][0]
170
act_compare_dict = ns.compare_model_outputs(float_model, qmodel, data, logger_cls=MyOutputLogger)
171
172
##############################################################################
173
# or:
174
175
ns.prepare_model_outputs(float_model, qmodel, MyOutputLogger)
176
for data in img_data:
177
float_model(data[0])
178
qmodel(data[0])
179
act_compare_dict = ns.get_matching_activations(float_model, qmodel)
180
181
182
183
##############################################################################
184
#
185
# 3. Compare a module in a quantized model with its float point equivalent, with the same input data
186
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
187
#
188
# The third tool allows for comparing a quantized module in a model with its float point counterpart, feeding both of them the same input and comparing their outputs as shown below.
189
#
190
# .. figure:: /_static/img/compare_stub.png
191
#
192
# In practice we call prepare_model_with_stubs() to swap the quantized module that we want to compare with the Shadow module, which is illustrated as below:
193
#
194
# .. figure:: /_static/img/shadow.png
195
#
196
# The Shadow module takes quantized module, float module and logger as input, and creates a forward path inside to make the float module to shadow quantized module sharing the same input tensor.
197
#
198
# The logger can be customizable, default logger is ``ShadowLogger`` and it will save the outputs of the quantized module and float module that can be used to compute the module level quantization error.
199
#
200
# Notice before each call of ``compare_model_outputs()`` and ``compare_model_stub()`` we need to have clean float and quantized model. This is because ``compare_model_outputs()`` and ``compare_model_stub()`` modify float and quantized model inplace, and it will cause unexpected results if call one right after another.
201
202
float_model = torchvision.models.quantization.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1, quantize=False)
203
float_model.to('cpu')
204
float_model.eval()
205
float_model.fuse_model()
206
float_model.qconfig = torch.quantization.default_qconfig
207
img_data = [(torch.rand(2, 3, 10, 10, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long)) for _ in range(2)]
208
qmodel = quantize(float_model, default_eval_fn, [img_data], inplace=False)
209
210
##############################################################################
211
# In the following example we call ``compare_model_stub()`` from PyTorch Numeric Suite to compare ``QuantizableBasicBlock`` module with its float point equivalent. This API returns a dict with key corresponding to module names and each entry being a dictionary with two keys 'float' and 'quantized', containing the output tensors of quantized and its matching float shadow module.
212
213
data = img_data[0][0]
214
module_swap_list = [torchvision.models.quantization.resnet.QuantizableBasicBlock]
215
216
# Takes in floating point and quantized model as well as input data, and returns a dict with key
217
# corresponding to module names and each entry being a dictionary with two keys 'float' and
218
# 'quantized', containing the output tensors of quantized module and its matching floating point shadow module.
219
ob_dict = ns.compare_model_stub(float_model, qmodel, module_swap_list, data)
220
221
print('keys of ob_dict:')
222
print(ob_dict.keys())
223
224
print("\nkeys of ob_dict entry for layer1.0's output:")
225
print(ob_dict['layer1.0.stats'].keys())
226
print(ob_dict['layer1.0.stats']['float'][0].shape)
227
print(ob_dict['layer1.0.stats']['quantized'][0].shape)
228
229
##############################################################################
230
# This dict can be then used to compare and compute the module level quantization error.
231
232
for key in ob_dict:
233
print(key, compute_error(ob_dict[key]['float'][0], ob_dict[key]['quantized'][0].dequantize()))
234
235
##############################################################################
236
# If we want to do the comparison for more than one input data, we can do the following.
237
238
ns.prepare_model_with_stubs(float_model, qmodel, module_swap_list, ns.ShadowLogger)
239
for data in img_data:
240
qmodel(data[0])
241
ob_dict = ns.get_logger_dict(qmodel)
242
243
##############################################################################
244
# The default logger used in above APIs is ``ShadowLogger``, which is used to log the outputs of the quantized module and its matching float shadow module. We can inherit from base ``Logger`` class and create our own logger to perform different functionalities. For example we can make a new ``MyShadowLogger`` class as below.
245
246
class MyShadowLogger(ns.Logger):
247
r"""Customized logger class
248
"""
249
250
def __init__(self):
251
super(MyShadowLogger, self).__init__()
252
253
def forward(self, x, y):
254
# Custom functionalities
255
# ...
256
return x
257
258
##############################################################################
259
# And then we can pass this logger into above APIs such as:
260
261
data = img_data[0][0]
262
ob_dict = ns.compare_model_stub(float_model, qmodel, module_swap_list, data, logger_cls=MyShadowLogger)
263
264
##############################################################################
265
# or:
266
267
ns.prepare_model_with_stubs(float_model, qmodel, module_swap_list, MyShadowLogger)
268
for data in img_data:
269
qmodel(data[0])
270
ob_dict = ns.get_logger_dict(qmodel)
271
272
###############################################################################
273
# Numeric Suite for Dynamic Quantization
274
# --------------------------------------
275
#
276
# Numeric Suite APIs are designed in such as way that they work for both dynamic quantized model and static quantized model. We will use a model with both LSTM and Linear modules to demonstrate the usage of Numeric Suite on dynamic quantized model. This model is the same one used in the tutorial of dynamic quantization on LSTM word language model [1].
277
#
278
279
#################################
280
# Setup
281
# ^^^^^^
282
# First we define the model as below. Notice that within this model only ``nn.LSTM`` and ``nn.Linear`` modules will be quantized dynamically and ``nn.Embedding`` will remain as floating point module after quantization.
283
284
class LSTMModel(nn.Module):
285
"""Container module with an encoder, a recurrent module, and a decoder."""
286
287
def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
288
super(LSTMModel, self).__init__()
289
self.encoder = nn.Embedding(ntoken, ninp)
290
self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
291
self.decoder = nn.Linear(nhid, ntoken)
292
293
self.init_weights()
294
295
self.nhid = nhid
296
self.nlayers = nlayers
297
298
def init_weights(self):
299
initrange = 0.1
300
self.encoder.weight.data.uniform_(-initrange, initrange)
301
self.decoder.bias.data.zero_()
302
self.decoder.weight.data.uniform_(-initrange, initrange)
303
304
def forward(self, input, hidden):
305
emb = self.encoder(input)
306
output, hidden = self.rnn(emb, hidden)
307
decoded = self.decoder(output)
308
return decoded, hidden
309
310
def init_hidden(self, bsz):
311
weight = next(self.parameters())
312
return (weight.new_zeros(self.nlayers, bsz, self.nhid),
313
weight.new_zeros(self.nlayers, bsz, self.nhid))
314
315
##############################################################################
316
# Then we create the ``float_model`` and quantize it into qmodel.
317
318
ntokens = 10
319
320
float_model = LSTMModel(
321
ntoken = ntokens,
322
ninp = 512,
323
nhid = 256,
324
nlayers = 5,
325
)
326
327
float_model.eval()
328
329
qmodel = torch.quantization.quantize_dynamic(
330
float_model, {nn.LSTM, nn.Linear}, dtype=torch.qint8
331
)
332
333
##############################################################################
334
#
335
# 1. Compare the weights of float and quantized models
336
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
337
#
338
# We first call ``compare_weights()`` from PyTorch Numeric Suite to get a dictionary ``wt_compare_dict`` with key corresponding to module names and each entry is a dictionary with two keys 'float' and 'quantized', containing the float and quantized weights.
339
340
wt_compare_dict = ns.compare_weights(float_model.state_dict(), qmodel.state_dict())
341
342
##############################################################################
343
# Once we get ``wt_compare_dict``, it can be used to compare and compute the quantization error of the weights of float and quantized models as following.
344
345
for key in wt_compare_dict:
346
if wt_compare_dict[key]['quantized'].is_quantized:
347
print(key, compute_error(wt_compare_dict[key]['float'], wt_compare_dict[key]['quantized'].dequantize()))
348
else:
349
print(key, compute_error(wt_compare_dict[key]['float'], wt_compare_dict[key]['quantized']))
350
351
##############################################################################
352
#
353
# The Inf value in ``encoder.weight`` entry above is because encoder module is not quantized and the weights are the same in both floating point and quantized models.
354
#
355
# 2. Compare float point and quantized models at corresponding locations
356
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
357
#
358
# Then we call ``compare_model_outputs()`` from PyTorch Numeric Suite to get the activations in float model and quantized model at corresponding locations for the given input data. This API returns a dict with module names being keys. Each entry is itself a dict with two keys 'float' and 'quantized' containing the activations. Notice that this sequence model has two inputs, and we can pass both inputs into ``compare_model_outputs()`` and ``compare_model_stub()``.
359
360
361
input_ = torch.randint(ntokens, (1, 1), dtype=torch.long)
362
hidden = float_model.init_hidden(1)
363
364
act_compare_dict = ns.compare_model_outputs(float_model, qmodel, input_, hidden)
365
print(act_compare_dict.keys())
366
367
##############################################################################
368
# This dict can be used to compare and compute the quantization error of the activations of float and quantized models as following. The LSTM module in this model has two outputs, in this example we compute the error of the first output.
369
370
371
for key in act_compare_dict:
372
print(key, compute_error(act_compare_dict[key]['float'][0][0], act_compare_dict[key]['quantized'][0][0]))
373
374
##############################################################################
375
#
376
# 3. Compare a module in a quantized model with its float point equivalent, with the same input data
377
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
378
#
379
# Next we call ``compare_model_stub()`` from PyTorch Numeric Suite to compare LSTM and Linear module with its float point equivalent. This API returns a dict with key corresponding to module names and each entry being a dictionary with two keys 'float' and 'quantized', containing the output tensors of quantized and its matching float shadow module.
380
#
381
# We reset the model first.
382
383
384
float_model = LSTMModel(
385
ntoken = ntokens,
386
ninp = 512,
387
nhid = 256,
388
nlayers = 5,
389
)
390
float_model.eval()
391
392
qmodel = torch.quantization.quantize_dynamic(
393
float_model, {nn.LSTM, nn.Linear}, dtype=torch.qint8
394
)
395
396
##############################################################################
397
# Next we call ``compare_model_stub()`` from PyTorch Numeric Suite to compare LSTM and Linear module with its float point equivalent. This API returns a dict with key corresponding to module names and each entry being a dictionary with two keys 'float' and 'quantized', containing the output tensors of quantized and its matching float shadow module.
398
399
module_swap_list = [nn.Linear, nn.LSTM]
400
ob_dict = ns.compare_model_stub(float_model, qmodel, module_swap_list, input_, hidden)
401
print(ob_dict.keys())
402
403
##############################################################################
404
# This dict can be then used to compare and compute the module level quantization error.
405
406
for key in ob_dict:
407
print(key, compute_error(ob_dict[key]['float'][0], ob_dict[key]['quantized'][0]))
408
409
##############################################################################
410
# SQNR of 40 dB is high and this is a situation where we have very good numerical alignment between the floating point and quantized model.
411
#
412
# Conclusion
413
# ----------
414
# In this tutorial, we demonstrated how to use PyTorch Numeric Suite to measure and compare the statistics between quantized model and float model in eager mode with unified APIs for both static quantization and dynamic quantization.
415
#
416
# Thanks for reading! As always, we welcome any feedback, so please create an issue `here <https://github.com/pytorch/pytorch/issues>`_ if you have any.
417
#
418
# References
419
# ----------
420
# [1] `DYNAMIC QUANTIZATION ON AN LSTM WORD LANGUAGE MODEL <https://pytorch.org/tutorials/advanced/dynamic_quantization_tutorial.html>`_.
421
422