Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/int4_quantization_in_keras.py
4282 views
1
"""
2
Title: INT4 Quantization in Keras
3
Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh)
4
Date created: 2025/10/14
5
Last modified: 2025/10/14
6
Description: Complete guide to using INT4 quantization in Keras and KerasHub.
7
Accelerator: GPU
8
"""
9
10
"""
11
## What is INT4 quantization?
12
13
Quantization lowers the numerical precision of weights and activations to reduce memory use
14
and often speed up inference, at the cost of a small accuracy drop. INT4 post-training
15
quantization (PTQ) stores model weights in 4-bit signed integers and dynamically quantizes
16
activations to 8-bit at runtime (a W4A8 scheme). Compared with FP32 this can shrink weight
17
storage ~8x (2x vs INT8) while retaining acceptable accuracy for many encoder models and
18
some decoder models. Compute still leverages widely available NVIDIA INT8 Tensor Cores.
19
20
4-bit is a more aggressive compression than 8-bit and may induce larger quality regressions,
21
especially for large autoregressive language models.
22
23
## How it works
24
25
Quantization maps real values to 4-bit integers with a scale:
26
27
1. Per-output-channel scale computed for each weight matrix (symmetric abs-max).
28
2. Weights are quantized to values in `[-8, 7]` (4 bits) and packed two-per-byte.
29
3. At inference, activations are dynamically scaled and quantized to INT8.
30
4. Packed INT4 weights are unpacked to an INT8 tensor (still with INT4-range values).
31
5. INT8 x INT8 matmul accumulates in INT32.
32
6. Result is dequantized using `(input_scale * per_channel_kernel_scale)`.
33
34
This mirrors the INT8 path described in the
35
[INT8 guide](https://keras.io/guides/int8_quantization_in_keras) with some added unpack
36
overhead for stronger compression.
37
38
## Benefits
39
* Memory / bandwidth bound models: When the implementation spends most of its time on memory I/O,
40
reducing the computation time does not reduce its overall runtime. INT4 reduces bytes
41
moved by ~8x vs FP32, improving cache behavior and reducing memory stalls;
42
this often helps more than increasing raw FLOPs.
43
* Accuracy: Many architectures retain acceptable accuracy with INT4; encoder-only models
44
often fare better than decoder LLMs. Always validate on your own dataset.
45
* Compute bound layers on supported hardware: 4-bit kernels are unpacked to INT8 at inference,
46
therefore, on NVIDIA GPUs, INT8 [Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/)
47
speed up matmul/conv, boosting throughput on compute-limited layers.
48
49
### What Keras does in INT4 mode
50
51
* **Mapping**: Symmetric, linear quantization with INT4 plus a floating-point scale.
52
* **Weights**: per-output-channel scales to preserve accuracy.
53
* **Activations**: **dynamic AbsMax** scaling computed at runtime.
54
* **Graph rewrite**: Quantization is applied after weights are trained and built; the graph
55
is rewritten so you can run or save immediately.
56
"""
57
58
"""
59
## Overview
60
61
This guide shows how to use 4-bit (W4A8) post-training quantization in Keras:
62
63
1. [Quantizing a minimal functional model](#quantizing-a-minimal-functional-model)
64
2. [Saving and reloading a quantized model](#saving-and-reloading-a-quantized-model)
65
3. [Quantizing a KerasHub model](#quantizing-a-kerashub-model)
66
4. [When to use INT4 vs INT8](#when-should-i-use-int4-vs-int8)
67
5. [Performance benchmarks](#performance--benchmarking)
68
6. [Practical Tips](#practical-tips)
69
7. [Limitations](#limitations)
70
"""
71
72
"""
73
## Quantizing a Minimal Functional Model
74
75
Below we build a small functional model, capture a baseline output, quantize to INT4
76
in place, and compare outputs with an MSE metric. (For real evaluation use your
77
validation metric.)
78
"""
79
80
import numpy as np
81
import keras
82
from keras import layers
83
84
# Create a random number generator.
85
rng = np.random.default_rng()
86
87
# Create a simple functional model.
88
inputs = keras.Input(shape=(10,))
89
x = layers.Dense(32, activation="relu")(inputs)
90
outputs = layers.Dense(1, name="target")(x)
91
model = keras.Model(inputs, outputs)
92
93
# Baseline output with full-precision weights.
94
x_eval = rng.random((32, 10)).astype("float32")
95
y_fp32 = model(x_eval)
96
97
98
# Quantize the model in-place to INT4 (W4A8).
99
model.quantize("int4")
100
101
# Compare outputs (MSE).
102
y_int4 = model(x_eval)
103
mse = keras.ops.mean(keras.ops.square(y_fp32 - y_int4))
104
print("Full-Precision vs INT4 MSE:", float(mse))
105
106
"""
107
The INT4 quantized model usually produces outputs close enough for many downstream
108
tasks. Expect larger deltas than INT8, so always validate on your own data.
109
"""
110
111
"""
112
## Saving and Reloading a Quantized Model
113
114
You can use standard Keras saving / loading APIs. Quantization metadata (including
115
scales and packed weights) is preserved.
116
"""
117
118
# Save the quantized model and reload to verify round-trip.
119
model.save("int4.keras")
120
int4_reloaded = keras.saving.load_model("int4.keras")
121
y_int4_reloaded = int4_reloaded(x_eval)
122
123
# Compare outputs (MSE).
124
roundtrip_mse = keras.ops.mean(keras.ops.square(y_fp32 - y_int4_reloaded))
125
print("MSE (INT4 vs reloaded INT4):", float(roundtrip_mse))
126
127
"""
128
## Quantizing a KerasHub Model
129
130
All KerasHub models support the `.quantize(...)` API for post-training quantization,
131
and follow the same workflow as above.
132
133
In this example, we will:
134
135
1. Load the [gemma3_1b](https://www.kaggle.com/models/keras/gemma3/keras/gemma3_1b)
136
preset from KerasHub
137
2. Generate text using both the full-precision and quantized models, and compare outputs.
138
3. Save both models to disk and compute storage savings.
139
4. Reload the INT4 model and verify output consistency with the original quantized model.
140
"""
141
import os
142
from keras_hub.models import Gemma3CausalLM
143
144
# Load a Gemma3 preset from KerasHub.
145
gemma3 = Gemma3CausalLM.from_preset("gemma3_1b")
146
147
# Generate with full-precision weights.
148
fp_output = gemma3.generate("Keras is a", max_length=30)
149
print("Full-precision output:", fp_output)
150
151
# Save the full-precision model to a preset.
152
gemma3.save_to_preset("gemma3_fp32")
153
154
# Quantize to INT4.
155
gemma3.quantize("int4")
156
157
# Generate with INT4 weights.
158
output = gemma3.generate("Keras is a", max_length=30)
159
print("Quantized output:", output)
160
161
# Save INT4 model to a new preset.
162
gemma3.save_to_preset("gemma3_int4")
163
164
# Reload and compare outputs
165
gemma3_int4 = Gemma3CausalLM.from_preset("gemma3_int4")
166
167
output = gemma3_int4.generate("Keras is a", max_length=30)
168
print("Quantized reloaded output:", output)
169
170
171
# Compute storage savings
172
def bytes_to_mib(n):
173
return n / (1024**2)
174
175
176
gemma_fp32_size = os.path.getsize("gemma3_fp32/model.weights.h5")
177
gemma_int4_size = os.path.getsize("gemma3_int4/model.weights.h5")
178
179
gemma_reduction = 100.0 * (1.0 - (gemma_int4_size / max(gemma_fp32_size, 1)))
180
print(f"Gemma3: FP32 file size: {bytes_to_mib(gemma_fp32_size):.2f} MiB")
181
print(f"Gemma3: INT4 file size: {bytes_to_mib(gemma_int4_size):.2f} MiB")
182
print(f"Gemma3: Size reduction: {gemma_reduction:.1f}%")
183
184
"""
185
## Performance & Benchmarking
186
187
Micro-benchmarks collected on a single NVIDIA L4 (22.5 GB). Baselines are FP32.
188
189
### Text Classification (DistilBERT Base on SST-2)
190
191
<img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> [**View in Colab**](https://colab.research.google.com/gist/JyotinderSingh/77e874187d6da3f8280c053192f78c06/int4-quantization-micro-benchmark-distilbert.ipynb)
192
193
| Metric | FP32 | INT4 | Change |
194
| ------ | ---- | ---- | ------ |
195
| Accuracy (↑) | 91.06% | 90.14% | -0.92pp |
196
| Model Size (MB, ↓) | 255.86 | 159.49 | -37.67% |
197
| Peak GPU Memory (MiB, ↓) | 1554.00 | 1243.26 | -20.00% |
198
| Latency (ms/sample, ↓) | 6.43 | 5.73 | -10.83% |
199
| Throughput (samples/s, ↑) | 155.60 | 174.50 | +12.15% |
200
201
**Analysis**: Accuracy drop is modest (<1pp) with notable speed and memory gains;
202
encoder-only models tend to retain fidelity under heavier weight compression.
203
204
### Text Generation (Falcon 1B)
205
206
<img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> [**View in Colab**](https://colab.research.google.com/gist/JyotinderSingh/19ab238e0f5b29ae24c0faf4128e7d7e/int4_quantization_micro_benchmark_falcon.ipynb)
207
208
| Metric | FP32 | INT4 | Change |
209
| ------ | ---- | ---- | ------ |
210
| Perplexity (↓) | 7.44 | 9.98 | +34.15% |
211
| Model Size (GB, ↓) | 4.8884 | 0.9526 | -80.51% |
212
| Peak GPU Memory (MiB, ↓) | 8021.12 | 5483.46 | -31.64% |
213
| First Token Latency (ms, ↓) | 128.87 | 122.50 | -4.95% |
214
| Sequence Latency (ms, ↓) | 338.29 | 181.93 | -46.22% |
215
| Token Throughput (tokens/s, ↑) | 174.41 | 256.96 | +47.33% |
216
217
**Analysis**: INT4 gives large size (-80%) and memory (-32%) reductions. Perplexity
218
increases (expected for aggressive compression) yet sequence latency drops and
219
throughput rises ~50%.
220
221
### Text Generation (Gemma3 1B)
222
223
<img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> [**View in Colab**](https://colab.research.google.com/gist/JyotinderSingh/9ca7813971868d5d1a16cd7998d0e352/int4_quantization_micro_benchmark_gemma3.ipynb)
224
225
| Metric | FP32 | INT4 | Change |
226
| ------ | ---- | ---- | ------ |
227
| Perplexity (↓) | 6.17 | 10.46 | +69.61% |
228
| Model Size (GB, ↓) | 3.7303 | 1.4576 | -60.92% |
229
| Peak GPU Memory (MiB, ↓) | 6844.67 | 5008.14 | -26.83% |
230
| First Token Latency (ms, ↓) | 57.42 | 64.21 | +11.83% |
231
| Sequence Latency (ms, ↓) | 239.78 | 161.18 | -32.78% |
232
| Token Throughput (tokens/s, ↑) | 246.06 | 366.05 | +48.76% |
233
234
**Analysis**: INT4 gives large size (-61%) and memory (-27%) reductions. Perplexity
235
increases (expected for aggressive compression) yet sequence latency drops and
236
throughput rises ~50%.
237
238
### Text Generation (Llama 3.2 1B)
239
240
<img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> [**View in Colab**](https://colab.research.google.com/gist/JyotinderSingh/310f50a0ca0eba3754de41c612b3b8ef/int4_quantization_micro_benchmark_llama3.ipynb)
241
242
| Metric | FP32 | INT4 | Change |
243
| ------ | ---- | ---- | ------ |
244
| Perplexity (↓) | 6.38 | 14.16 | +121.78% |
245
| Model Size (GB, ↓) | 5.5890 | 2.4186 | -56.73% |
246
| Peak GPU Memory (MiB, ↓) | 9509.49 | 6810.26 | -28.38% |
247
| First Token Latency (ms, ↓) | 209.41 | 219.09 | +4.62% |
248
| Sequence Latency (ms, ↓) | 322.33 | 262.15 | -18.67% |
249
| Token Throughput (tokens/s, ↑) | 183.82 | 230.78 | +25.55% |
250
251
**Analysis**: INT4 gives large size (-57%) and memory (-28%) reductions. Perplexity
252
increases (expected for aggressive compression) yet sequence latency drops and
253
throughput rises ~25%.
254
255
### Text Generation (OPT 125M)
256
257
<img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> [**View in Colab**](https://colab.research.google.com/gist/JyotinderSingh/918fcdb8a1433dea12800f8ca4a240f5/int4_quantization_micro_benchmark_opt.ipynb)
258
259
| Metric | FP32 | INT4 | Change |
260
| ------ | ---- | ---- | ------ |
261
| Perplexity (↓) | 13.85 | 21.02 | +51.79% |
262
| Model Size (MB, ↓) | 468.3 | 284.0 | -39.37% |
263
| Peak GPU Memory (MiB, ↓) | 1007.23 | 659.28 | -34.54% |
264
| First Token Latency (ms/sample, ↓) | 95.79 | 97.87 | +2.18% |
265
| Sequence Latency (ms/sample, ↓) | 60.35 | 54.64 | -9.46% |
266
| Throughput (samples/s, ↑) | 973.41 | 1075.15 | +10.45% |
267
268
**Analysis**: INT4 gives large size (-39%) and memory (-35%) reductions. Perplexity
269
increases (expected for aggressive compression) yet sequence latency drops and
270
throughput rises ~10%.
271
"""
272
273
"""
274
## When should I use INT4 vs INT8?
275
276
| Goal / Constraint | Prefer INT8 | Prefer INT4 (W4A8) |
277
| ----------------- | ----------- | ------------------ |
278
| Minimal accuracy drop critical | ✔︎ | |
279
| Maximum compression (disk / RAM) | | ✔︎ |
280
| Bandwidth-bound inference | Possible | Often better |
281
| Decoder LLM | ✔︎ | Try with eval first |
282
| Encoder / classification models | ✔︎ | ✔︎ |
283
| Available kernels / tooling maturity | ✔︎ | Emerging |
284
285
* Start with INT8; if memory or distribution size is still a bottleneck, evaluate INT4.
286
* For LLMs, measure task-specific metrics (perplexity, exact match, etc.) after INT4.
287
* Combine INT4 weights + LoRA adapters for efficient fine-tuning workflows.
288
"""
289
290
"""
291
## Practical Tips
292
293
* Post-training quantization (PTQ) is a one-time operation; you cannot train a model
294
after quantizing it to INT4.
295
* Always materialize weights before quantization (e.g., `build()` or a forward pass).
296
* Evaluate on a representative validation set; track task metrics, not just MSE.
297
* Use LoRA for further fine-tuning.
298
299
## Limitations
300
* Runtime unpack adds overhead (weights are decompressed layer-wise for each forward pass).
301
* Large compression leads to accuracy drop (especially for decoder-only LLMs).
302
* LoRA export path is lossy (dequantize -> add delta -> requantize).
303
* Keras does not yet support native fused INT4 kernels; relies on unpack + INT8 matmul.
304
"""
305
306