Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/int8_quantization_in_keras.py
4282 views
1
"""
2
Title: 8-bit Integer Quantization in Keras
3
Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh)
4
Date created: 2025/10/14
5
Last modified: 2025/10/14
6
Description: Complete guide to using INT8 quantization in Keras and KerasHub.
7
Accelerator: GPU
8
"""
9
10
"""
11
## What is INT8 quantization?
12
13
Quantization lowers the numerical precision of weights and activations to reduce memory use
14
and often speed up inference, at the cost of a small accuracy drop. Moving from `float32` to
15
`float16` halves the memory requirements; `float32` to INT8 is ~4x smaller (and ~2x vs
16
`float16`). On hardware with low-precision kernels (e.g., NVIDIA Tensor Cores), this can also
17
improve throughput and latency. Actual gains depend on your backend and device.
18
19
### How it works
20
21
Quantization maps real values to 8-bit integers with a scale:
22
23
* Integer domain: `[-128, 127]` (256 levels).
24
* For a tensor (often per-output-channel for weights) with values `w`:
25
* Compute `a_max = max(abs(w))`.
26
* Set scale `s = (2 * a_max) / 256`.
27
* Quantize: `q = clip(round(w / s), -128, 127)` (stored as INT8) and keep `s`.
28
* Inference uses `q` and `s` to reconstruct effective weights on the fly
29
(`w ≈ s · q`) or folds `s` into the matmul/conv for efficiency.
30
31
### Benefits
32
33
* Memory / bandwidth bound models: When implementation spends most of its time on memory I/O,
34
reducing the computation time does not reduce their overall runtime. INT8 reduces bytes
35
moved by ~4x vs `float32`, improving cache behavior and reducing memory stalls;
36
this often helps more than increasing raw FLOPs.
37
* Compute bound layers on supported hardware: On NVIDIA GPUs, INT8
38
[Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/) speed up matmul/conv,
39
boosting throughput on compute-limited layers.
40
* Accuracy: Many models retain near-FP accuracy with `float16`; INT8 may introduce a modest
41
drop (often ~1-5% depending on task/model/data). Always validate on your own dataset.
42
43
### What Keras does in INT8 mode
44
45
* **Mapping**: Symmetric, linear quantization with INT8 plus a floating-point scale.
46
* **Weights**: per-output-channel scales to preserve accuracy.
47
* **Activations**: **dynamic AbsMax** scaling computed at runtime.
48
* **Graph rewrite**: Quantization is applied after weights are trained and built; the graph
49
is rewritten so you can run or save immediately.
50
"""
51
52
"""
53
## Overview
54
55
This guide shows how to use 8-bit integer post-training quantization (PTQ) in Keras:
56
57
1. [Quantizing a minimal functional model](#quantizing-a-minimal-functional-model)
58
2. [Saving and reloading a quantized model](#saving-and-reloading-a-quantized-model)
59
3. [Quantizing a KerasHub model](#quantizing-a-kerashub-model)
60
61
## Quantizing a minimal functional model.
62
63
We build a small functional model, capture a baseline output, quantize to INT8 in-place,
64
and then compare outputs with an MSE metric.
65
"""
66
67
import os
68
import numpy as np
69
import keras
70
from keras import layers
71
72
73
# Create a random number generator.
74
rng = np.random.default_rng()
75
76
# Create a simple functional model.
77
inputs = keras.Input(shape=(10,))
78
x = layers.Dense(32, activation="relu")(inputs)
79
outputs = layers.Dense(1, name="target")(x)
80
model = keras.Model(inputs, outputs)
81
82
# Compile and train briefly to materialize meaningful weights.
83
model.compile(optimizer="adam", loss="mse")
84
x_train = rng.random((256, 10)).astype("float32")
85
y_train = rng.random((256, 1)).astype("float32")
86
model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=0)
87
88
# Sample inputs for evaluation.
89
x_eval = rng.random((32, 10)).astype("float32")
90
91
# Baseline (FP) outputs.
92
y_fp32 = model(x_eval)
93
94
# Quantize the model in-place to INT8.
95
model.quantize("int8")
96
97
# INT8 outputs after quantization.
98
y_int8 = model(x_eval)
99
100
# Compute a simple MSE between FP and INT8 outputs.
101
mse = keras.ops.mean(keras.ops.square(y_fp32 - y_int8))
102
print("Full-Precision vs INT8 MSE:", float(mse))
103
104
105
"""
106
It is evident that the INT8 quantized model produces outputs close to the original FP32
107
model, as indicated by the low MSE value.
108
109
## Saving and reloading a quantized model
110
111
You can use the standard Keras saving and loading APIs with quantized models. Quantization
112
is preserved when saving to `.keras` and loading back.
113
"""
114
115
# Save the quantized model and reload to verify round-trip.
116
model.save("int8.keras")
117
int8_reloaded = keras.saving.load_model("int8.keras")
118
y_int8_reloaded = int8_reloaded(x_eval)
119
roundtrip_mse = keras.ops.mean(keras.ops.square(y_int8 - y_int8_reloaded))
120
print("MSE (INT8 vs reloaded-INT8):", float(roundtrip_mse))
121
122
"""
123
## Quantizing a KerasHub model
124
125
All KerasHub models support the `.quantize(...)` API for post-training quantization,
126
and follow the same workflow as above.
127
128
In this example, we will:
129
130
1. Load the [gemma3_1b](https://www.kaggle.com/models/keras/gemma3/keras/gemma3_1b)
131
preset from KerasHub
132
2. Generate text using both the full-precision and quantized models, and compare outputs.
133
3. Save both models to disk and compute storage savings.
134
4. Reload the INT8 model and verify output consistency with the original quantized model.
135
"""
136
137
from keras_hub.models import Gemma3CausalLM
138
139
# Load from Gemma3 preset
140
gemma3 = Gemma3CausalLM.from_preset("gemma3_1b")
141
142
# Generate text for a single prompt
143
output = gemma3.generate("Keras is a", max_length=50)
144
print("Full-precision output:", output)
145
146
# Save FP32 Gemma3 model for size comparison.
147
gemma3.save_to_preset("gemma3_fp32")
148
149
# Quantize in-place to INT8 and generate again
150
gemma3.quantize("int8")
151
152
output = gemma3.generate("Keras is a", max_length=50)
153
print("Quantized output:", output)
154
155
# Save INT8 Gemma3 model
156
gemma3.save_to_preset("gemma3_int8")
157
158
# Reload and compare outputs
159
gemma3_int8 = Gemma3CausalLM.from_preset("gemma3_int8")
160
161
output = gemma3_int8.generate("Keras is a", max_length=50)
162
print("Quantized reloaded output:", output)
163
164
165
# Compute storage savings
166
def bytes_to_mib(n):
167
return n / (1024**2)
168
169
170
gemma_fp32_size = os.path.getsize("gemma3_fp32/model.weights.h5")
171
gemma_int8_size = os.path.getsize("gemma3_int8/model.weights.h5")
172
173
gemma_reduction = 100.0 * (1.0 - (gemma_int8_size / max(gemma_fp32_size, 1)))
174
print(f"Gemma3: FP32 file size: {bytes_to_mib(gemma_fp32_size):.2f} MiB")
175
print(f"Gemma3: INT8 file size: {bytes_to_mib(gemma_int8_size):.2f} MiB")
176
print(f"Gemma3: Size reduction: {gemma_reduction:.1f}%")
177
178
"""
179
## Practical tips
180
181
* Post-training quantization (PTQ) is a one-time operation; you cannot train a model
182
after quantizing it to INT8.
183
* Always materialize weights before quantization (e.g., `build()` or a forward pass).
184
* Expect small numerical deltas; quantify with a metric like MSE on a validation batch.
185
* Storage savings are immediate; speedups depend on backend/device kernels.
186
187
## References
188
189
* [Milvus: How does 8-bit quantization or float16 affect the accuracy and speed of Sentence Transformer embeddings and similarity calculations?](https://milvus.io/ai-quick-reference/how-does-quantization-such-as-int8-quantization-or-using-float16-affect-the-accuracy-and-speed-of-sentence-transformer-embeddings-and-similarity-calculations)
190
* [NVIDIA Developer Blog: Achieving FP32 accuracy for INT8 inference using quantization-aware training with TensorRT](https://developer.nvidia.com/blog/achieving-fp32-accuracy-for-int8-inference-using-quantization-aware-training-with-tensorrt/)
191
"""
192
193