CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/advanced_source/dynamic_quantization_tutorial.py
Views: 494
1
"""
2
(beta) Dynamic Quantization on an LSTM Word Language Model
3
==================================================================
4
5
**Author**: `James Reed <https://github.com/jamesr66a>`_
6
7
**Edited by**: `Seth Weidman <https://github.com/SethHWeidman/>`_
8
9
Introduction
10
------------
11
12
Quantization involves converting the weights and activations of your model from float
13
to int, which can result in smaller model size and faster inference with only a small
14
hit to accuracy.
15
16
In this tutorial, we will apply the easiest form of quantization -
17
`dynamic quantization <https://pytorch.org/docs/stable/quantization.html#torch.quantization.quantize_dynamic>`_ -
18
to an LSTM-based next word-prediction model, closely following the
19
`word language model <https://github.com/pytorch/examples/tree/master/word_language_model>`_
20
from the PyTorch examples.
21
"""
22
23
# imports
24
import os
25
from io import open
26
import time
27
28
import torch
29
import torch.nn as nn
30
import torch.nn.functional as F
31
32
######################################################################
33
# 1. Define the model
34
# -------------------
35
#
36
# Here we define the LSTM model architecture, following the
37
# `model <https://github.com/pytorch/examples/blob/master/word_language_model/model.py>`_
38
# from the word language model example.
39
40
class LSTMModel(nn.Module):
41
"""Container module with an encoder, a recurrent module, and a decoder."""
42
43
def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
44
super(LSTMModel, self).__init__()
45
self.drop = nn.Dropout(dropout)
46
self.encoder = nn.Embedding(ntoken, ninp)
47
self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
48
self.decoder = nn.Linear(nhid, ntoken)
49
50
self.init_weights()
51
52
self.nhid = nhid
53
self.nlayers = nlayers
54
55
def init_weights(self):
56
initrange = 0.1
57
self.encoder.weight.data.uniform_(-initrange, initrange)
58
self.decoder.bias.data.zero_()
59
self.decoder.weight.data.uniform_(-initrange, initrange)
60
61
def forward(self, input, hidden):
62
emb = self.drop(self.encoder(input))
63
output, hidden = self.rnn(emb, hidden)
64
output = self.drop(output)
65
decoded = self.decoder(output)
66
return decoded, hidden
67
68
def init_hidden(self, bsz):
69
weight = next(self.parameters())
70
return (weight.new_zeros(self.nlayers, bsz, self.nhid),
71
weight.new_zeros(self.nlayers, bsz, self.nhid))
72
73
######################################################################
74
# 2. Load in the text data
75
# ------------------------
76
#
77
# Next, we load the
78
# `Wikitext-2 dataset <https://www.google.com/search?q=wikitext+2+data>`_ into a `Corpus`,
79
# again following the
80
# `preprocessing <https://github.com/pytorch/examples/blob/master/word_language_model/data.py>`_
81
# from the word language model example.
82
83
class Dictionary(object):
84
def __init__(self):
85
self.word2idx = {}
86
self.idx2word = []
87
88
def add_word(self, word):
89
if word not in self.word2idx:
90
self.idx2word.append(word)
91
self.word2idx[word] = len(self.idx2word) - 1
92
return self.word2idx[word]
93
94
def __len__(self):
95
return len(self.idx2word)
96
97
98
class Corpus(object):
99
def __init__(self, path):
100
self.dictionary = Dictionary()
101
self.train = self.tokenize(os.path.join(path, 'train.txt'))
102
self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
103
self.test = self.tokenize(os.path.join(path, 'test.txt'))
104
105
def tokenize(self, path):
106
"""Tokenizes a text file."""
107
assert os.path.exists(path)
108
# Add words to the dictionary
109
with open(path, 'r', encoding="utf8") as f:
110
for line in f:
111
words = line.split() + ['<eos>']
112
for word in words:
113
self.dictionary.add_word(word)
114
115
# Tokenize file content
116
with open(path, 'r', encoding="utf8") as f:
117
idss = []
118
for line in f:
119
words = line.split() + ['<eos>']
120
ids = []
121
for word in words:
122
ids.append(self.dictionary.word2idx[word])
123
idss.append(torch.tensor(ids).type(torch.int64))
124
ids = torch.cat(idss)
125
126
return ids
127
128
model_data_filepath = 'data/'
129
130
corpus = Corpus(model_data_filepath + 'wikitext-2')
131
132
######################################################################
133
# 3. Load the pretrained model
134
# -----------------------------
135
#
136
# This is a tutorial on dynamic quantization, a quantization technique
137
# that is applied after a model has been trained. Therefore, we'll simply load some
138
# pretrained weights into this model architecture; these weights were obtained
139
# by training for five epochs using the default settings in the word language model
140
# example.
141
142
ntokens = len(corpus.dictionary)
143
144
model = LSTMModel(
145
ntoken = ntokens,
146
ninp = 512,
147
nhid = 256,
148
nlayers = 5,
149
)
150
151
model.load_state_dict(
152
torch.load(
153
model_data_filepath + 'word_language_model_quantize.pth',
154
map_location=torch.device('cpu'),
155
weights_only=True
156
)
157
)
158
159
model.eval()
160
print(model)
161
162
######################################################################
163
# Now let's generate some text to ensure that the pretrained model is working
164
# properly - similarly to before, we follow
165
# `here <https://github.com/pytorch/examples/blob/master/word_language_model/generate.py>`_
166
167
input_ = torch.randint(ntokens, (1, 1), dtype=torch.long)
168
hidden = model.init_hidden(1)
169
temperature = 1.0
170
num_words = 1000
171
172
with open(model_data_filepath + 'out.txt', 'w') as outf:
173
with torch.no_grad(): # no tracking history
174
for i in range(num_words):
175
output, hidden = model(input_, hidden)
176
word_weights = output.squeeze().div(temperature).exp().cpu()
177
word_idx = torch.multinomial(word_weights, 1)[0]
178
input_.fill_(word_idx)
179
180
word = corpus.dictionary.idx2word[word_idx]
181
182
outf.write(str(word.encode('utf-8')) + ('\n' if i % 20 == 19 else ' '))
183
184
if i % 100 == 0:
185
print('| Generated {}/{} words'.format(i, 1000))
186
187
with open(model_data_filepath + 'out.txt', 'r') as outf:
188
all_output = outf.read()
189
print(all_output)
190
191
######################################################################
192
# It's no GPT-2, but it looks like the model has started to learn the structure of
193
# language!
194
#
195
# We're almost ready to demonstrate dynamic quantization. We just need to define a few more
196
# helper functions:
197
198
bptt = 25
199
criterion = nn.CrossEntropyLoss()
200
eval_batch_size = 1
201
202
# create test data set
203
def batchify(data, bsz):
204
# Work out how cleanly we can divide the dataset into ``bsz`` parts.
205
nbatch = data.size(0) // bsz
206
# Trim off any extra elements that wouldn't cleanly fit (remainders).
207
data = data.narrow(0, 0, nbatch * bsz)
208
# Evenly divide the data across the ``bsz`` batches.
209
return data.view(bsz, -1).t().contiguous()
210
211
test_data = batchify(corpus.test, eval_batch_size)
212
213
# Evaluation functions
214
def get_batch(source, i):
215
seq_len = min(bptt, len(source) - 1 - i)
216
data = source[i:i+seq_len]
217
target = source[i+1:i+1+seq_len].reshape(-1)
218
return data, target
219
220
def repackage_hidden(h):
221
"""Wraps hidden states in new Tensors, to detach them from their history."""
222
223
if isinstance(h, torch.Tensor):
224
return h.detach()
225
else:
226
return tuple(repackage_hidden(v) for v in h)
227
228
def evaluate(model_, data_source):
229
# Turn on evaluation mode which disables dropout.
230
model_.eval()
231
total_loss = 0.
232
hidden = model_.init_hidden(eval_batch_size)
233
with torch.no_grad():
234
for i in range(0, data_source.size(0) - 1, bptt):
235
data, targets = get_batch(data_source, i)
236
output, hidden = model_(data, hidden)
237
hidden = repackage_hidden(hidden)
238
output_flat = output.view(-1, ntokens)
239
total_loss += len(data) * criterion(output_flat, targets).item()
240
return total_loss / (len(data_source) - 1)
241
242
######################################################################
243
# 4. Test dynamic quantization
244
# ----------------------------
245
#
246
# Finally, we can call ``torch.quantization.quantize_dynamic`` on the model!
247
# Specifically,
248
#
249
# - We specify that we want the ``nn.LSTM`` and ``nn.Linear`` modules in our
250
# model to be quantized
251
# - We specify that we want weights to be converted to ``int8`` values
252
253
import torch.quantization
254
255
quantized_model = torch.quantization.quantize_dynamic(
256
model, {nn.LSTM, nn.Linear}, dtype=torch.qint8
257
)
258
print(quantized_model)
259
260
######################################################################
261
# The model looks the same; how has this benefited us? First, we see a
262
# significant reduction in model size:
263
264
def print_size_of_model(model):
265
torch.save(model.state_dict(), "temp.p")
266
print('Size (MB):', os.path.getsize("temp.p")/1e6)
267
os.remove('temp.p')
268
269
print_size_of_model(model)
270
print_size_of_model(quantized_model)
271
272
######################################################################
273
# Second, we see faster inference time, with no difference in evaluation loss:
274
#
275
# Note: we set the number of threads to one for single threaded comparison, since quantized
276
# models run single threaded.
277
278
torch.set_num_threads(1)
279
280
def time_model_evaluation(model, test_data):
281
s = time.time()
282
loss = evaluate(model, test_data)
283
elapsed = time.time() - s
284
print('''loss: {0:.3f}\nelapsed time (seconds): {1:.1f}'''.format(loss, elapsed))
285
286
time_model_evaluation(model, test_data)
287
time_model_evaluation(quantized_model, test_data)
288
289
######################################################################
290
# Running this locally on a MacBook Pro, without quantization, inference takes about 200 seconds,
291
# and with quantization it takes just about 100 seconds.
292
#
293
# Conclusion
294
# ----------
295
#
296
# Dynamic quantization can be an easy way to reduce model size while only
297
# having a limited effect on accuracy.
298
#
299
# Thanks for reading! As always, we welcome any feedback, so please create an issue
300
# `here <https://github.com/pytorch/pytorch/issues>`_ if you have any.
301
302