CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/prototype_source/fx_graph_mode_ptq_dynamic.py
Views: 494
1
"""
2
(prototype) FX Graph Mode Post Training Dynamic Quantization
3
============================================================
4
5
**Author**: `Jerry Zhang <https://github.com/jerryzh168>`_
6
7
This tutorial introduces the steps to do post training dynamic quantization in graph mode based on ``torch.fx``.
8
We have a separate tutorial for `FX Graph Mode Post Training Static Quantization <https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static.html>`_,
9
comparison between FX Graph Mode Quantization and Eager Mode Quantization can be found in the `quantization docs <https://pytorch.org/docs/master/quantization.html#quantization-api-summary>`_
10
11
tldr; The FX Graph Mode API for dynamic quantization looks like the following:
12
13
.. code:: python
14
15
import torch
16
from torch.ao.quantization import default_dynamic_qconfig, QConfigMapping
17
# Note that this is temporary, we'll expose these functions to torch.ao.quantization after official releasee
18
from torch.quantization.quantize_fx import prepare_fx, convert_fx
19
20
float_model.eval()
21
# The old 'fbgemm' is still available but 'x86' is the recommended default.
22
qconfig = get_default_qconfig("x86")
23
qconfig_mapping = QConfigMapping().set_global(qconfig)
24
prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs) # fuse modules and insert observers
25
# no calibration is required for dynamic quantization
26
quantized_model = convert_fx(prepared_model) # convert the model to a dynamically quantized model
27
28
In this tutorial, we’ll apply dynamic quantization to an LSTM-based next word-prediction model,
29
closely following the word language model from the PyTorch examples.
30
We will copy the code from `Dynamic Quantization on an LSTM Word Language Model <https://pytorch.org/tutorials/advanced/dynamic_quantization_tutorial.html>`_
31
and omit the descriptions.
32
33
"""
34
35
36
###################################################
37
# 1. Define the Model, Download Data and Model
38
# --------------------------------------------
39
#
40
# Download the `data <https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip>`_
41
# and unzip to data folder
42
#
43
# .. code::
44
#
45
# mkdir data
46
# cd data
47
# wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip
48
# unzip wikitext-2-v1.zip
49
#
50
# Download model to the data folder:
51
#
52
# .. code::
53
#
54
# wget https://s3.amazonaws.com/pytorch-tutorial-assets/word_language_model_quantize.pth
55
#
56
# Define the model:
57
58
# imports
59
import os
60
from io import open
61
import time
62
import copy
63
64
import torch
65
import torch.nn as nn
66
import torch.nn.functional as F
67
68
# Model Definition
69
class LSTMModel(nn.Module):
70
"""Container module with an encoder, a recurrent module, and a decoder."""
71
72
def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
73
super(LSTMModel, self).__init__()
74
self.drop = nn.Dropout(dropout)
75
self.encoder = nn.Embedding(ntoken, ninp)
76
self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
77
self.decoder = nn.Linear(nhid, ntoken)
78
79
self.init_weights()
80
81
self.nhid = nhid
82
self.nlayers = nlayers
83
84
def init_weights(self):
85
initrange = 0.1
86
self.encoder.weight.data.uniform_(-initrange, initrange)
87
self.decoder.bias.data.zero_()
88
self.decoder.weight.data.uniform_(-initrange, initrange)
89
90
def forward(self, input, hidden):
91
emb = self.drop(self.encoder(input))
92
output, hidden = self.rnn(emb, hidden)
93
output = self.drop(output)
94
decoded = self.decoder(output)
95
return decoded, hidden
96
97
98
def init_hidden(lstm_model, bsz):
99
# get the weight tensor and create hidden layer in the same device
100
weight = lstm_model.encoder.weight
101
# get weight from quantized model
102
if not isinstance(weight, torch.Tensor):
103
weight = weight()
104
device = weight.device
105
nlayers = lstm_model.rnn.num_layers
106
nhid = lstm_model.rnn.hidden_size
107
return (torch.zeros(nlayers, bsz, nhid, device=device),
108
torch.zeros(nlayers, bsz, nhid, device=device))
109
110
111
# Load Text Data
112
class Dictionary(object):
113
def __init__(self):
114
self.word2idx = {}
115
self.idx2word = []
116
117
def add_word(self, word):
118
if word not in self.word2idx:
119
self.idx2word.append(word)
120
self.word2idx[word] = len(self.idx2word) - 1
121
return self.word2idx[word]
122
123
def __len__(self):
124
return len(self.idx2word)
125
126
127
class Corpus(object):
128
def __init__(self, path):
129
self.dictionary = Dictionary()
130
self.train = self.tokenize(os.path.join(path, 'wiki.train.tokens'))
131
self.valid = self.tokenize(os.path.join(path, 'wiki.valid.tokens'))
132
self.test = self.tokenize(os.path.join(path, 'wiki.test.tokens'))
133
134
def tokenize(self, path):
135
"""Tokenizes a text file."""
136
assert os.path.exists(path)
137
# Add words to the dictionary
138
with open(path, 'r', encoding="utf8") as f:
139
for line in f:
140
words = line.split() + ['<eos>']
141
for word in words:
142
self.dictionary.add_word(word)
143
144
# Tokenize file content
145
with open(path, 'r', encoding="utf8") as f:
146
idss = []
147
for line in f:
148
words = line.split() + ['<eos>']
149
ids = []
150
for word in words:
151
ids.append(self.dictionary.word2idx[word])
152
idss.append(torch.tensor(ids).type(torch.int64))
153
ids = torch.cat(idss)
154
155
return ids
156
157
model_data_filepath = 'data/'
158
159
corpus = Corpus(model_data_filepath + 'wikitext-2')
160
161
ntokens = len(corpus.dictionary)
162
163
# Load Pretrained Model
164
model = LSTMModel(
165
ntoken = ntokens,
166
ninp = 512,
167
nhid = 256,
168
nlayers = 5,
169
)
170
171
model.load_state_dict(
172
torch.load(
173
model_data_filepath + 'word_language_model_quantize.pth',
174
map_location=torch.device('cpu'),
175
weights_only=True
176
)
177
)
178
179
model.eval()
180
print(model)
181
182
bptt = 25
183
criterion = nn.CrossEntropyLoss()
184
eval_batch_size = 1
185
186
# create test data set
187
def batchify(data, bsz):
188
# Work out how cleanly we can divide the dataset into bsz parts.
189
nbatch = data.size(0) // bsz
190
# Trim off any extra elements that wouldn't cleanly fit (remainders).
191
data = data.narrow(0, 0, nbatch * bsz)
192
# Evenly divide the data across the bsz batches.
193
return data.view(bsz, -1).t().contiguous()
194
195
test_data = batchify(corpus.test, eval_batch_size)
196
example_inputs = (next(iter(test_data))[0])
197
198
# Evaluation functions
199
def get_batch(source, i):
200
seq_len = min(bptt, len(source) - 1 - i)
201
data = source[i:i+seq_len]
202
target = source[i+1:i+1+seq_len].reshape(-1)
203
return data, target
204
205
def repackage_hidden(h):
206
"""Wraps hidden states in new Tensors, to detach them from their history."""
207
208
if isinstance(h, torch.Tensor):
209
return h.detach()
210
else:
211
return tuple(repackage_hidden(v) for v in h)
212
213
def evaluate(model_, data_source):
214
# Turn on evaluation mode which disables dropout.
215
model_.eval()
216
total_loss = 0.
217
hidden = init_hidden(model_, eval_batch_size)
218
with torch.no_grad():
219
for i in range(0, data_source.size(0) - 1, bptt):
220
data, targets = get_batch(data_source, i)
221
output, hidden = model_(data, hidden)
222
hidden = repackage_hidden(hidden)
223
output_flat = output.view(-1, ntokens)
224
total_loss += len(data) * criterion(output_flat, targets).item()
225
return total_loss / (len(data_source) - 1)
226
227
######################################################################
228
# 2. Post Training Dynamic Quantization
229
# -------------------------------------
230
# Now we can dynamically quantize the model.
231
# We can use the same function as post training static quantization but with a dynamic qconfig.
232
233
from torch.quantization.quantize_fx import prepare_fx, convert_fx
234
from torch.ao.quantization import default_dynamic_qconfig, float_qparams_weight_only_qconfig, QConfigMapping
235
236
# Full docs for supported qconfig for floating point modules/ops can be found in `quantization docs <https://pytorch.org/docs/stable/quantization.html#module-torch.quantization>`_
237
# Full docs for `QConfigMapping <https://pytorch.org/docs/stable/generated/torch.ao.quantization.qconfig_mapping.QConfigMapping.html#torch.ao.quantization.qconfig_mapping.QConfigMapping>`_
238
qconfig_mapping = (QConfigMapping()
239
.set_object_type(nn.Embedding, float_qparams_weight_only_qconfig)
240
.set_object_type(nn.LSTM, default_dynamic_qconfig)
241
.set_object_type(nn.Linear, default_dynamic_qconfig)
242
)
243
# Load model to create the original model because quantization api changes the model inplace and we want
244
# to keep the original model for future comparison
245
246
247
model_to_quantize = LSTMModel(
248
ntoken = ntokens,
249
ninp = 512,
250
nhid = 256,
251
nlayers = 5,
252
)
253
254
model_to_quantize.load_state_dict(
255
torch.load(
256
model_data_filepath + 'word_language_model_quantize.pth',
257
map_location=torch.device('cpu')
258
)
259
)
260
261
model_to_quantize.eval()
262
263
264
prepared_model = prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
265
print("prepared model:", prepared_model)
266
quantized_model = convert_fx(prepared_model)
267
print("quantized model", quantized_model)
268
269
270
######################################################################
271
# For dynamically quantized objects, we didn't do anything in ``prepare_fx`` for modules,
272
# but will insert observers for weight for dynamically quantizable forunctionals and torch ops.
273
# We also fuse the modules like Conv + Bn, Linear + ReLU.
274
#
275
# In convert we'll convert the float modules to dynamically quantized modules and
276
# convert float ops to dynamically quantized ops. We can see in the example model,
277
# ``nn.Embedding``, ``nn.Linear`` and ``nn.LSTM`` are dynamically quantized.
278
#
279
# Now we can compare the size and runtime of the quantized model.
280
281
def print_size_of_model(model):
282
torch.save(model.state_dict(), "temp.p")
283
print('Size (MB):', os.path.getsize("temp.p")/1e6)
284
os.remove('temp.p')
285
286
print_size_of_model(model)
287
print_size_of_model(quantized_model)
288
289
######################################################################
290
# There is a 4x size reduction because we quantized all the weights
291
# in the model (nn.Embedding, nn.Linear and nn.LSTM) from float (4 bytes) to quantized int (1 byte).
292
293
torch.set_num_threads(1)
294
295
def time_model_evaluation(model, test_data):
296
s = time.time()
297
loss = evaluate(model, test_data)
298
elapsed = time.time() - s
299
print('''loss: {0:.3f}\nelapsed time (seconds): {1:.1f}'''.format(loss, elapsed))
300
301
time_model_evaluation(model, test_data)
302
time_model_evaluation(quantized_model, test_data)
303
304
#####################################################################
305
# There is a roughly 2x speedup for this model. Also note that the speedup
306
# may vary depending on model, device, build, input batch sizes, threading etc.
307
#
308
# 3. Conclusion
309
# -------------
310
# This tutorial introduces the api for post training dynamic quantization in FX Graph Mode,
311
# which dynamically quantizes the same modules as Eager Mode Quantization.
312
313