Path: blob/master/guides/md/writing_quantization_compatible_layers.md
4299 views
Writing Quantization-Compatible Layers in Keras
Author: Jyotinder Singh
Date created: 2025/10/16
Last modified: 2025/10/16
Description: Complete guide for writing quantization-compatible Keras layers.
What are quantization-compatible layers?
Keras lets you optimize models via post-training quantization (PTQ) by calling the layer.quantize(...) or model.quantize(...) APIs. Keras exposes an extensible framework for defining quantization-compatible layers. This lets you author custom layers that plug into the quantization framework, can be quantized to INT8 or INT4, and saved/loaded with quantization metadata.
A quantization-compatible layer needs to implement a few hooks, so that it can:
Switch its variables to quantized representations.
Use a quantization-aware forward path at inference.
Save and load quantization metadata with the model.
In this guide, we'll implement a simple layer that supports INT8 PTQ. The same patterns generalize to INT4 quantization and FP8 mixed-precision training.
The hooks you'll implement
At minimum, your layer should define:
quantize(mode, **kwargs): Converts existing variables to quantized form and switches the dtype policy_int8_build(...): Allocates INT8 variables needed by your layer_int8_call(inputs, training=None): Minimal INT8 forward path
We'll implement these for a very small layer called SimpleScale, which multiplies the inputs by a trainable per-feature vector (elementwise scaling on the last dimension). The same patterns generalize to more sophisticated layers.
Writing a Simple Quantization-Compatible Layer
We start with a tiny layer that learns a per-feature multiplier. The full-precision path just computes y = x * w. We'll add the quantization hooks step by step.
The quantize() method
PTQ is a one-time rewrite. After you train or load your FP32 layer, you call layer.quantize("int8"). The layer should:
Read its existing full-precision variables (e.g.,
self._kernel).Quantize them to INT8 values plus a quantization scale.
Replace full-precision variables with INT8 storage and assign the quantized data.
Switch the
dtype_policyto a quantized variant (e.g.,int8_from_float32).
Note
The
quantize(...)method should validatemodeand raise aNotImplementedErrorif the mode is not supported.Ensure your
quantize(...)sets a quantized dtype policy based on the prior policy, e.g.,int8_from_float32orint8_from_bfloat16. This ensures that the layer'squantization_modeis correctly set.The
_is_quantizedflag should be set before changing the dtype policy to inform the setter that quantized variables are initialized.
The _int8_build(...) method
This int8_build(...) method is called from quantize(...) to initialize the INT8 variables. It should allocate:
self._kernelas an INT8 vector of shape(input_dim,)(the same shape as the original full-precision kernel).self.scaleas the scalar quantization scale in the layer's variable dtype, which is FP32 in this case.
Note
INT8 variables should be created with
trainable=False, as quantized parameters are not meant to be updated during training. Subsequent fine-tuning should not alter these quantized variables.If you support INT4 quantization, implement a similar
_int4_build(...)method that allocates packed INT4 storage (often packed into INT8) plus per-feature scales. The original unpacked dimensions and packing axis should be recorded as instance variables for use in the call path. A reference implementation is available in the Keras Dense layer.
The _int8_call(...) method
The _int8_call(...) method implements a minimal INT8 forward path. It uses the quantized variables allocated in _int8_build(...) and de-scales the output back to floating-point.
The base keras.Layer class automatically dispatches to this method when the layer is quantized, without requiring you to wire it up manually.
The INT8 path mirrors the float computation y = x * w but performs:
Elementwise multiply using the quantized weight.
De-scale back to float by dividing with the
scale.
Complete SimpleScale class with hooks
Below is the full class definition that incorporates the all the hooks shown above (quantize, _int8_build, _int8_call).
Try it: quantize and run a forward pass
Below we build the layer, then quantize to INT8 and call it again.
Extending to INT4
If you want to support INT4 quantization, add:
_int4_build(...): allocate a packed 4-bit storage (often packed into int8) plus per-feature scales_int4_call(...): unpack at runtime and follow the same de-scale patternquantize("int4"): quantize weights withvalue_range=(-8, 7), then pack to int4 storage
See the Dense reference for a complete packed int4 example, including how to track and use the original (unpacked) dimension in the call path.
Adding Serialization Support
Keras depends on a fixed serialization contract for saving and loading models. This contract is complicated by quantization, since the variables you need to save and load depend on the quantization mode.
The framework provides two hooks for layers to customize variable serialization:
save_own_variables(self, store): Write variables tostorein a fixed order.load_own_variables(self, store): Read variables fromstorein the same order.
Additionally, the build(...) method should also be modified to allocate the correct variables based on presence (or absence) of a self.quantization_mode.
For this layer we only aim to support two modes (Non-quantized and INT8), so the serialization contract is:
None (no quantization):
["kernel"]INT8:
["kernel", "scale"]
The following code implements the required hooks; Keras will call them during model.save(...) and keras.saving.load_model(...).
Modify the build(...) method
The build method itself also needs to be aware of quantization mode. If a saved quantized layer is being loaded/deserialized, self.quantization_mode will be set before build(...) is called. In that case, we need to allocate quantized variables directly instead of full-precision ones.
Complete implementation with serialization
The full class with serialization support looks like this:
Note
The @keras.saving.register_keras_serializable() decorator is needed to register the class for serialization.
Try it: quantize, save, and load
Practical tips
Here are concrete patterns you can reuse when making your own layers PTQ-friendly.
Build-time vs call-time responsibilities
In
build(...), ifself.quantization_modeis set: allocate the quantized variables and skip allocating the float kernel to avoid duplicates.
Record any metadata you need for the call path, e.g., for INT4:
The axis you packed along (e.g.,
_int4_pack_axis).The original (unpacked) length on that axis (e.g.,
_original_input_dimor_original_length_along_pack_axis).
In quantized call hooks, compute with the quantized buffers and de-scale back to float at the end, wherever possible. This allows you to leverage optimized low-precision kernels (e.g., cuBLAS INT8 GEMM).
INT4 specifics (packed nibbles)
Quantize to INT4 values in range [-8, 7] (still dtype int8), then pack two 4-bit integers per byte with
quantizers.pack_int4(..., axis=pack_axis).Store the packed kernel with
dtype="int8". Unpack on the fly in the call path withquantizers.unpack_int4(packed, orig_len, axis=pack_axis).Keep the original length and pack axis so you can unpack for LoRA, gradients, and serialization.
Inputs quantization and broadcasting
In the forward path de-scale outputs using
outputs /= (inputs_scale * kernel_scale); make sure both scales broadcast to the output shape.
Dtype policy lifecycle
During
quantize(mode): delete FP32 variables, allocate quantized ones, assign values, then setself._is_quantized = Truebefore changing the dtype policy.Only change policy if the current policy has
quantization_mode is Noneto avoid an infinite loop.
Serialization contract
Provide a fixed-order logic for variable serialization so save/load is deterministic.
Write variables in a fixed order per mode (e.g., None: [kernel, bias],
"int8": [kernel, bias, kernel_scale],"int4": [kernel, bias, kernel_scale]).
Validation and error handling
Validate
modeearly and raise aNotImplementedErrorfor unsupported modes.After quantization, run a tiny smoke test and assert the output matches the FP32 path and values are within a reasonable tolerance after de-scale.
Performance hygiene
Avoid repeated transformations hot paths; precompute as much information as possible and keep the forward-pass hooks lightweight.
Keep quantized buffers
trainable=Falseand prefer vectorized operations.
For more advanced patterns, refer to the Dense and EinsumDense reference implementations.