CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
Path: blob/main/prototype_source/fx_graph_mode_ptq_dynamic.py
Views: 494
"""1(prototype) FX Graph Mode Post Training Dynamic Quantization2============================================================34**Author**: `Jerry Zhang <https://github.com/jerryzh168>`_56This tutorial introduces the steps to do post training dynamic quantization in graph mode based on ``torch.fx``.7We have a separate tutorial for `FX Graph Mode Post Training Static Quantization <https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static.html>`_,8comparison between FX Graph Mode Quantization and Eager Mode Quantization can be found in the `quantization docs <https://pytorch.org/docs/master/quantization.html#quantization-api-summary>`_910tldr; The FX Graph Mode API for dynamic quantization looks like the following:1112.. code:: python1314import torch15from torch.ao.quantization import default_dynamic_qconfig, QConfigMapping16# Note that this is temporary, we'll expose these functions to torch.ao.quantization after official releasee17from torch.quantization.quantize_fx import prepare_fx, convert_fx1819float_model.eval()20# The old 'fbgemm' is still available but 'x86' is the recommended default.21qconfig = get_default_qconfig("x86")22qconfig_mapping = QConfigMapping().set_global(qconfig)23prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs) # fuse modules and insert observers24# no calibration is required for dynamic quantization25quantized_model = convert_fx(prepared_model) # convert the model to a dynamically quantized model2627In this tutorial, we’ll apply dynamic quantization to an LSTM-based next word-prediction model,28closely following the word language model from the PyTorch examples.29We will copy the code from `Dynamic Quantization on an LSTM Word Language Model <https://pytorch.org/tutorials/advanced/dynamic_quantization_tutorial.html>`_30and omit the descriptions.3132"""333435###################################################36# 1. Define the Model, Download Data and Model37# --------------------------------------------38#39# Download the `data <https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip>`_40# and unzip to data folder41#42# .. code::43#44# mkdir data45# cd data46# wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip47# unzip wikitext-2-v1.zip48#49# Download model to the data folder:50#51# .. code::52#53# wget https://s3.amazonaws.com/pytorch-tutorial-assets/word_language_model_quantize.pth54#55# Define the model:5657# imports58import os59from io import open60import time61import copy6263import torch64import torch.nn as nn65import torch.nn.functional as F6667# Model Definition68class LSTMModel(nn.Module):69"""Container module with an encoder, a recurrent module, and a decoder."""7071def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):72super(LSTMModel, self).__init__()73self.drop = nn.Dropout(dropout)74self.encoder = nn.Embedding(ntoken, ninp)75self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)76self.decoder = nn.Linear(nhid, ntoken)7778self.init_weights()7980self.nhid = nhid81self.nlayers = nlayers8283def init_weights(self):84initrange = 0.185self.encoder.weight.data.uniform_(-initrange, initrange)86self.decoder.bias.data.zero_()87self.decoder.weight.data.uniform_(-initrange, initrange)8889def forward(self, input, hidden):90emb = self.drop(self.encoder(input))91output, hidden = self.rnn(emb, hidden)92output = self.drop(output)93decoded = self.decoder(output)94return decoded, hidden959697def init_hidden(lstm_model, bsz):98# get the weight tensor and create hidden layer in the same device99weight = lstm_model.encoder.weight100# get weight from quantized model101if not isinstance(weight, torch.Tensor):102weight = weight()103device = weight.device104nlayers = lstm_model.rnn.num_layers105nhid = lstm_model.rnn.hidden_size106return (torch.zeros(nlayers, bsz, nhid, device=device),107torch.zeros(nlayers, bsz, nhid, device=device))108109110# Load Text Data111class Dictionary(object):112def __init__(self):113self.word2idx = {}114self.idx2word = []115116def add_word(self, word):117if word not in self.word2idx:118self.idx2word.append(word)119self.word2idx[word] = len(self.idx2word) - 1120return self.word2idx[word]121122def __len__(self):123return len(self.idx2word)124125126class Corpus(object):127def __init__(self, path):128self.dictionary = Dictionary()129self.train = self.tokenize(os.path.join(path, 'wiki.train.tokens'))130self.valid = self.tokenize(os.path.join(path, 'wiki.valid.tokens'))131self.test = self.tokenize(os.path.join(path, 'wiki.test.tokens'))132133def tokenize(self, path):134"""Tokenizes a text file."""135assert os.path.exists(path)136# Add words to the dictionary137with open(path, 'r', encoding="utf8") as f:138for line in f:139words = line.split() + ['<eos>']140for word in words:141self.dictionary.add_word(word)142143# Tokenize file content144with open(path, 'r', encoding="utf8") as f:145idss = []146for line in f:147words = line.split() + ['<eos>']148ids = []149for word in words:150ids.append(self.dictionary.word2idx[word])151idss.append(torch.tensor(ids).type(torch.int64))152ids = torch.cat(idss)153154return ids155156model_data_filepath = 'data/'157158corpus = Corpus(model_data_filepath + 'wikitext-2')159160ntokens = len(corpus.dictionary)161162# Load Pretrained Model163model = LSTMModel(164ntoken = ntokens,165ninp = 512,166nhid = 256,167nlayers = 5,168)169170model.load_state_dict(171torch.load(172model_data_filepath + 'word_language_model_quantize.pth',173map_location=torch.device('cpu'),174weights_only=True175)176)177178model.eval()179print(model)180181bptt = 25182criterion = nn.CrossEntropyLoss()183eval_batch_size = 1184185# create test data set186def batchify(data, bsz):187# Work out how cleanly we can divide the dataset into bsz parts.188nbatch = data.size(0) // bsz189# Trim off any extra elements that wouldn't cleanly fit (remainders).190data = data.narrow(0, 0, nbatch * bsz)191# Evenly divide the data across the bsz batches.192return data.view(bsz, -1).t().contiguous()193194test_data = batchify(corpus.test, eval_batch_size)195example_inputs = (next(iter(test_data))[0])196197# Evaluation functions198def get_batch(source, i):199seq_len = min(bptt, len(source) - 1 - i)200data = source[i:i+seq_len]201target = source[i+1:i+1+seq_len].reshape(-1)202return data, target203204def repackage_hidden(h):205"""Wraps hidden states in new Tensors, to detach them from their history."""206207if isinstance(h, torch.Tensor):208return h.detach()209else:210return tuple(repackage_hidden(v) for v in h)211212def evaluate(model_, data_source):213# Turn on evaluation mode which disables dropout.214model_.eval()215total_loss = 0.216hidden = init_hidden(model_, eval_batch_size)217with torch.no_grad():218for i in range(0, data_source.size(0) - 1, bptt):219data, targets = get_batch(data_source, i)220output, hidden = model_(data, hidden)221hidden = repackage_hidden(hidden)222output_flat = output.view(-1, ntokens)223total_loss += len(data) * criterion(output_flat, targets).item()224return total_loss / (len(data_source) - 1)225226######################################################################227# 2. Post Training Dynamic Quantization228# -------------------------------------229# Now we can dynamically quantize the model.230# We can use the same function as post training static quantization but with a dynamic qconfig.231232from torch.quantization.quantize_fx import prepare_fx, convert_fx233from torch.ao.quantization import default_dynamic_qconfig, float_qparams_weight_only_qconfig, QConfigMapping234235# Full docs for supported qconfig for floating point modules/ops can be found in `quantization docs <https://pytorch.org/docs/stable/quantization.html#module-torch.quantization>`_236# Full docs for `QConfigMapping <https://pytorch.org/docs/stable/generated/torch.ao.quantization.qconfig_mapping.QConfigMapping.html#torch.ao.quantization.qconfig_mapping.QConfigMapping>`_237qconfig_mapping = (QConfigMapping()238.set_object_type(nn.Embedding, float_qparams_weight_only_qconfig)239.set_object_type(nn.LSTM, default_dynamic_qconfig)240.set_object_type(nn.Linear, default_dynamic_qconfig)241)242# Load model to create the original model because quantization api changes the model inplace and we want243# to keep the original model for future comparison244245246model_to_quantize = LSTMModel(247ntoken = ntokens,248ninp = 512,249nhid = 256,250nlayers = 5,251)252253model_to_quantize.load_state_dict(254torch.load(255model_data_filepath + 'word_language_model_quantize.pth',256map_location=torch.device('cpu')257)258)259260model_to_quantize.eval()261262263prepared_model = prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)264print("prepared model:", prepared_model)265quantized_model = convert_fx(prepared_model)266print("quantized model", quantized_model)267268269######################################################################270# For dynamically quantized objects, we didn't do anything in ``prepare_fx`` for modules,271# but will insert observers for weight for dynamically quantizable forunctionals and torch ops.272# We also fuse the modules like Conv + Bn, Linear + ReLU.273#274# In convert we'll convert the float modules to dynamically quantized modules and275# convert float ops to dynamically quantized ops. We can see in the example model,276# ``nn.Embedding``, ``nn.Linear`` and ``nn.LSTM`` are dynamically quantized.277#278# Now we can compare the size and runtime of the quantized model.279280def print_size_of_model(model):281torch.save(model.state_dict(), "temp.p")282print('Size (MB):', os.path.getsize("temp.p")/1e6)283os.remove('temp.p')284285print_size_of_model(model)286print_size_of_model(quantized_model)287288######################################################################289# There is a 4x size reduction because we quantized all the weights290# in the model (nn.Embedding, nn.Linear and nn.LSTM) from float (4 bytes) to quantized int (1 byte).291292torch.set_num_threads(1)293294def time_model_evaluation(model, test_data):295s = time.time()296loss = evaluate(model, test_data)297elapsed = time.time() - s298print('''loss: {0:.3f}\nelapsed time (seconds): {1:.1f}'''.format(loss, elapsed))299300time_model_evaluation(model, test_data)301time_model_evaluation(quantized_model, test_data)302303#####################################################################304# There is a roughly 2x speedup for this model. Also note that the speedup305# may vary depending on model, device, build, input batch sizes, threading etc.306#307# 3. Conclusion308# -------------309# This tutorial introduces the api for post training dynamic quantization in FX Graph Mode,310# which dynamically quantizes the same modules as Eager Mode Quantization.311312313