CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
Path: blob/main/prototype_source/gpu_quantization_torchao_tutorial.py
Views: 494
"""1(prototype) GPU Quantization with TorchAO2======================================================34**Author**: `HDCharles <https://github.com/HDCharles>`_56In this tutorial, we will walk you through the quantization and optimization7of the popular `segment anything model <https://github.com/facebookresearch/segment-anything>`_. These8steps will mimic some of those taken to develop the9`segment-anything-fast <https://github.com/pytorch-labs/segment-anything-fast/blob/main/segment_anything_fast/modeling/image_encoder.py#L15>`_10repo. This step-by-step guide demonstrates how you can11apply these techniques to speed up your own models, especially those12that use transformers. To that end, we will focus on widely applicable13techniques, such as optimizing performance with ``torch.compile`` and14quantization and measure their impact.1516"""171819######################################################################20# Set up Your Environment21# --------------------------------22#23# First, let's configure your environment. This guide was written for CUDA 12.1.24# We have run this tutorial on an A100-PG509-200 power limited to 330.00 W. If you25# are using a different hardware, you might see different performance numbers.26#27#28# .. code-block:: bash29#30# > conda create -n myenv python=3.1031# > pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu12132# > pip install git+https://github.com/facebookresearch/segment-anything.git33# > pip install git+https://github.com/pytorch-labs/ao.git34#35# Segment Anything Model checkpoint setup:36#37# 1. Go to the `segment-anything repo checkpoint <https://github.com/facebookresearch/segment-anything/tree/main#model-checkpoints>`_ and download the ``vit_h`` checkpoint. Alternatively, you can use ``wget`` (for example, ``wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth --directory-prefix=<path>``).38# 2. Pass in that directory by editing the code below to say:39#40# .. code-block:: bash41#42# {sam_checkpoint_base_path}=<path>43#4445import torch46from torchao.quantization import change_linear_weights_to_int8_dqtensors47from segment_anything import sam_model_registry48from torch.utils.benchmark import Timer4950sam_checkpoint_base_path = "data"51model_type = 'vit_h'52model_name = 'sam_vit_h_4b8939.pth'53checkpoint_path = f"{sam_checkpoint_base_path}/{model_name}"54batchsize = 1655only_one_block = True565758@torch.no_grad()59def benchmark(f, *args, **kwargs):60for _ in range(3):61f(*args, **kwargs)62torch.cuda.synchronize()6364torch.cuda.reset_peak_memory_stats()65t0 = Timer(66stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}67)68res = t0.adaptive_autorange(.03, min_run_time=.2, max_run_time=20)69return {'time':res.median * 1e3, 'memory': torch.cuda.max_memory_allocated()/1e9}7071def get_sam_model(only_one_block=False, batchsize=1):72sam = sam_model_registry[model_type](checkpoint=checkpoint_path).cuda()73model = sam.image_encoder.eval()74image = torch.randn(batchsize, 3, 1024, 1024, device='cuda')7576# code to use just a single block of the model77if only_one_block:78model = model.blocks[0]79image = torch.randn(batchsize, 64, 64, 1280, device='cuda')80return model, image818283######################################################################84# In this tutorial, we focus on quantizing the ``image_encoder`` because the85# inputs to it are statically sized while the prompt encoder and mask86# decoder have variable sizes which makes them harder to quantize.87#88# We’ll focus on just a single block at first to make the analysis easier.89#90# Let's start by measuring the baseline runtime.9192try:93model, image = get_sam_model(only_one_block, batchsize)94fp32_res = benchmark(model, image)95print(f"base fp32 runtime of the model is {fp32_res['time']:0.2f}ms and peak memory {fp32_res['memory']:0.2f}GB")96# base fp32 runtime of the model is 186.16ms and peak memory 6.33GB97except Exception as e:98print("unable to run fp32 model: ", e)99100101102######################################################################103# We can achieve an instant performance boost by converting the model to bfloat16.104# The reason we opt for bfloat16 over fp16 is due to its dynamic range, which is comparable to105# that of fp32. Both bfloat16 and fp32 possess 8 exponential bits, whereas fp16 only has 4. This106# larger dynamic range helps protect us from overflow errors and other issues that can arise107# when scaling and rescaling tensors due to quantization.108#109110model, image = get_sam_model(only_one_block, batchsize)111model = model.to(torch.bfloat16)112image = image.to(torch.bfloat16)113bf16_res = benchmark(model, image)114print(f"bf16 runtime of the block is {bf16_res['time']:0.2f}ms and peak memory {bf16_res['memory']: 0.2f}GB")115# bf16 runtime of the block is 25.43ms and peak memory 3.17GB116117118######################################################################119# Just this quick change improves runtime by a factor of ~7x in the tests we have120# conducted (186.16ms to 25.43ms).121#122# Next, let's use ``torch.compile`` with our model to see how much the performance123# improves.124#125126model_c = torch.compile(model, mode='max-autotune')127comp_res = benchmark(model_c, image)128print(f"bf16 compiled runtime of the block is {comp_res['time']:0.2f}ms and peak memory {comp_res['memory']: 0.2f}GB")129# bf16 compiled runtime of the block is 19.95ms and peak memory 2.24GB130131132######################################################################133# The first time this is run, you should see a sequence of ``AUTOTUNE``134# outputs which occurs when inductor compares the performance between135# various kernel parameters for a kernel. This only happens once (unless136# you delete your cache) so if you run the cell again you should just get137# the benchmark output.138#139# ``torch.compile`` yields about another 27% improvement. This brings the140# model to a reasonable baseline where we now have to work a bit harder141# for improvements.142#143# Next, let's apply quantization. Quantization for GPUs comes in three main forms144# in `torchao <https://github.com/pytorch-labs/ao>`_ which is just native145# pytorch+python code. This includes:146#147# * int8 dynamic quantization148# * int8 weight-only quantization149# * int4 weight-only quantization150#151# Different models, or sometimes different layers in a model can require different techniques.152# For models which are heavily compute bound, dynamic quantization tends153# to work the best since it swaps the normal expensive floating point154# matmul ops with integer versions. Weight-only quantization works better155# in memory bound situations where the benefit comes from loading less156# weight data, rather than doing less computation. The torchao APIs:157#158# ``change_linear_weights_to_int8_dqtensors``,159# ``change_linear_weights_to_int8_woqtensors`` or160# ``change_linear_weights_to_int4_woqtensors``161#162# can be used to easily apply the desired quantization technique and then163# once the model is compiled with ``torch.compile`` with ``max-autotune``, quantization is164# complete and we can see our speedup.165#166# .. note::167# You might experience issues with these on older versions of PyTorch. If you run168# into an issue, you can use ``apply_dynamic_quant`` and169# ``apply_weight_only_int8_quant`` instead as drop in replacement for the two170# above (no replacement for int4).171#172# The difference between the two APIs is that ``change_linear_weights`` API173# alters the weight tensor of the linear module so instead of doing a174# normal linear, it does a quantized operation. This is helpful when you175# have non-standard linear ops that do more than one thing. The ``apply``176# APIs directly swap the linear modules for a quantized module which177# works on older versions but doesn’t work with non-standard linear178# modules.179#180# In this case Segment Anything is compute-bound so we’ll use dynamic quantization:181#182183del model_c, model, image184model, image = get_sam_model(only_one_block, batchsize)185model = model.to(torch.bfloat16)186image = image.to(torch.bfloat16)187change_linear_weights_to_int8_dqtensors(model)188model_c = torch.compile(model, mode='max-autotune')189quant_res = benchmark(model_c, image)190print(f"bf16 compiled runtime of the quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")191# bf16 compiled runtime of the quantized block is 19.04ms and peak memory 3.58GB192193194######################################################################195# With quantization, we have improved performance a bit more but memory usage increased196# significantly.197#198# This is for two reasons:199#200# 1) Quantization adds overhead to the model201# since we need to quantize and dequantize the input and output. For small202# batch sizes this overhead can actually make the model go slower.203# 2) Even though we are doing a quantized matmul, such as ``int8 x int8``,204# the result of the multiplication gets stored in an int32 tensor205# which is twice the size of the result from the non-quantized model.206# If we can avoid creating this int32 tensor, our memory usage will improve a lot.207#208# We can fix #2 by fusing the integer matmul with the subsequent rescale209# operation since the final output will be bf16, if we immediately convert210# the int32 tensor to bf16 and instead store that we’ll get better211# performance in terms of both runtime and memory.212#213# The way to do this, is to enable the option214# ``force_fuse_int_mm_with_mul`` in the inductor config.215#216217del model_c, model, image218model, image = get_sam_model(only_one_block, batchsize)219model = model.to(torch.bfloat16)220image = image.to(torch.bfloat16)221torch._inductor.config.force_fuse_int_mm_with_mul = True222change_linear_weights_to_int8_dqtensors(model)223model_c = torch.compile(model, mode='max-autotune')224quant_res = benchmark(model_c, image)225print(f"bf16 compiled runtime of the fused quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")226# bf16 compiled runtime of the fused quantized block is 18.78ms and peak memory 2.37GB227228229######################################################################230# The fusion improves performance by another small bit (about 6% over the231# baseline in total) and removes almost all the memory increase, the232# remaining amount (2.37GB quantized vs 2.24GB unquantized) is due to233# quantization overhead which cannot be helped.234#235# We’re still not done though, we can apply a few general purpose236# optimizations to get our final best-case performance.237#238# 1) We can sometimes improve performance by disabling epilogue fusion239# since the autotuning process can be confused by fusions and choose240# bad kernel parameters.241# 2) We can apply coordinate descent tuning in all directions to enlarge242# the search area for kernel parameters.243#244245del model_c, model, image246model, image = get_sam_model(only_one_block, batchsize)247model = model.to(torch.bfloat16)248image = image.to(torch.bfloat16)249torch._inductor.config.epilogue_fusion = False250torch._inductor.config.coordinate_descent_tuning = True251torch._inductor.config.coordinate_descent_check_all_directions = True252torch._inductor.config.force_fuse_int_mm_with_mul = True253change_linear_weights_to_int8_dqtensors(model)254model_c = torch.compile(model, mode='max-autotune')255quant_res = benchmark(model_c, image)256print(f"bf16 compiled runtime of the final quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")257# bf16 compiled runtime of the final quantized block is 18.16ms and peak memory 2.39GB258259260######################################################################261# As you can see, we’ve squeezed another small improvement from the model,262# taking our total improvement to over 10x compared to our original. To263# get a final estimate of the impact of quantization lets do an apples to264# apples comparison on the full model since the actual improvement will265# differ block by block depending on the shapes involved.266#267268try:269del model_c, model, image270model, image = get_sam_model(False, batchsize)271model = model.to(torch.bfloat16)272image = image.to(torch.bfloat16)273model_c = torch.compile(model, mode='max-autotune')274quant_res = benchmark(model_c, image)275print(f"bf16 compiled runtime of the compiled full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")276# bf16 compiled runtime of the compiled full model is 729.65ms and peak memory 23.96GB277278del model_c, model, image279model, image = get_sam_model(False, batchsize)280model = model.to(torch.bfloat16)281image = image.to(torch.bfloat16)282change_linear_weights_to_int8_dqtensors(model)283model_c = torch.compile(model, mode='max-autotune')284quant_res = benchmark(model_c, image)285print(f"bf16 compiled runtime of the quantized full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")286# bf16 compiled runtime of the quantized full model is 677.28ms and peak memory 24.93GB287except Exception as e:288print("unable to run full model: ", e)289290291292######################################################################293# Conclusion294# -----------------295# In this tutorial, we have learned about the quantization and optimization techniques296# on the example of the segment anything model.297#298# In the end, we achieved a full-model apples to apples quantization speedup299# of about 7.7% on batch size 16 (677.28ms to 729.65ms). We can push this a300# bit further by increasing the batch size and optimizing other parts of301# the model. For example, this can be done with some form of flash attention.302#303# For more information visit304# `torchao <https://github.com/pytorch-labs/ao>`_ and try it on your own305# models.306#307308309