Path: blob/master/guides/gptq_quantization_in_keras.py
4282 views
"""1Title: GPTQ Quantization in Keras2Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh)3Date created: 2025/10/164Last modified: 2025/10/165Description: How to run weight-only GPTQ quantization for Keras & KerasHub models.6Accelerator: GPU7"""89"""10## What is GPTQ?1112GPTQ ("Generative Pre-Training Quantization") is a post-training, weight-only13quantization method that uses a second-order approximation of the loss (via a14Hessian estimate) to minimize the error introduced when compressing weights to15lower precision, typically 4-bit integers.1617Unlike standard post-training techniques, GPTQ keeps activations in18higher-precision and only quantizes the weights. This often preserves model19quality in low bit-width settings while still providing large storage and20memory savings.2122Keras supports GPTQ quantization for KerasHub models via the23`keras.quantizers.GPTQConfig` class.24"""2526"""27## Load a KerasHub model2829This guide uses the `Gemma3CausalLM` model from KerasHub, a small (1B30parameter) causal language model.3132"""33import keras34from keras_hub.models import Gemma3CausalLM35from datasets import load_dataset363738prompt = "Keras is a"3940model = Gemma3CausalLM.from_preset("gemma3_1b")4142outputs = model.generate(prompt, max_length=30)43print(outputs)4445"""46## Configure & run GPTQ quantization4748You can configure GPTQ quantization via the `keras.quantizers.GPTQConfig` class.4950The GPTQ configuration requires a calibration dataset and tokenizer, which it51uses to estimate the Hessian and quantization error. Here, we use a small slice52of the WikiText-2 dataset for calibration.5354You can tune several parameters to trade off speed, memory, and accuracy. The55most important of these are `weight_bits` (the bit-width to quantize weights to)56and `group_size` (the number of weights to quantize together). The group size57controls the granularity of quantization: smaller groups typically yield better58accuracy but are slower to quantize and may use more memory. A good starting59point is `group_size=128` for 4-bit quantization (`weight_bits=4`).6061In this example, we first prepare a tiny calibration set, and then run GPTQ on62the model using the `.quantize(...)` API.63"""6465# Calibration slice (use a larger/representative set in practice)66texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")["text"]6768calibration_dataset = [69s + "." for text in texts for s in map(str.strip, text.split(".")) if s70]7172gptq_config = keras.quantizers.GPTQConfig(73dataset=calibration_dataset,74tokenizer=model.preprocessor.tokenizer,75weight_bits=4,76group_size=128,77num_samples=256,78sequence_length=256,79hessian_damping=0.01,80symmetric=False,81activation_order=False,82)8384model.quantize("gptq", config=gptq_config)8586outputs = model.generate(prompt, max_length=30)87print(outputs)8889"""90## Model Export9192The GPTQ quantized model can be saved to a preset and reloaded elsewhere, just93like any other KerasHub model.94"""9596model.save_to_preset("gemma3_gptq_w4gs128_preset")97model_from_preset = Gemma3CausalLM.from_preset("gemma3_gptq_w4gs128_preset")98output = model_from_preset.generate(prompt, max_length=30)99print(output)100101"""102## Performance & Benchmarking103104Micro-benchmarks collected on a single NVIDIA 4070 Ti Super (16 GB).105Baselines are FP32.106107Dataset: WikiText-2.108109110| Model (preset) | Perplexity Increase % (↓ better) | Disk Storage Reduction Δ % (↓ better) | VRAM Reduction Δ % (↓ better) | First-token Latency Δ % (↓ better) | Throughput Δ % (↑ better) |111| --------------------------------- | -------------------------------: | ------------------------------------: | ----------------------------: | ---------------------------------: | ------------------------: |112| GPT2 (gpt2_base_en_cnn_dailymail) | 1.0% | -50.1% ↓ | -41.1% ↓ | +0.7% ↑ | +20.1% ↑ |113| OPT (opt_125m_en) | 10.0% | -49.8% ↓ | -47.0% ↓ | +6.7% ↑ | -15.7% ↓ |114| Bloom (bloom_1.1b_multi) | 7.0% | -47.0% ↓ | -54.0% ↓ | +1.8% ↑ | -15.7% ↓ |115| Gemma3 (gemma3_1b) | 3.0% | -51.5% ↓ | -51.8% ↓ | +39.5% ↑ | +5.7% ↑ |116117118Detailed benchmarking numbers and scripts are available119[here](https://github.com/keras-team/keras/pull/21641).120121### Analysis122123There is notable reduction in disk space and VRAM usage across all models, with124disk space savings around 50% and VRAM savings ranging from 41% to 54%. The125reported disk savings understate the true weight compression because presets126also include non-weight assets.127128Perplexity increases only marginally, indicating model quality is largely129preserved after quantization.130"""131132"""133## Practical tips134135* GPTQ is a post-training technique; training after quantization is not supported.136* Always use the model's own tokenizer for calibration.137* Use a representative calibration set; small slices are only for demos.138* Start with W4 group_size=128; tune per model/task.139"""140141142