Path: blob/master/guides/int8_quantization_in_keras.py
4282 views
"""1Title: 8-bit Integer Quantization in Keras2Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh)3Date created: 2025/10/144Last modified: 2025/10/145Description: Complete guide to using INT8 quantization in Keras and KerasHub.6Accelerator: GPU7"""89"""10## What is INT8 quantization?1112Quantization lowers the numerical precision of weights and activations to reduce memory use13and often speed up inference, at the cost of a small accuracy drop. Moving from `float32` to14`float16` halves the memory requirements; `float32` to INT8 is ~4x smaller (and ~2x vs15`float16`). On hardware with low-precision kernels (e.g., NVIDIA Tensor Cores), this can also16improve throughput and latency. Actual gains depend on your backend and device.1718### How it works1920Quantization maps real values to 8-bit integers with a scale:2122* Integer domain: `[-128, 127]` (256 levels).23* For a tensor (often per-output-channel for weights) with values `w`:24* Compute `a_max = max(abs(w))`.25* Set scale `s = (2 * a_max) / 256`.26* Quantize: `q = clip(round(w / s), -128, 127)` (stored as INT8) and keep `s`.27* Inference uses `q` and `s` to reconstruct effective weights on the fly28(`w ≈ s · q`) or folds `s` into the matmul/conv for efficiency.2930### Benefits3132* Memory / bandwidth bound models: When implementation spends most of its time on memory I/O,33reducing the computation time does not reduce their overall runtime. INT8 reduces bytes34moved by ~4x vs `float32`, improving cache behavior and reducing memory stalls;35this often helps more than increasing raw FLOPs.36* Compute bound layers on supported hardware: On NVIDIA GPUs, INT837[Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/) speed up matmul/conv,38boosting throughput on compute-limited layers.39* Accuracy: Many models retain near-FP accuracy with `float16`; INT8 may introduce a modest40drop (often ~1-5% depending on task/model/data). Always validate on your own dataset.4142### What Keras does in INT8 mode4344* **Mapping**: Symmetric, linear quantization with INT8 plus a floating-point scale.45* **Weights**: per-output-channel scales to preserve accuracy.46* **Activations**: **dynamic AbsMax** scaling computed at runtime.47* **Graph rewrite**: Quantization is applied after weights are trained and built; the graph48is rewritten so you can run or save immediately.49"""5051"""52## Overview5354This guide shows how to use 8-bit integer post-training quantization (PTQ) in Keras:55561. [Quantizing a minimal functional model](#quantizing-a-minimal-functional-model)572. [Saving and reloading a quantized model](#saving-and-reloading-a-quantized-model)583. [Quantizing a KerasHub model](#quantizing-a-kerashub-model)5960## Quantizing a minimal functional model.6162We build a small functional model, capture a baseline output, quantize to INT8 in-place,63and then compare outputs with an MSE metric.64"""6566import os67import numpy as np68import keras69from keras import layers707172# Create a random number generator.73rng = np.random.default_rng()7475# Create a simple functional model.76inputs = keras.Input(shape=(10,))77x = layers.Dense(32, activation="relu")(inputs)78outputs = layers.Dense(1, name="target")(x)79model = keras.Model(inputs, outputs)8081# Compile and train briefly to materialize meaningful weights.82model.compile(optimizer="adam", loss="mse")83x_train = rng.random((256, 10)).astype("float32")84y_train = rng.random((256, 1)).astype("float32")85model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=0)8687# Sample inputs for evaluation.88x_eval = rng.random((32, 10)).astype("float32")8990# Baseline (FP) outputs.91y_fp32 = model(x_eval)9293# Quantize the model in-place to INT8.94model.quantize("int8")9596# INT8 outputs after quantization.97y_int8 = model(x_eval)9899# Compute a simple MSE between FP and INT8 outputs.100mse = keras.ops.mean(keras.ops.square(y_fp32 - y_int8))101print("Full-Precision vs INT8 MSE:", float(mse))102103104"""105It is evident that the INT8 quantized model produces outputs close to the original FP32106model, as indicated by the low MSE value.107108## Saving and reloading a quantized model109110You can use the standard Keras saving and loading APIs with quantized models. Quantization111is preserved when saving to `.keras` and loading back.112"""113114# Save the quantized model and reload to verify round-trip.115model.save("int8.keras")116int8_reloaded = keras.saving.load_model("int8.keras")117y_int8_reloaded = int8_reloaded(x_eval)118roundtrip_mse = keras.ops.mean(keras.ops.square(y_int8 - y_int8_reloaded))119print("MSE (INT8 vs reloaded-INT8):", float(roundtrip_mse))120121"""122## Quantizing a KerasHub model123124All KerasHub models support the `.quantize(...)` API for post-training quantization,125and follow the same workflow as above.126127In this example, we will:1281291. Load the [gemma3_1b](https://www.kaggle.com/models/keras/gemma3/keras/gemma3_1b)130preset from KerasHub1312. Generate text using both the full-precision and quantized models, and compare outputs.1323. Save both models to disk and compute storage savings.1334. Reload the INT8 model and verify output consistency with the original quantized model.134"""135136from keras_hub.models import Gemma3CausalLM137138# Load from Gemma3 preset139gemma3 = Gemma3CausalLM.from_preset("gemma3_1b")140141# Generate text for a single prompt142output = gemma3.generate("Keras is a", max_length=50)143print("Full-precision output:", output)144145# Save FP32 Gemma3 model for size comparison.146gemma3.save_to_preset("gemma3_fp32")147148# Quantize in-place to INT8 and generate again149gemma3.quantize("int8")150151output = gemma3.generate("Keras is a", max_length=50)152print("Quantized output:", output)153154# Save INT8 Gemma3 model155gemma3.save_to_preset("gemma3_int8")156157# Reload and compare outputs158gemma3_int8 = Gemma3CausalLM.from_preset("gemma3_int8")159160output = gemma3_int8.generate("Keras is a", max_length=50)161print("Quantized reloaded output:", output)162163164# Compute storage savings165def bytes_to_mib(n):166return n / (1024**2)167168169gemma_fp32_size = os.path.getsize("gemma3_fp32/model.weights.h5")170gemma_int8_size = os.path.getsize("gemma3_int8/model.weights.h5")171172gemma_reduction = 100.0 * (1.0 - (gemma_int8_size / max(gemma_fp32_size, 1)))173print(f"Gemma3: FP32 file size: {bytes_to_mib(gemma_fp32_size):.2f} MiB")174print(f"Gemma3: INT8 file size: {bytes_to_mib(gemma_int8_size):.2f} MiB")175print(f"Gemma3: Size reduction: {gemma_reduction:.1f}%")176177"""178## Practical tips179180* Post-training quantization (PTQ) is a one-time operation; you cannot train a model181after quantizing it to INT8.182* Always materialize weights before quantization (e.g., `build()` or a forward pass).183* Expect small numerical deltas; quantify with a metric like MSE on a validation batch.184* Storage savings are immediate; speedups depend on backend/device kernels.185186## References187188* [Milvus: How does 8-bit quantization or float16 affect the accuracy and speed of Sentence Transformer embeddings and similarity calculations?](https://milvus.io/ai-quick-reference/how-does-quantization-such-as-int8-quantization-or-using-float16-affect-the-accuracy-and-speed-of-sentence-transformer-embeddings-and-similarity-calculations)189* [NVIDIA Developer Blog: Achieving FP32 accuracy for INT8 inference using quantization-aware training with TensorRT](https://developer.nvidia.com/blog/achieving-fp32-accuracy-for-int8-inference-using-quantization-aware-training-with-tensorrt/)190"""191192193