Path: blob/master/guides/int4_quantization_in_keras.py
4282 views
"""1Title: INT4 Quantization in Keras2Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh)3Date created: 2025/10/144Last modified: 2025/10/145Description: Complete guide to using INT4 quantization in Keras and KerasHub.6Accelerator: GPU7"""89"""10## What is INT4 quantization?1112Quantization lowers the numerical precision of weights and activations to reduce memory use13and often speed up inference, at the cost of a small accuracy drop. INT4 post-training14quantization (PTQ) stores model weights in 4-bit signed integers and dynamically quantizes15activations to 8-bit at runtime (a W4A8 scheme). Compared with FP32 this can shrink weight16storage ~8x (2x vs INT8) while retaining acceptable accuracy for many encoder models and17some decoder models. Compute still leverages widely available NVIDIA INT8 Tensor Cores.18194-bit is a more aggressive compression than 8-bit and may induce larger quality regressions,20especially for large autoregressive language models.2122## How it works2324Quantization maps real values to 4-bit integers with a scale:25261. Per-output-channel scale computed for each weight matrix (symmetric abs-max).272. Weights are quantized to values in `[-8, 7]` (4 bits) and packed two-per-byte.283. At inference, activations are dynamically scaled and quantized to INT8.294. Packed INT4 weights are unpacked to an INT8 tensor (still with INT4-range values).305. INT8 x INT8 matmul accumulates in INT32.316. Result is dequantized using `(input_scale * per_channel_kernel_scale)`.3233This mirrors the INT8 path described in the34[INT8 guide](https://keras.io/guides/int8_quantization_in_keras) with some added unpack35overhead for stronger compression.3637## Benefits38* Memory / bandwidth bound models: When the implementation spends most of its time on memory I/O,39reducing the computation time does not reduce its overall runtime. INT4 reduces bytes40moved by ~8x vs FP32, improving cache behavior and reducing memory stalls;41this often helps more than increasing raw FLOPs.42* Accuracy: Many architectures retain acceptable accuracy with INT4; encoder-only models43often fare better than decoder LLMs. Always validate on your own dataset.44* Compute bound layers on supported hardware: 4-bit kernels are unpacked to INT8 at inference,45therefore, on NVIDIA GPUs, INT8 [Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/)46speed up matmul/conv, boosting throughput on compute-limited layers.4748### What Keras does in INT4 mode4950* **Mapping**: Symmetric, linear quantization with INT4 plus a floating-point scale.51* **Weights**: per-output-channel scales to preserve accuracy.52* **Activations**: **dynamic AbsMax** scaling computed at runtime.53* **Graph rewrite**: Quantization is applied after weights are trained and built; the graph54is rewritten so you can run or save immediately.55"""5657"""58## Overview5960This guide shows how to use 4-bit (W4A8) post-training quantization in Keras:61621. [Quantizing a minimal functional model](#quantizing-a-minimal-functional-model)632. [Saving and reloading a quantized model](#saving-and-reloading-a-quantized-model)643. [Quantizing a KerasHub model](#quantizing-a-kerashub-model)654. [When to use INT4 vs INT8](#when-should-i-use-int4-vs-int8)665. [Performance benchmarks](#performance--benchmarking)676. [Practical Tips](#practical-tips)687. [Limitations](#limitations)69"""7071"""72## Quantizing a Minimal Functional Model7374Below we build a small functional model, capture a baseline output, quantize to INT475in place, and compare outputs with an MSE metric. (For real evaluation use your76validation metric.)77"""7879import numpy as np80import keras81from keras import layers8283# Create a random number generator.84rng = np.random.default_rng()8586# Create a simple functional model.87inputs = keras.Input(shape=(10,))88x = layers.Dense(32, activation="relu")(inputs)89outputs = layers.Dense(1, name="target")(x)90model = keras.Model(inputs, outputs)9192# Baseline output with full-precision weights.93x_eval = rng.random((32, 10)).astype("float32")94y_fp32 = model(x_eval)959697# Quantize the model in-place to INT4 (W4A8).98model.quantize("int4")99100# Compare outputs (MSE).101y_int4 = model(x_eval)102mse = keras.ops.mean(keras.ops.square(y_fp32 - y_int4))103print("Full-Precision vs INT4 MSE:", float(mse))104105"""106The INT4 quantized model usually produces outputs close enough for many downstream107tasks. Expect larger deltas than INT8, so always validate on your own data.108"""109110"""111## Saving and Reloading a Quantized Model112113You can use standard Keras saving / loading APIs. Quantization metadata (including114scales and packed weights) is preserved.115"""116117# Save the quantized model and reload to verify round-trip.118model.save("int4.keras")119int4_reloaded = keras.saving.load_model("int4.keras")120y_int4_reloaded = int4_reloaded(x_eval)121122# Compare outputs (MSE).123roundtrip_mse = keras.ops.mean(keras.ops.square(y_fp32 - y_int4_reloaded))124print("MSE (INT4 vs reloaded INT4):", float(roundtrip_mse))125126"""127## Quantizing a KerasHub Model128129All KerasHub models support the `.quantize(...)` API for post-training quantization,130and follow the same workflow as above.131132In this example, we will:1331341. Load the [gemma3_1b](https://www.kaggle.com/models/keras/gemma3/keras/gemma3_1b)135preset from KerasHub1362. Generate text using both the full-precision and quantized models, and compare outputs.1373. Save both models to disk and compute storage savings.1384. Reload the INT4 model and verify output consistency with the original quantized model.139"""140import os141from keras_hub.models import Gemma3CausalLM142143# Load a Gemma3 preset from KerasHub.144gemma3 = Gemma3CausalLM.from_preset("gemma3_1b")145146# Generate with full-precision weights.147fp_output = gemma3.generate("Keras is a", max_length=30)148print("Full-precision output:", fp_output)149150# Save the full-precision model to a preset.151gemma3.save_to_preset("gemma3_fp32")152153# Quantize to INT4.154gemma3.quantize("int4")155156# Generate with INT4 weights.157output = gemma3.generate("Keras is a", max_length=30)158print("Quantized output:", output)159160# Save INT4 model to a new preset.161gemma3.save_to_preset("gemma3_int4")162163# Reload and compare outputs164gemma3_int4 = Gemma3CausalLM.from_preset("gemma3_int4")165166output = gemma3_int4.generate("Keras is a", max_length=30)167print("Quantized reloaded output:", output)168169170# Compute storage savings171def bytes_to_mib(n):172return n / (1024**2)173174175gemma_fp32_size = os.path.getsize("gemma3_fp32/model.weights.h5")176gemma_int4_size = os.path.getsize("gemma3_int4/model.weights.h5")177178gemma_reduction = 100.0 * (1.0 - (gemma_int4_size / max(gemma_fp32_size, 1)))179print(f"Gemma3: FP32 file size: {bytes_to_mib(gemma_fp32_size):.2f} MiB")180print(f"Gemma3: INT4 file size: {bytes_to_mib(gemma_int4_size):.2f} MiB")181print(f"Gemma3: Size reduction: {gemma_reduction:.1f}%")182183"""184## Performance & Benchmarking185186Micro-benchmarks collected on a single NVIDIA L4 (22.5 GB). Baselines are FP32.187188### Text Classification (DistilBERT Base on SST-2)189190<img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> [**View in Colab**](https://colab.research.google.com/gist/JyotinderSingh/77e874187d6da3f8280c053192f78c06/int4-quantization-micro-benchmark-distilbert.ipynb)191192| Metric | FP32 | INT4 | Change |193| ------ | ---- | ---- | ------ |194| Accuracy (↑) | 91.06% | 90.14% | -0.92pp |195| Model Size (MB, ↓) | 255.86 | 159.49 | -37.67% |196| Peak GPU Memory (MiB, ↓) | 1554.00 | 1243.26 | -20.00% |197| Latency (ms/sample, ↓) | 6.43 | 5.73 | -10.83% |198| Throughput (samples/s, ↑) | 155.60 | 174.50 | +12.15% |199200**Analysis**: Accuracy drop is modest (<1pp) with notable speed and memory gains;201encoder-only models tend to retain fidelity under heavier weight compression.202203### Text Generation (Falcon 1B)204205<img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> [**View in Colab**](https://colab.research.google.com/gist/JyotinderSingh/19ab238e0f5b29ae24c0faf4128e7d7e/int4_quantization_micro_benchmark_falcon.ipynb)206207| Metric | FP32 | INT4 | Change |208| ------ | ---- | ---- | ------ |209| Perplexity (↓) | 7.44 | 9.98 | +34.15% |210| Model Size (GB, ↓) | 4.8884 | 0.9526 | -80.51% |211| Peak GPU Memory (MiB, ↓) | 8021.12 | 5483.46 | -31.64% |212| First Token Latency (ms, ↓) | 128.87 | 122.50 | -4.95% |213| Sequence Latency (ms, ↓) | 338.29 | 181.93 | -46.22% |214| Token Throughput (tokens/s, ↑) | 174.41 | 256.96 | +47.33% |215216**Analysis**: INT4 gives large size (-80%) and memory (-32%) reductions. Perplexity217increases (expected for aggressive compression) yet sequence latency drops and218throughput rises ~50%.219220### Text Generation (Gemma3 1B)221222<img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> [**View in Colab**](https://colab.research.google.com/gist/JyotinderSingh/9ca7813971868d5d1a16cd7998d0e352/int4_quantization_micro_benchmark_gemma3.ipynb)223224| Metric | FP32 | INT4 | Change |225| ------ | ---- | ---- | ------ |226| Perplexity (↓) | 6.17 | 10.46 | +69.61% |227| Model Size (GB, ↓) | 3.7303 | 1.4576 | -60.92% |228| Peak GPU Memory (MiB, ↓) | 6844.67 | 5008.14 | -26.83% |229| First Token Latency (ms, ↓) | 57.42 | 64.21 | +11.83% |230| Sequence Latency (ms, ↓) | 239.78 | 161.18 | -32.78% |231| Token Throughput (tokens/s, ↑) | 246.06 | 366.05 | +48.76% |232233**Analysis**: INT4 gives large size (-61%) and memory (-27%) reductions. Perplexity234increases (expected for aggressive compression) yet sequence latency drops and235throughput rises ~50%.236237### Text Generation (Llama 3.2 1B)238239<img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> [**View in Colab**](https://colab.research.google.com/gist/JyotinderSingh/310f50a0ca0eba3754de41c612b3b8ef/int4_quantization_micro_benchmark_llama3.ipynb)240241| Metric | FP32 | INT4 | Change |242| ------ | ---- | ---- | ------ |243| Perplexity (↓) | 6.38 | 14.16 | +121.78% |244| Model Size (GB, ↓) | 5.5890 | 2.4186 | -56.73% |245| Peak GPU Memory (MiB, ↓) | 9509.49 | 6810.26 | -28.38% |246| First Token Latency (ms, ↓) | 209.41 | 219.09 | +4.62% |247| Sequence Latency (ms, ↓) | 322.33 | 262.15 | -18.67% |248| Token Throughput (tokens/s, ↑) | 183.82 | 230.78 | +25.55% |249250**Analysis**: INT4 gives large size (-57%) and memory (-28%) reductions. Perplexity251increases (expected for aggressive compression) yet sequence latency drops and252throughput rises ~25%.253254### Text Generation (OPT 125M)255256<img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> [**View in Colab**](https://colab.research.google.com/gist/JyotinderSingh/918fcdb8a1433dea12800f8ca4a240f5/int4_quantization_micro_benchmark_opt.ipynb)257258| Metric | FP32 | INT4 | Change |259| ------ | ---- | ---- | ------ |260| Perplexity (↓) | 13.85 | 21.02 | +51.79% |261| Model Size (MB, ↓) | 468.3 | 284.0 | -39.37% |262| Peak GPU Memory (MiB, ↓) | 1007.23 | 659.28 | -34.54% |263| First Token Latency (ms/sample, ↓) | 95.79 | 97.87 | +2.18% |264| Sequence Latency (ms/sample, ↓) | 60.35 | 54.64 | -9.46% |265| Throughput (samples/s, ↑) | 973.41 | 1075.15 | +10.45% |266267**Analysis**: INT4 gives large size (-39%) and memory (-35%) reductions. Perplexity268increases (expected for aggressive compression) yet sequence latency drops and269throughput rises ~10%.270"""271272"""273## When should I use INT4 vs INT8?274275| Goal / Constraint | Prefer INT8 | Prefer INT4 (W4A8) |276| ----------------- | ----------- | ------------------ |277| Minimal accuracy drop critical | ✔︎ | |278| Maximum compression (disk / RAM) | | ✔︎ |279| Bandwidth-bound inference | Possible | Often better |280| Decoder LLM | ✔︎ | Try with eval first |281| Encoder / classification models | ✔︎ | ✔︎ |282| Available kernels / tooling maturity | ✔︎ | Emerging |283284* Start with INT8; if memory or distribution size is still a bottleneck, evaluate INT4.285* For LLMs, measure task-specific metrics (perplexity, exact match, etc.) after INT4.286* Combine INT4 weights + LoRA adapters for efficient fine-tuning workflows.287"""288289"""290## Practical Tips291292* Post-training quantization (PTQ) is a one-time operation; you cannot train a model293after quantizing it to INT4.294* Always materialize weights before quantization (e.g., `build()` or a forward pass).295* Evaluate on a representative validation set; track task metrics, not just MSE.296* Use LoRA for further fine-tuning.297298## Limitations299* Runtime unpack adds overhead (weights are decompressed layer-wise for each forward pass).300* Large compression leads to accuracy drop (especially for decoder-only LLMs).301* LoRA export path is lossy (dequantize -> add delta -> requantize).302* Keras does not yet support native fused INT4 kernels; relies on unpack + INT8 matmul.303"""304305306