Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/gptq_quantization_in_keras.py
4282 views
1
"""
2
Title: GPTQ Quantization in Keras
3
Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh)
4
Date created: 2025/10/16
5
Last modified: 2025/10/16
6
Description: How to run weight-only GPTQ quantization for Keras & KerasHub models.
7
Accelerator: GPU
8
"""
9
10
"""
11
## What is GPTQ?
12
13
GPTQ ("Generative Pre-Training Quantization") is a post-training, weight-only
14
quantization method that uses a second-order approximation of the loss (via a
15
Hessian estimate) to minimize the error introduced when compressing weights to
16
lower precision, typically 4-bit integers.
17
18
Unlike standard post-training techniques, GPTQ keeps activations in
19
higher-precision and only quantizes the weights. This often preserves model
20
quality in low bit-width settings while still providing large storage and
21
memory savings.
22
23
Keras supports GPTQ quantization for KerasHub models via the
24
`keras.quantizers.GPTQConfig` class.
25
"""
26
27
"""
28
## Load a KerasHub model
29
30
This guide uses the `Gemma3CausalLM` model from KerasHub, a small (1B
31
parameter) causal language model.
32
33
"""
34
import keras
35
from keras_hub.models import Gemma3CausalLM
36
from datasets import load_dataset
37
38
39
prompt = "Keras is a"
40
41
model = Gemma3CausalLM.from_preset("gemma3_1b")
42
43
outputs = model.generate(prompt, max_length=30)
44
print(outputs)
45
46
"""
47
## Configure & run GPTQ quantization
48
49
You can configure GPTQ quantization via the `keras.quantizers.GPTQConfig` class.
50
51
The GPTQ configuration requires a calibration dataset and tokenizer, which it
52
uses to estimate the Hessian and quantization error. Here, we use a small slice
53
of the WikiText-2 dataset for calibration.
54
55
You can tune several parameters to trade off speed, memory, and accuracy. The
56
most important of these are `weight_bits` (the bit-width to quantize weights to)
57
and `group_size` (the number of weights to quantize together). The group size
58
controls the granularity of quantization: smaller groups typically yield better
59
accuracy but are slower to quantize and may use more memory. A good starting
60
point is `group_size=128` for 4-bit quantization (`weight_bits=4`).
61
62
In this example, we first prepare a tiny calibration set, and then run GPTQ on
63
the model using the `.quantize(...)` API.
64
"""
65
66
# Calibration slice (use a larger/representative set in practice)
67
texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")["text"]
68
69
calibration_dataset = [
70
s + "." for text in texts for s in map(str.strip, text.split(".")) if s
71
]
72
73
gptq_config = keras.quantizers.GPTQConfig(
74
dataset=calibration_dataset,
75
tokenizer=model.preprocessor.tokenizer,
76
weight_bits=4,
77
group_size=128,
78
num_samples=256,
79
sequence_length=256,
80
hessian_damping=0.01,
81
symmetric=False,
82
activation_order=False,
83
)
84
85
model.quantize("gptq", config=gptq_config)
86
87
outputs = model.generate(prompt, max_length=30)
88
print(outputs)
89
90
"""
91
## Model Export
92
93
The GPTQ quantized model can be saved to a preset and reloaded elsewhere, just
94
like any other KerasHub model.
95
"""
96
97
model.save_to_preset("gemma3_gptq_w4gs128_preset")
98
model_from_preset = Gemma3CausalLM.from_preset("gemma3_gptq_w4gs128_preset")
99
output = model_from_preset.generate(prompt, max_length=30)
100
print(output)
101
102
"""
103
## Performance & Benchmarking
104
105
Micro-benchmarks collected on a single NVIDIA 4070 Ti Super (16 GB).
106
Baselines are FP32.
107
108
Dataset: WikiText-2.
109
110
111
| Model (preset) | Perplexity Increase % (↓ better) | Disk Storage Reduction Δ % (↓ better) | VRAM Reduction Δ % (↓ better) | First-token Latency Δ % (↓ better) | Throughput Δ % (↑ better) |
112
| --------------------------------- | -------------------------------: | ------------------------------------: | ----------------------------: | ---------------------------------: | ------------------------: |
113
| GPT2 (gpt2_base_en_cnn_dailymail) | 1.0% | -50.1% ↓ | -41.1% ↓ | +0.7% ↑ | +20.1% ↑ |
114
| OPT (opt_125m_en) | 10.0% | -49.8% ↓ | -47.0% ↓ | +6.7% ↑ | -15.7% ↓ |
115
| Bloom (bloom_1.1b_multi) | 7.0% | -47.0% ↓ | -54.0% ↓ | +1.8% ↑ | -15.7% ↓ |
116
| Gemma3 (gemma3_1b) | 3.0% | -51.5% ↓ | -51.8% ↓ | +39.5% ↑ | +5.7% ↑ |
117
118
119
Detailed benchmarking numbers and scripts are available
120
[here](https://github.com/keras-team/keras/pull/21641).
121
122
### Analysis
123
124
There is notable reduction in disk space and VRAM usage across all models, with
125
disk space savings around 50% and VRAM savings ranging from 41% to 54%. The
126
reported disk savings understate the true weight compression because presets
127
also include non-weight assets.
128
129
Perplexity increases only marginally, indicating model quality is largely
130
preserved after quantization.
131
"""
132
133
"""
134
## Practical tips
135
136
* GPTQ is a post-training technique; training after quantization is not supported.
137
* Always use the model's own tokenizer for calibration.
138
* Use a representative calibration set; small slices are only for demos.
139
* Start with W4 group_size=128; tune per model/task.
140
"""
141
142