CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
pytorch

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: pytorch/tutorials
Path: blob/main/prototype_source/gpu_quantization_torchao_tutorial.py
Views: 494
1
"""
2
(prototype) GPU Quantization with TorchAO
3
======================================================
4
5
**Author**: `HDCharles <https://github.com/HDCharles>`_
6
7
In this tutorial, we will walk you through the quantization and optimization
8
of the popular `segment anything model <https://github.com/facebookresearch/segment-anything>`_. These
9
steps will mimic some of those taken to develop the
10
`segment-anything-fast <https://github.com/pytorch-labs/segment-anything-fast/blob/main/segment_anything_fast/modeling/image_encoder.py#L15>`_
11
repo. This step-by-step guide demonstrates how you can
12
apply these techniques to speed up your own models, especially those
13
that use transformers. To that end, we will focus on widely applicable
14
techniques, such as optimizing performance with ``torch.compile`` and
15
quantization and measure their impact.
16
17
"""
18
19
20
######################################################################
21
# Set up Your Environment
22
# --------------------------------
23
#
24
# First, let's configure your environment. This guide was written for CUDA 12.1.
25
# We have run this tutorial on an A100-PG509-200 power limited to 330.00 W. If you
26
# are using a different hardware, you might see different performance numbers.
27
#
28
#
29
# .. code-block:: bash
30
#
31
# > conda create -n myenv python=3.10
32
# > pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
33
# > pip install git+https://github.com/facebookresearch/segment-anything.git
34
# > pip install git+https://github.com/pytorch-labs/ao.git
35
#
36
# Segment Anything Model checkpoint setup:
37
#
38
# 1. Go to the `segment-anything repo checkpoint <https://github.com/facebookresearch/segment-anything/tree/main#model-checkpoints>`_ and download the ``vit_h`` checkpoint. Alternatively, you can use ``wget`` (for example, ``wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth --directory-prefix=<path>``).
39
# 2. Pass in that directory by editing the code below to say:
40
#
41
# .. code-block:: bash
42
#
43
# {sam_checkpoint_base_path}=<path>
44
#
45
46
import torch
47
from torchao.quantization import change_linear_weights_to_int8_dqtensors
48
from segment_anything import sam_model_registry
49
from torch.utils.benchmark import Timer
50
51
sam_checkpoint_base_path = "data"
52
model_type = 'vit_h'
53
model_name = 'sam_vit_h_4b8939.pth'
54
checkpoint_path = f"{sam_checkpoint_base_path}/{model_name}"
55
batchsize = 16
56
only_one_block = True
57
58
59
@torch.no_grad()
60
def benchmark(f, *args, **kwargs):
61
for _ in range(3):
62
f(*args, **kwargs)
63
torch.cuda.synchronize()
64
65
torch.cuda.reset_peak_memory_stats()
66
t0 = Timer(
67
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
68
)
69
res = t0.adaptive_autorange(.03, min_run_time=.2, max_run_time=20)
70
return {'time':res.median * 1e3, 'memory': torch.cuda.max_memory_allocated()/1e9}
71
72
def get_sam_model(only_one_block=False, batchsize=1):
73
sam = sam_model_registry[model_type](checkpoint=checkpoint_path).cuda()
74
model = sam.image_encoder.eval()
75
image = torch.randn(batchsize, 3, 1024, 1024, device='cuda')
76
77
# code to use just a single block of the model
78
if only_one_block:
79
model = model.blocks[0]
80
image = torch.randn(batchsize, 64, 64, 1280, device='cuda')
81
return model, image
82
83
84
######################################################################
85
# In this tutorial, we focus on quantizing the ``image_encoder`` because the
86
# inputs to it are statically sized while the prompt encoder and mask
87
# decoder have variable sizes which makes them harder to quantize.
88
#
89
# We’ll focus on just a single block at first to make the analysis easier.
90
#
91
# Let's start by measuring the baseline runtime.
92
93
try:
94
model, image = get_sam_model(only_one_block, batchsize)
95
fp32_res = benchmark(model, image)
96
print(f"base fp32 runtime of the model is {fp32_res['time']:0.2f}ms and peak memory {fp32_res['memory']:0.2f}GB")
97
# base fp32 runtime of the model is 186.16ms and peak memory 6.33GB
98
except Exception as e:
99
print("unable to run fp32 model: ", e)
100
101
102
103
######################################################################
104
# We can achieve an instant performance boost by converting the model to bfloat16.
105
# The reason we opt for bfloat16 over fp16 is due to its dynamic range, which is comparable to
106
# that of fp32. Both bfloat16 and fp32 possess 8 exponential bits, whereas fp16 only has 4. This
107
# larger dynamic range helps protect us from overflow errors and other issues that can arise
108
# when scaling and rescaling tensors due to quantization.
109
#
110
111
model, image = get_sam_model(only_one_block, batchsize)
112
model = model.to(torch.bfloat16)
113
image = image.to(torch.bfloat16)
114
bf16_res = benchmark(model, image)
115
print(f"bf16 runtime of the block is {bf16_res['time']:0.2f}ms and peak memory {bf16_res['memory']: 0.2f}GB")
116
# bf16 runtime of the block is 25.43ms and peak memory 3.17GB
117
118
119
######################################################################
120
# Just this quick change improves runtime by a factor of ~7x in the tests we have
121
# conducted (186.16ms to 25.43ms).
122
#
123
# Next, let's use ``torch.compile`` with our model to see how much the performance
124
# improves.
125
#
126
127
model_c = torch.compile(model, mode='max-autotune')
128
comp_res = benchmark(model_c, image)
129
print(f"bf16 compiled runtime of the block is {comp_res['time']:0.2f}ms and peak memory {comp_res['memory']: 0.2f}GB")
130
# bf16 compiled runtime of the block is 19.95ms and peak memory 2.24GB
131
132
133
######################################################################
134
# The first time this is run, you should see a sequence of ``AUTOTUNE``
135
# outputs which occurs when inductor compares the performance between
136
# various kernel parameters for a kernel. This only happens once (unless
137
# you delete your cache) so if you run the cell again you should just get
138
# the benchmark output.
139
#
140
# ``torch.compile`` yields about another 27% improvement. This brings the
141
# model to a reasonable baseline where we now have to work a bit harder
142
# for improvements.
143
#
144
# Next, let's apply quantization. Quantization for GPUs comes in three main forms
145
# in `torchao <https://github.com/pytorch-labs/ao>`_ which is just native
146
# pytorch+python code. This includes:
147
#
148
# * int8 dynamic quantization
149
# * int8 weight-only quantization
150
# * int4 weight-only quantization
151
#
152
# Different models, or sometimes different layers in a model can require different techniques.
153
# For models which are heavily compute bound, dynamic quantization tends
154
# to work the best since it swaps the normal expensive floating point
155
# matmul ops with integer versions. Weight-only quantization works better
156
# in memory bound situations where the benefit comes from loading less
157
# weight data, rather than doing less computation. The torchao APIs:
158
#
159
# ``change_linear_weights_to_int8_dqtensors``,
160
# ``change_linear_weights_to_int8_woqtensors`` or
161
# ``change_linear_weights_to_int4_woqtensors``
162
#
163
# can be used to easily apply the desired quantization technique and then
164
# once the model is compiled with ``torch.compile`` with ``max-autotune``, quantization is
165
# complete and we can see our speedup.
166
#
167
# .. note::
168
# You might experience issues with these on older versions of PyTorch. If you run
169
# into an issue, you can use ``apply_dynamic_quant`` and
170
# ``apply_weight_only_int8_quant`` instead as drop in replacement for the two
171
# above (no replacement for int4).
172
#
173
# The difference between the two APIs is that ``change_linear_weights`` API
174
# alters the weight tensor of the linear module so instead of doing a
175
# normal linear, it does a quantized operation. This is helpful when you
176
# have non-standard linear ops that do more than one thing. The ``apply``
177
# APIs directly swap the linear modules for a quantized module which
178
# works on older versions but doesn’t work with non-standard linear
179
# modules.
180
#
181
# In this case Segment Anything is compute-bound so we’ll use dynamic quantization:
182
#
183
184
del model_c, model, image
185
model, image = get_sam_model(only_one_block, batchsize)
186
model = model.to(torch.bfloat16)
187
image = image.to(torch.bfloat16)
188
change_linear_weights_to_int8_dqtensors(model)
189
model_c = torch.compile(model, mode='max-autotune')
190
quant_res = benchmark(model_c, image)
191
print(f"bf16 compiled runtime of the quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
192
# bf16 compiled runtime of the quantized block is 19.04ms and peak memory 3.58GB
193
194
195
######################################################################
196
# With quantization, we have improved performance a bit more but memory usage increased
197
# significantly.
198
#
199
# This is for two reasons:
200
#
201
# 1) Quantization adds overhead to the model
202
# since we need to quantize and dequantize the input and output. For small
203
# batch sizes this overhead can actually make the model go slower.
204
# 2) Even though we are doing a quantized matmul, such as ``int8 x int8``,
205
# the result of the multiplication gets stored in an int32 tensor
206
# which is twice the size of the result from the non-quantized model.
207
# If we can avoid creating this int32 tensor, our memory usage will improve a lot.
208
#
209
# We can fix #2 by fusing the integer matmul with the subsequent rescale
210
# operation since the final output will be bf16, if we immediately convert
211
# the int32 tensor to bf16 and instead store that we’ll get better
212
# performance in terms of both runtime and memory.
213
#
214
# The way to do this, is to enable the option
215
# ``force_fuse_int_mm_with_mul`` in the inductor config.
216
#
217
218
del model_c, model, image
219
model, image = get_sam_model(only_one_block, batchsize)
220
model = model.to(torch.bfloat16)
221
image = image.to(torch.bfloat16)
222
torch._inductor.config.force_fuse_int_mm_with_mul = True
223
change_linear_weights_to_int8_dqtensors(model)
224
model_c = torch.compile(model, mode='max-autotune')
225
quant_res = benchmark(model_c, image)
226
print(f"bf16 compiled runtime of the fused quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
227
# bf16 compiled runtime of the fused quantized block is 18.78ms and peak memory 2.37GB
228
229
230
######################################################################
231
# The fusion improves performance by another small bit (about 6% over the
232
# baseline in total) and removes almost all the memory increase, the
233
# remaining amount (2.37GB quantized vs 2.24GB unquantized) is due to
234
# quantization overhead which cannot be helped.
235
#
236
# We’re still not done though, we can apply a few general purpose
237
# optimizations to get our final best-case performance.
238
#
239
# 1) We can sometimes improve performance by disabling epilogue fusion
240
# since the autotuning process can be confused by fusions and choose
241
# bad kernel parameters.
242
# 2) We can apply coordinate descent tuning in all directions to enlarge
243
# the search area for kernel parameters.
244
#
245
246
del model_c, model, image
247
model, image = get_sam_model(only_one_block, batchsize)
248
model = model.to(torch.bfloat16)
249
image = image.to(torch.bfloat16)
250
torch._inductor.config.epilogue_fusion = False
251
torch._inductor.config.coordinate_descent_tuning = True
252
torch._inductor.config.coordinate_descent_check_all_directions = True
253
torch._inductor.config.force_fuse_int_mm_with_mul = True
254
change_linear_weights_to_int8_dqtensors(model)
255
model_c = torch.compile(model, mode='max-autotune')
256
quant_res = benchmark(model_c, image)
257
print(f"bf16 compiled runtime of the final quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
258
# bf16 compiled runtime of the final quantized block is 18.16ms and peak memory 2.39GB
259
260
261
######################################################################
262
# As you can see, we’ve squeezed another small improvement from the model,
263
# taking our total improvement to over 10x compared to our original. To
264
# get a final estimate of the impact of quantization lets do an apples to
265
# apples comparison on the full model since the actual improvement will
266
# differ block by block depending on the shapes involved.
267
#
268
269
try:
270
del model_c, model, image
271
model, image = get_sam_model(False, batchsize)
272
model = model.to(torch.bfloat16)
273
image = image.to(torch.bfloat16)
274
model_c = torch.compile(model, mode='max-autotune')
275
quant_res = benchmark(model_c, image)
276
print(f"bf16 compiled runtime of the compiled full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
277
# bf16 compiled runtime of the compiled full model is 729.65ms and peak memory 23.96GB
278
279
del model_c, model, image
280
model, image = get_sam_model(False, batchsize)
281
model = model.to(torch.bfloat16)
282
image = image.to(torch.bfloat16)
283
change_linear_weights_to_int8_dqtensors(model)
284
model_c = torch.compile(model, mode='max-autotune')
285
quant_res = benchmark(model_c, image)
286
print(f"bf16 compiled runtime of the quantized full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
287
# bf16 compiled runtime of the quantized full model is 677.28ms and peak memory 24.93GB
288
except Exception as e:
289
print("unable to run full model: ", e)
290
291
292
293
######################################################################
294
# Conclusion
295
# -----------------
296
# In this tutorial, we have learned about the quantization and optimization techniques
297
# on the example of the segment anything model.
298
#
299
# In the end, we achieved a full-model apples to apples quantization speedup
300
# of about 7.7% on batch size 16 (677.28ms to 729.65ms). We can push this a
301
# bit further by increasing the batch size and optimizing other parts of
302
# the model. For example, this can be done with some form of flash attention.
303
#
304
# For more information visit
305
# `torchao <https://github.com/pytorch-labs/ao>`_ and try it on your own
306
# models.
307
#
308
309