Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place.
Path: blob/main/recipes_source/recipes/dynamic_quantization.py
Views: 713
"""1Dynamic Quantization2====================34In this recipe you will see how to take advantage of Dynamic5Quantization to accelerate inference on an LSTM-style recurrent neural6network. This reduces the size of the model weights and speeds up model7execution.89Introduction10-------------1112There are a number of trade-offs that can be made when designing neural13networks. During model development and training you can alter the14number of layers and number of parameters in a recurrent neural network15and trade-off accuracy against model size and/or model latency or16throughput. Such changes can take lot of time and compute resources17because you are iterating over the model training. Quantization gives18you a way to make a similar trade off between performance and model19accuracy with a known model after training is completed.2021You can give it a try in a single session and you will certainly reduce22your model size significantly and may get a significant latency23reduction without losing a lot of accuracy.2425What is dynamic quantization?26-----------------------------2728Quantizing a network means converting it to use a reduced precision29integer representation for the weights and/or activations. This saves on30model size and allows the use of higher throughput math operations on31your CPU or GPU.3233When converting from floating point to integer values you are34essentially multiplying the floating point value by some scale factor35and rounding the result to a whole number. The various quantization36approaches differ in the way they approach determining that scale37factor.3839The key idea with dynamic quantization as described here is that we are40going to determine the scale factor for activations dynamically based on41the data range observed at runtime. This ensures that the scale factor42is "tuned" so that as much signal as possible about each observed43dataset is preserved.4445The model parameters on the other hand are known during model conversion46and they are converted ahead of time and stored in INT8 form.4748Arithmetic in the quantized model is done using vectorized INT849instructions. Accumulation is typically done with INT16 or INT32 to50avoid overflow. This higher precision value is scaled back to INT8 if51the next layer is quantized or converted to FP32 for output.5253Dynamic quantization is relatively free of tuning parameters which makes54it well suited to be added into production pipelines as a standard part55of converting LSTM models to deployment.56575859.. note::60Limitations on the approach taken here616263This recipe provides a quick introduction to the dynamic quantization64features in PyTorch and the workflow for using it. Our focus is on65explaining the specific functions used to convert the model. We will66make a number of significant simplifications in the interest of brevity67and clarity6869701. You will start with a minimal LSTM network712. You are simply going to initialize the network with a random hidden72state733. You are going to test the network with random inputs744. You are not going to train the network in this tutorial755. You will see that the quantized form of this network is smaller and76runs faster than the floating point network we started with776. You will see that the output values are generally in the same78ballpark as the output of the FP32 network, but we are not79demonstrating here the expected accuracy loss on a real trained80network8182You will see how dynamic quantization is done and be able to see83suggestive reductions in memory use and latency times. Providing a84demonstration that the technique can preserve high levels of model85accuracy on a trained LSTM is left to a more advanced tutorial. If you86want to move right away to that more rigorous treatment please proceed87to the `advanced dynamic quantization88tutorial <https://pytorch.org/tutorials/advanced/dynamic_quantization_tutorial.html>`__.8990Steps91-------------9293This recipe has 5 steps.94951. Set Up - Here you define a very simple LSTM, import modules, and establish96some random input tensors.97982. Do the Quantization - Here you instantiate a floating point model and then create quantized99version of it.1001013. Look at Model Size - Here you show that the model size gets smaller.1021034. Look at Latency - Here you run the two models and compare model runtime (latency).1041055. Look at Accuracy - Here you run the two models and compare outputs.1061071081: Set Up109~~~~~~~~~~~~~~~110This is a straightforward bit of code to set up for the rest of the111recipe.112113The unique module we are importing here is torch.quantization which114includes PyTorch's quantized operators and conversion functions. We also115define a very simple LSTM model and set up some inputs.116117"""118119# import the modules used here in this recipe120import torch121import torch.quantization122import torch.nn as nn123import copy124import os125import time126127# define a very, very simple LSTM for demonstration purposes128# in this case, we are wrapping ``nn.LSTM``, one layer, no preprocessing or postprocessing129# inspired by130# `Sequence Models and Long Short-Term Memory Networks tutorial <https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html`_, by Robert Guthrie131# and `Dynamic Quanitzation tutorial <https://pytorch.org/tutorials/advanced/dynamic_quantization_tutorial.html>`__.132class lstm_for_demonstration(nn.Module):133"""Elementary Long Short Term Memory style model which simply wraps ``nn.LSTM``134Not to be used for anything other than demonstration.135"""136def __init__(self,in_dim,out_dim,depth):137super(lstm_for_demonstration,self).__init__()138self.lstm = nn.LSTM(in_dim,out_dim,depth)139140def forward(self,inputs,hidden):141out,hidden = self.lstm(inputs,hidden)142return out, hidden143144145torch.manual_seed(29592) # set the seed for reproducibility146147#shape parameters148model_dimension=8149sequence_length=20150batch_size=1151lstm_depth=1152153# random data for input154inputs = torch.randn(sequence_length,batch_size,model_dimension)155# hidden is actually is a tuple of the initial hidden state and the initial cell state156hidden = (torch.randn(lstm_depth,batch_size,model_dimension), torch.randn(lstm_depth,batch_size,model_dimension))157158159######################################################################160# 2: Do the Quantization161# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~162#163# Now we get to the fun part. First we create an instance of the model164# called ``float_lstm`` then we are going to quantize it. We're going to use165# the `torch.quantization.quantize_dynamic <https://pytorch.org/docs/stable/quantization.html#torch.quantization.quantize_dynamic>`__ function, which takes the model, then a list of the submodules166# which we want to167# have quantized if they appear, then the datatype we are targeting. This168# function returns a quantized version of the original model as a new169# module.170#171# That's all it takes.172#173174# here is our floating point instance175float_lstm = lstm_for_demonstration(model_dimension, model_dimension,lstm_depth)176177# this is the call that does the work178quantized_lstm = torch.quantization.quantize_dynamic(179float_lstm, {nn.LSTM, nn.Linear}, dtype=torch.qint8180)181182# show the changes that were made183print('Here is the floating point version of this module:')184print(float_lstm)185print('')186print('and now the quantized version:')187print(quantized_lstm)188189190######################################################################191# 3. Look at Model Size192# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~193# We've quantized the model. What does that get us? Well the first194# benefit is that we've replaced the FP32 model parameters with INT8195# values (and some recorded scale factors). This means about 75% less data196# to store and move around. With the default values the reduction shown197# below will be less than 75% but if you increase the model size above198# (for example you can set model dimension to something like 80) this will199# converge towards 4x smaller as the stored model size dominated more and200# more by the parameter values.201#202203def print_size_of_model(model, label=""):204torch.save(model.state_dict(), "temp.p")205size=os.path.getsize("temp.p")206print("model: ",label,' \t','Size (KB):', size/1e3)207os.remove('temp.p')208return size209210# compare the sizes211f=print_size_of_model(float_lstm,"fp32")212q=print_size_of_model(quantized_lstm,"int8")213print("{0:.2f} times smaller".format(f/q))214215216######################################################################217# 4. Look at Latency218# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~219# The second benefit is that the quantized model will typically run220# faster. This is due to a combinations of effects including at least:221#222# 1. Less time spent moving parameter data in223# 2. Faster INT8 operations224#225# As you will see the quantized version of this super-simple network runs226# faster. This will generally be true of more complex networks but as they227# say "your mileage may vary" depending on a number of factors including228# the structure of the model and the hardware you are running on.229#230231# compare the performance232print("Floating point FP32")233234#####################################################################235# .. code-block:: python236#237# %timeit float_lstm.forward(inputs, hidden)238239print("Quantized INT8")240241######################################################################242# .. code-block:: python243#244# %timeit quantized_lstm.forward(inputs,hidden)245246247######################################################################248# 5: Look at Accuracy249# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~250# We are not going to do a careful look at accuracy here because we are251# working with a randomly initialized network rather than a properly252# trained one. However, I think it is worth quickly showing that the253# quantized network does produce output tensors that are "in the same254# ballpark" as the original one.255#256# For a more detailed analysis please see the more advanced tutorials257# referenced at the end of this recipe.258#259260# run the float model261out1, hidden1 = float_lstm(inputs, hidden)262mag1 = torch.mean(abs(out1)).item()263print('mean absolute value of output tensor values in the FP32 model is {0:.5f} '.format(mag1))264265# run the quantized model266out2, hidden2 = quantized_lstm(inputs, hidden)267mag2 = torch.mean(abs(out2)).item()268print('mean absolute value of output tensor values in the INT8 model is {0:.5f}'.format(mag2))269270# compare them271mag3 = torch.mean(abs(out1-out2)).item()272print('mean absolute value of the difference between the output tensors is {0:.5f} or {1:.2f} percent'.format(mag3,mag3/mag1*100))273274275######################################################################276# Learn More277# ------------278# We've explained what dynamic quantization is, what benefits it brings,279# and you have used the ``torch.quantization.quantize_dynamic()`` function280# to quickly quantize a simple LSTM model.281#282# This was a fast and high level treatment of this material; for more283# detail please continue learning with `(beta) Dynamic Quantization on an LSTM Word Language Model Tutorial <https://pytorch.org/tutorials/advanced/dynamic\_quantization\_tutorial.html>`_.284#285#286# Additional Resources287# --------------------288#289# * `Quantization API Documentaion <https://pytorch.org/docs/stable/quantization.html>`_290# * `(beta) Dynamic Quantization on BERT <https://pytorch.org/tutorials/intermediate/dynamic\_quantization\_bert\_tutorial.html>`_291# * `(beta) Dynamic Quantization on an LSTM Word Language Model <https://pytorch.org/tutorials/advanced/dynamic\_quantization\_tutorial.html>`_292# * `Introduction to Quantization on PyTorch <https://pytorch.org/blog/introduction-to-quantization-on-pytorch/>`_293#294295296