Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/advanced_source/dynamic_quantization_tutorial.py
Views: 712
"""1(beta) Dynamic Quantization on an LSTM Word Language Model2==================================================================34**Author**: `James Reed <https://github.com/jamesr66a>`_56**Edited by**: `Seth Weidman <https://github.com/SethHWeidman/>`_78Introduction9------------1011Quantization involves converting the weights and activations of your model from float12to int, which can result in smaller model size and faster inference with only a small13hit to accuracy.1415In this tutorial, we will apply the easiest form of quantization -16`dynamic quantization <https://pytorch.org/docs/stable/quantization.html#torch.quantization.quantize_dynamic>`_ -17to an LSTM-based next word-prediction model, closely following the18`word language model <https://github.com/pytorch/examples/tree/master/word_language_model>`_19from the PyTorch examples.20"""2122# imports23import os24from io import open25import time2627import torch28import torch.nn as nn29import torch.nn.functional as F3031######################################################################32# 1. Define the model33# -------------------34#35# Here we define the LSTM model architecture, following the36# `model <https://github.com/pytorch/examples/blob/master/word_language_model/model.py>`_37# from the word language model example.3839class LSTMModel(nn.Module):40"""Container module with an encoder, a recurrent module, and a decoder."""4142def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):43super(LSTMModel, self).__init__()44self.drop = nn.Dropout(dropout)45self.encoder = nn.Embedding(ntoken, ninp)46self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)47self.decoder = nn.Linear(nhid, ntoken)4849self.init_weights()5051self.nhid = nhid52self.nlayers = nlayers5354def init_weights(self):55initrange = 0.156self.encoder.weight.data.uniform_(-initrange, initrange)57self.decoder.bias.data.zero_()58self.decoder.weight.data.uniform_(-initrange, initrange)5960def forward(self, input, hidden):61emb = self.drop(self.encoder(input))62output, hidden = self.rnn(emb, hidden)63output = self.drop(output)64decoded = self.decoder(output)65return decoded, hidden6667def init_hidden(self, bsz):68weight = next(self.parameters())69return (weight.new_zeros(self.nlayers, bsz, self.nhid),70weight.new_zeros(self.nlayers, bsz, self.nhid))7172######################################################################73# 2. Load in the text data74# ------------------------75#76# Next, we load the77# `Wikitext-2 dataset <https://www.google.com/search?q=wikitext+2+data>`_ into a `Corpus`,78# again following the79# `preprocessing <https://github.com/pytorch/examples/blob/master/word_language_model/data.py>`_80# from the word language model example.8182class Dictionary(object):83def __init__(self):84self.word2idx = {}85self.idx2word = []8687def add_word(self, word):88if word not in self.word2idx:89self.idx2word.append(word)90self.word2idx[word] = len(self.idx2word) - 191return self.word2idx[word]9293def __len__(self):94return len(self.idx2word)959697class Corpus(object):98def __init__(self, path):99self.dictionary = Dictionary()100self.train = self.tokenize(os.path.join(path, 'train.txt'))101self.valid = self.tokenize(os.path.join(path, 'valid.txt'))102self.test = self.tokenize(os.path.join(path, 'test.txt'))103104def tokenize(self, path):105"""Tokenizes a text file."""106assert os.path.exists(path)107# Add words to the dictionary108with open(path, 'r', encoding="utf8") as f:109for line in f:110words = line.split() + ['<eos>']111for word in words:112self.dictionary.add_word(word)113114# Tokenize file content115with open(path, 'r', encoding="utf8") as f:116idss = []117for line in f:118words = line.split() + ['<eos>']119ids = []120for word in words:121ids.append(self.dictionary.word2idx[word])122idss.append(torch.tensor(ids).type(torch.int64))123ids = torch.cat(idss)124125return ids126127model_data_filepath = 'data/'128129corpus = Corpus(model_data_filepath + 'wikitext-2')130131######################################################################132# 3. Load the pretrained model133# -----------------------------134#135# This is a tutorial on dynamic quantization, a quantization technique136# that is applied after a model has been trained. Therefore, we'll simply load some137# pretrained weights into this model architecture; these weights were obtained138# by training for five epochs using the default settings in the word language model139# example.140141ntokens = len(corpus.dictionary)142143model = LSTMModel(144ntoken = ntokens,145ninp = 512,146nhid = 256,147nlayers = 5,148)149150model.load_state_dict(151torch.load(152model_data_filepath + 'word_language_model_quantize.pth',153map_location=torch.device('cpu'),154weights_only=True155)156)157158model.eval()159print(model)160161######################################################################162# Now let's generate some text to ensure that the pretrained model is working163# properly - similarly to before, we follow164# `here <https://github.com/pytorch/examples/blob/master/word_language_model/generate.py>`_165166input_ = torch.randint(ntokens, (1, 1), dtype=torch.long)167hidden = model.init_hidden(1)168temperature = 1.0169num_words = 1000170171with open(model_data_filepath + 'out.txt', 'w') as outf:172with torch.no_grad(): # no tracking history173for i in range(num_words):174output, hidden = model(input_, hidden)175word_weights = output.squeeze().div(temperature).exp().cpu()176word_idx = torch.multinomial(word_weights, 1)[0]177input_.fill_(word_idx)178179word = corpus.dictionary.idx2word[word_idx]180181outf.write(str(word.encode('utf-8')) + ('\n' if i % 20 == 19 else ' '))182183if i % 100 == 0:184print('| Generated {}/{} words'.format(i, 1000))185186with open(model_data_filepath + 'out.txt', 'r') as outf:187all_output = outf.read()188print(all_output)189190######################################################################191# It's no GPT-2, but it looks like the model has started to learn the structure of192# language!193#194# We're almost ready to demonstrate dynamic quantization. We just need to define a few more195# helper functions:196197bptt = 25198criterion = nn.CrossEntropyLoss()199eval_batch_size = 1200201# create test data set202def batchify(data, bsz):203# Work out how cleanly we can divide the dataset into ``bsz`` parts.204nbatch = data.size(0) // bsz205# Trim off any extra elements that wouldn't cleanly fit (remainders).206data = data.narrow(0, 0, nbatch * bsz)207# Evenly divide the data across the ``bsz`` batches.208return data.view(bsz, -1).t().contiguous()209210test_data = batchify(corpus.test, eval_batch_size)211212# Evaluation functions213def get_batch(source, i):214seq_len = min(bptt, len(source) - 1 - i)215data = source[i:i+seq_len]216target = source[i+1:i+1+seq_len].reshape(-1)217return data, target218219def repackage_hidden(h):220"""Wraps hidden states in new Tensors, to detach them from their history."""221222if isinstance(h, torch.Tensor):223return h.detach()224else:225return tuple(repackage_hidden(v) for v in h)226227def evaluate(model_, data_source):228# Turn on evaluation mode which disables dropout.229model_.eval()230total_loss = 0.231hidden = model_.init_hidden(eval_batch_size)232with torch.no_grad():233for i in range(0, data_source.size(0) - 1, bptt):234data, targets = get_batch(data_source, i)235output, hidden = model_(data, hidden)236hidden = repackage_hidden(hidden)237output_flat = output.view(-1, ntokens)238total_loss += len(data) * criterion(output_flat, targets).item()239return total_loss / (len(data_source) - 1)240241######################################################################242# 4. Test dynamic quantization243# ----------------------------244#245# Finally, we can call ``torch.quantization.quantize_dynamic`` on the model!246# Specifically,247#248# - We specify that we want the ``nn.LSTM`` and ``nn.Linear`` modules in our249# model to be quantized250# - We specify that we want weights to be converted to ``int8`` values251252import torch.quantization253254quantized_model = torch.quantization.quantize_dynamic(255model, {nn.LSTM, nn.Linear}, dtype=torch.qint8256)257print(quantized_model)258259######################################################################260# The model looks the same; how has this benefited us? First, we see a261# significant reduction in model size:262263def print_size_of_model(model):264torch.save(model.state_dict(), "temp.p")265print('Size (MB):', os.path.getsize("temp.p")/1e6)266os.remove('temp.p')267268print_size_of_model(model)269print_size_of_model(quantized_model)270271######################################################################272# Second, we see faster inference time, with no difference in evaluation loss:273#274# Note: we set the number of threads to one for single threaded comparison, since quantized275# models run single threaded.276277torch.set_num_threads(1)278279def time_model_evaluation(model, test_data):280s = time.time()281loss = evaluate(model, test_data)282elapsed = time.time() - s283print('''loss: {0:.3f}\nelapsed time (seconds): {1:.1f}'''.format(loss, elapsed))284285time_model_evaluation(model, test_data)286time_model_evaluation(quantized_model, test_data)287288######################################################################289# Running this locally on a MacBook Pro, without quantization, inference takes about 200 seconds,290# and with quantization it takes just about 100 seconds.291#292# Conclusion293# ----------294#295# Dynamic quantization can be an easy way to reduce model size while only296# having a limited effect on accuracy.297#298# Thanks for reading! As always, we welcome any feedback, so please create an issue299# `here <https://github.com/pytorch/pytorch/issues>`_ if you have any.300301302