Path: blob/master/guides/writing_quantization_compatible_layers.py
4282 views
"""1Title: Writing Quantization-Compatible Layers in Keras2Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh)3Date created: 2025/10/164Last modified: 2025/10/165Description: Complete guide for writing quantization-compatible Keras layers.6Accelerator: GPU7"""89"""10## What are quantization-compatible layers?1112Keras lets you optimize models via post-training quantization (PTQ) by calling13the `layer.quantize(...)` or `model.quantize(...)` APIs. Keras exposes an14extensible framework for defining quantization-compatible layers. This lets you15author custom layers that plug into the quantization framework, can be quantized16to INT8 or INT4, and saved/loaded with quantization metadata.1718A quantization-compatible layer needs to implement a few hooks, so that it can:1920- Switch its variables to quantized representations.21- Use a quantization-aware forward path at inference.22- Save and load quantization metadata with the model.2324In this guide, we'll implement a simple layer that supports INT8 PTQ. The same25patterns generalize to INT4 quantization and FP8 mixed-precision training.26"""2728"""29## The hooks you'll implement3031At minimum, your layer should define:3233- `quantize(mode, **kwargs)`: Converts existing variables to quantized form and34switches the dtype policy35- `_int8_build(...)`: Allocates INT8 variables needed by your layer36- `_int8_call(inputs, training=None)`: Minimal INT8 forward path3738We'll implement these for a very small layer called `SimpleScale`, which39multiplies the inputs by a trainable per-feature vector (elementwise scaling on40the last dimension). The same patterns generalize to more sophisticated layers.41"""4243"""44## Writing a Simple Quantization-Compatible Layer4546We start with a tiny layer that learns a per-feature multiplier. The47full-precision path just computes `y = x * w`. We'll add the quantization hooks48step by step.49"""5051import numpy as np52import keras53from keras import ops, quantizers, dtype_policies54from keras.layers import Layer, Input555657class SimpleScale(Layer):58"""A layer that learns a per-feature scaling factor."""5960def __init__(self, **kwargs):61super().__init__(**kwargs)6263def build(self, input_shape):64input_dim = input_shape[-1]65self._kernel = self.add_weight(66name="kernel",67shape=(input_dim,),68initializer="random_uniform",69)7071def call(self, inputs, training=None):72return ops.multiply(inputs, self._kernel)737475"""76### The `quantize()` method7778PTQ is a one-time rewrite. After you train or load your FP32 layer, you call79`layer.quantize("int8")`. The layer should:80811. Read its existing full-precision variables (e.g., `self._kernel`).822. Quantize them to INT8 values plus a quantization scale.833. Replace full-precision variables with INT8 storage and assign the quantized84data.854. Switch the `dtype_policy` to a quantized variant (e.g., `int8_from_float32`).86"""878889def quantize(self, mode, **kwargs):90if mode != "int8":91raise NotImplementedError(f"Unsupported quantization mode: {mode}")9293quantized_kernel, scale = quantizers.abs_max_quantize(94self._kernel, axis=0, dtype="int8", to_numpy=True95)96scale = ops.squeeze(scale, axis=0)9798kernel_shape = self._kernel.shape99100del self._kernel101102# Allocate INT8 variables. Discussed in the next section.103self._int8_build(kernel_shape)104105self._kernel.assign(quantized_kernel)106self.scale.assign(scale)107108# `_is_quantized` should be set before changing dtype policy to inform109# the setter that quantized variables are initialized.110self._is_quantized = True111112if self.dtype_policy.quantization_mode is None:113policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")114self.dtype_policy = policy115116117"""118#### Note1191201. The `quantize(...)` method should validate `mode` and raise a121`NotImplementedError` if the mode is not supported.1222. Ensure your `quantize(...)` sets a quantized dtype policy based on the123prior policy, e.g., `int8_from_float32` or `int8_from_bfloat16`. This ensures124that the layer's `quantization_mode` is correctly set.1251263. The `_is_quantized` flag should be set before changing the dtype policy to127inform the setter that quantized variables are initialized.128"""129"""130### The `_int8_build(...)` method131132This `int8_build(...)` method is called from `quantize(...)` to initialize the133INT8 variables. It should allocate:134135- `self._kernel` as an INT8 vector of shape `(input_dim,)` (the same shape as136the original full-precision kernel).137- `self.scale` as the scalar quantization scale in the layer's variable dtype,138which is FP32 in this case.139"""140141142def _int8_build(self, kernel_shape):143(input_dim,) = kernel_shape144self._kernel = self.add_weight(145name="kernel",146shape=(input_dim,),147initializer="zeros",148dtype="int8",149trainable=False,150)151self.scale = self.add_weight(152name="scale",153initializer="ones",154trainable=False,155)156157158"""159#### Note1601611. INT8 variables should be created with `trainable=False`, as quantized parameters162are not meant to be updated during training. Subsequent fine-tuning should not163alter these quantized variables.1642. If you support INT4 quantization, implement a similar `_int4_build(...)`165method that allocates packed INT4 storage (often packed into INT8) plus166per-feature scales. The original unpacked dimensions and packing axis should167be recorded as instance variables for use in the call path. A reference168implementation is available in the Keras169[Dense](https://github.com/keras-team/keras/blob/3c3d6adc08db627d89b5ad5e7f9b0ba3e88f2641/keras/src/layers/core/dense.py#L481-L512)170layer.171"""172173"""174### The `_int8_call(...)` method175176The `_int8_call(...)` method implements a minimal INT8 forward path. It uses the177quantized variables allocated in `_int8_build(...)` and de-scales the output178back to floating-point.179180The base `keras.Layer` class automatically dispatches to this method when the181layer is quantized, without requiring you to wire it up manually.182183The INT8 path mirrors the float computation `y = x * w` but performs:1841851. Elementwise multiply using the quantized weight.1862. De-scale back to float by dividing with the `scale`.187"""188189190def _int8_call(self, inputs, training=None):191x = ops.multiply(inputs, self._kernel)192x = ops.divide(x, self.scale)193return x194195196"""197## Complete `SimpleScale` class with hooks198199Below is the full class definition that incorporates the all the hooks shown above (`quantize`, `_int8_build`,200`_int8_call`).201"""202203204class SimpleScale(Layer):205"""A layer that learns a per-feature scaling factor."""206207def __init__(self, **kwargs):208super().__init__(**kwargs)209210def build(self, input_shape):211input_dim = input_shape[-1]212self._kernel = self.add_weight(213name="kernel",214shape=(input_dim,),215initializer="random_uniform",216)217218def call(self, inputs, training=None):219return ops.multiply(inputs, self._kernel)220221def quantize(self, mode, **kwargs):222if mode != "int8":223raise NotImplementedError(f"Unsupported quantization mode: {mode}")224225quantized_kernel, scale = quantizers.abs_max_quantize(226self._kernel, axis=0, dtype="int8", to_numpy=True227)228scale = ops.squeeze(scale, axis=0)229230kernel_shape = self._kernel.shape231232del self._kernel233234self._int8_build(kernel_shape)235236self._kernel.assign(quantized_kernel)237self.scale.assign(scale)238239self._is_quantized = True240241if self.dtype_policy.quantization_mode is None:242policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")243self.dtype_policy = policy244245def _int8_build(self, kernel_shape):246(input_dim,) = kernel_shape247self._kernel = self.add_weight(248name="kernel",249shape=(input_dim,),250initializer="zeros",251dtype="int8",252trainable=False,253)254self.scale = self.add_weight(255name="scale",256initializer="ones",257trainable=False,258)259260def _int8_call(self, inputs, training=None):261x = ops.multiply(inputs, self._kernel)262x = ops.divide(x, self.scale)263return x264265266"""267## Try it: quantize and run a forward pass268269Below we build the layer, then quantize to INT8 and call it again.270"""271272# Sample inputs273rng = np.random.default_rng()274x = rng.random((2, 4)).astype("float32")275276layer = SimpleScale()277278# Forward pass in float279y_fp = layer(x)280281# Quantize to INT8 and run again282layer.quantize("int8")283y_int8 = layer(x)284285print("SimpleScale FP32 sample:", y_fp[0].numpy())286print("SimpleScale INT8 sample:", y_int8[0].numpy())287288289"""290## Extending to INT4291292If you want to support INT4 quantization, add:293294- `_int4_build(...)`: allocate a packed 4-bit storage (often packed into int8) plus per-feature scales295- `_int4_call(...)`: unpack at runtime and follow the same de-scale pattern296- `quantize("int4")`: quantize weights with `value_range=(-8, 7)`, then pack to int4 storage297298See the299[Dense](https://github.com/keras-team/keras/blob/3c3d6adc08db627d89b5ad5e7f9b0ba3e88f2641/keras/src/layers/core/dense.py#L602-L653)300reference for a complete packed int4 example, including how to track and use the301original (unpacked) dimension in the call path.302"""303304"""305## Adding Serialization Support306307Keras depends on a fixed serialization contract for saving and loading models.308This contract is complicated by quantization, since the variables you need to309save and load depend on the quantization mode.310311The framework provides two hooks for layers to customize variable serialization:312313- `save_own_variables(self, store)`: Write variables to `store` in a fixed314order.315- `load_own_variables(self, store)`: Read variables from `store` in the same316order.317318Additionally, the `build(...)` method should also be modified to allocate the319correct variables based on presence (or absence) of a `self.quantization_mode`.320321For this layer we only aim to support two modes (Non-quantized and INT8), so the322serialization contract is:323324- None (no quantization): `["kernel"]`325- INT8: `["kernel", "scale"]`326327The following code implements the required hooks; Keras will call them during328`model.save(...)` and `keras.saving.load_model(...)`.329"""330331332def save_own_variables(self, store):333# Write variables to `store` in a fixed order based on quantization mode.334# `store` is a key-value mapping provided by Keras during model.save().335# Values are tensors.336if not self.built:337return338mode = self.quantization_mode339idx = 0340if mode is None:341# Order: _kernel342store[str(idx)] = self._kernel343elif mode == "int8":344# Order: _kernel, scale345store[str(idx)] = self._kernel346idx += 1347store[str(idx)] = self.scale348else:349raise ValueError(f"Unsupported quantization mode for save: {mode}")350351352def load_own_variables(self, store):353# Read variables from `store` in the same order used by354# `save_own_variables`. Keras calls this during355# `keras.saving.load_model(...)`.356if not self.built:357return358mode = self.quantization_mode359idx = 0360if mode is None:361self._kernel.assign(store[str(idx)])362elif mode == "int8":363self._kernel.assign(store[str(idx)])364idx += 1365self.scale.assign(store[str(idx)])366else:367raise ValueError(f"Unsupported quantization mode for load: {mode}")368369370"""371### Modify the `build(...)` method372373The build method itself also needs to be aware of quantization mode. If a saved374quantized layer is being loaded/deserialized, `self.quantization_mode` will be375set before `build(...)` is called. In that case, we need to allocate quantized376variables directly instead of full-precision ones.377"""378379380def build(self, input_shape):381input_dim = input_shape[-1]382383# Quantized build path.384if self.quantization_mode:385if self.quantization_mode == "int8":386self._int8_build((input_dim,))387else:388# Regular FP32 build path.389self._kernel = self.add_weight(390name="kernel",391shape=(input_dim,),392initializer="random_uniform",393)394395396"""397## Complete implementation with serialization398399The full class with serialization support looks like this:400"""401402403@keras.saving.register_keras_serializable()404class SimpleScale(Layer):405"""A layer that learns a per-feature scaling factor."""406407def __init__(self, **kwargs):408super().__init__(**kwargs)409410def build(self, input_shape):411input_dim = input_shape[-1]412413if self.quantization_mode:414if self.quantization_mode == "int8":415self._int8_build((input_dim,))416else:417self._kernel = self.add_weight(418name="kernel",419shape=(input_dim,),420initializer="random_uniform",421)422423def call(self, inputs, training=None):424return ops.multiply(inputs, self._kernel)425426def quantize(self, mode, **kwargs):427if mode != "int8":428raise NotImplementedError(f"Unsupported quantization mode: {mode}")429430quantized_kernel, scale = quantizers.abs_max_quantize(431self._kernel, axis=0, dtype="int8", to_numpy=True432)433scale = ops.squeeze(scale, axis=0)434435kernel_shape = self._kernel.shape436437del self._kernel438439self._int8_build(kernel_shape)440441self._kernel.assign(quantized_kernel)442self.scale.assign(scale)443444self._is_quantized = True445446if self.dtype_policy.quantization_mode is None:447policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")448self.dtype_policy = policy449450def _int8_build(self, kernel_shape):451(input_dim,) = kernel_shape452self._kernel = self.add_weight(453name="kernel",454shape=(input_dim,),455initializer="zeros",456dtype="int8",457trainable=False,458)459self.scale = self.add_weight(460name="scale",461initializer="ones",462trainable=False,463)464465def _int8_call(self, inputs, training=None):466x = ops.multiply(inputs, self._kernel)467x = ops.divide(x, self.scale)468return x469470def save_own_variables(self, store):471# Write variables to `store` in a fixed order based on quantization mode.472# `store` is a key-value mapping provided by Keras during model.save(); values are tensors.473if not self.built:474return475mode = self.quantization_mode476idx = 0477if mode is None:478# Order: _kernel479store[str(idx)] = self._kernel480elif mode == "int8":481# Order: _kernel, scale482store[str(idx)] = self._kernel483idx += 1484store[str(idx)] = self.scale485else:486raise ValueError(f"Unsupported quantization mode for save: {mode}")487488def load_own_variables(self, store):489# Read variables from `store` in the same order used by `save_own_variables`.490# Keras calls this during `keras.saving.load_model(...)`.491if not self.built:492return493mode = self.quantization_mode494idx = 0495if mode is None:496self._kernel.assign(store[str(idx)])497elif mode == "int8":498self._kernel.assign(store[str(idx)])499idx += 1500self.scale.assign(store[str(idx)])501else:502raise ValueError(f"Unsupported quantization mode for load: {mode}")503504505"""506#### Note507508The `@keras.saving.register_keras_serializable()` decorator is needed to509register the class for serialization.510"""511"""512## Try it: quantize, save, and load513"""514model = keras.Sequential([Input(shape=(4,)), SimpleScale()])515model.build((None, 4))516517# Quantize to INT8.518model.quantize("int8")519y_int8 = model(x)520print("SimpleScale INT8 sample:", y_int8[0].numpy())521522# Save and load the quantized model.523model.save("simplescale_int8.keras")524loaded = keras.saving.load_model("simplescale_int8.keras")525526y_loaded = loaded(x)527print("Loaded INT8 sample:", y_loaded[0].numpy())528529"""530## Practical tips531532Here are concrete patterns you can reuse when making your own layers PTQ-friendly.533534- Build-time vs call-time responsibilities535- In `build(...)`, if `self.quantization_mode` is set: allocate the quantized536variables and skip allocating the float kernel to avoid duplicates.537- Record any metadata you need for the call path, e.g., for INT4:538- The axis you packed along (e.g., `_int4_pack_axis`).539- The original (unpacked) length on that axis (e.g., `_original_input_dim` or540`_original_length_along_pack_axis`).541- In quantized call hooks, compute with the quantized buffers and de-scale back542to float at the end, wherever possible. This allows you to leverage optimized543low-precision kernels (e.g., cuBLAS INT8 GEMM).544545- INT4 specifics (packed nibbles)546- Quantize to INT4 values in range [-8, 7] (still dtype int8), then pack two5474-bit integers per byte with `quantizers.pack_int4(..., axis=pack_axis)`.548- Store the packed kernel with `dtype="int8"`. Unpack on the fly in the call549path with `quantizers.unpack_int4(packed, orig_len, axis=pack_axis)`.550- Keep the original length and pack axis so you can unpack for LoRA,551gradients, and serialization.552553- Inputs quantization and broadcasting554- In the forward path de-scale outputs using555`outputs /= (inputs_scale * kernel_scale)`; make sure both scales broadcast to556the output shape.557558- Dtype policy lifecycle559- During `quantize(mode)`: delete FP32 variables, allocate quantized ones,560assign values, then set `self._is_quantized = True` before changing the561dtype policy.562- Only change policy if the current policy has `quantization_mode is None` to563avoid an infinite loop.564565- Serialization contract566- Provide a fixed-order logic for variable serialization so save/load is567deterministic.568- Write variables in a fixed order per mode (e.g., None: [kernel, bias],569`"int8"`: [kernel, bias, kernel_scale], `"int4"`: [kernel, bias, kernel_scale]).570571- Validation and error handling572- Validate `mode` early and raise a `NotImplementedError` for unsupported573modes.574- After quantization, run a tiny smoke test and assert the output matches the575FP32 path and values are within a reasonable tolerance after de-scale.576577- Performance hygiene578- Avoid repeated transformations hot paths; precompute as much information579as possible and keep the forward-pass hooks lightweight.580- Keep quantized buffers `trainable=False` and prefer vectorized operations.581582For more advanced patterns, refer to the583[Dense](https://github.com/keras-team/keras/blob/3c3d6adc08db627d89b5ad5e7f9b0ba3e88f2641/keras/src/layers/core/dense.py) and584[EinsumDense](https://github.com/keras-team/keras/blob/3c3d6adc08db627d89b5ad5e7f9b0ba3e88f2641/keras/src/layers/core/einsum_dense.py)585reference implementations.586"""587588589