Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/writing_quantization_compatible_layers.py
4282 views
1
"""
2
Title: Writing Quantization-Compatible Layers in Keras
3
Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh)
4
Date created: 2025/10/16
5
Last modified: 2025/10/16
6
Description: Complete guide for writing quantization-compatible Keras layers.
7
Accelerator: GPU
8
"""
9
10
"""
11
## What are quantization-compatible layers?
12
13
Keras lets you optimize models via post-training quantization (PTQ) by calling
14
the `layer.quantize(...)` or `model.quantize(...)` APIs. Keras exposes an
15
extensible framework for defining quantization-compatible layers. This lets you
16
author custom layers that plug into the quantization framework, can be quantized
17
to INT8 or INT4, and saved/loaded with quantization metadata.
18
19
A quantization-compatible layer needs to implement a few hooks, so that it can:
20
21
- Switch its variables to quantized representations.
22
- Use a quantization-aware forward path at inference.
23
- Save and load quantization metadata with the model.
24
25
In this guide, we'll implement a simple layer that supports INT8 PTQ. The same
26
patterns generalize to INT4 quantization and FP8 mixed-precision training.
27
"""
28
29
"""
30
## The hooks you'll implement
31
32
At minimum, your layer should define:
33
34
- `quantize(mode, **kwargs)`: Converts existing variables to quantized form and
35
switches the dtype policy
36
- `_int8_build(...)`: Allocates INT8 variables needed by your layer
37
- `_int8_call(inputs, training=None)`: Minimal INT8 forward path
38
39
We'll implement these for a very small layer called `SimpleScale`, which
40
multiplies the inputs by a trainable per-feature vector (elementwise scaling on
41
the last dimension). The same patterns generalize to more sophisticated layers.
42
"""
43
44
"""
45
## Writing a Simple Quantization-Compatible Layer
46
47
We start with a tiny layer that learns a per-feature multiplier. The
48
full-precision path just computes `y = x * w`. We'll add the quantization hooks
49
step by step.
50
"""
51
52
import numpy as np
53
import keras
54
from keras import ops, quantizers, dtype_policies
55
from keras.layers import Layer, Input
56
57
58
class SimpleScale(Layer):
59
"""A layer that learns a per-feature scaling factor."""
60
61
def __init__(self, **kwargs):
62
super().__init__(**kwargs)
63
64
def build(self, input_shape):
65
input_dim = input_shape[-1]
66
self._kernel = self.add_weight(
67
name="kernel",
68
shape=(input_dim,),
69
initializer="random_uniform",
70
)
71
72
def call(self, inputs, training=None):
73
return ops.multiply(inputs, self._kernel)
74
75
76
"""
77
### The `quantize()` method
78
79
PTQ is a one-time rewrite. After you train or load your FP32 layer, you call
80
`layer.quantize("int8")`. The layer should:
81
82
1. Read its existing full-precision variables (e.g., `self._kernel`).
83
2. Quantize them to INT8 values plus a quantization scale.
84
3. Replace full-precision variables with INT8 storage and assign the quantized
85
data.
86
4. Switch the `dtype_policy` to a quantized variant (e.g., `int8_from_float32`).
87
"""
88
89
90
def quantize(self, mode, **kwargs):
91
if mode != "int8":
92
raise NotImplementedError(f"Unsupported quantization mode: {mode}")
93
94
quantized_kernel, scale = quantizers.abs_max_quantize(
95
self._kernel, axis=0, dtype="int8", to_numpy=True
96
)
97
scale = ops.squeeze(scale, axis=0)
98
99
kernel_shape = self._kernel.shape
100
101
del self._kernel
102
103
# Allocate INT8 variables. Discussed in the next section.
104
self._int8_build(kernel_shape)
105
106
self._kernel.assign(quantized_kernel)
107
self.scale.assign(scale)
108
109
# `_is_quantized` should be set before changing dtype policy to inform
110
# the setter that quantized variables are initialized.
111
self._is_quantized = True
112
113
if self.dtype_policy.quantization_mode is None:
114
policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
115
self.dtype_policy = policy
116
117
118
"""
119
#### Note
120
121
1. The `quantize(...)` method should validate `mode` and raise a
122
`NotImplementedError` if the mode is not supported.
123
2. Ensure your `quantize(...)` sets a quantized dtype policy based on the
124
prior policy, e.g., `int8_from_float32` or `int8_from_bfloat16`. This ensures
125
that the layer's `quantization_mode` is correctly set.
126
127
3. The `_is_quantized` flag should be set before changing the dtype policy to
128
inform the setter that quantized variables are initialized.
129
"""
130
"""
131
### The `_int8_build(...)` method
132
133
This `int8_build(...)` method is called from `quantize(...)` to initialize the
134
INT8 variables. It should allocate:
135
136
- `self._kernel` as an INT8 vector of shape `(input_dim,)` (the same shape as
137
the original full-precision kernel).
138
- `self.scale` as the scalar quantization scale in the layer's variable dtype,
139
which is FP32 in this case.
140
"""
141
142
143
def _int8_build(self, kernel_shape):
144
(input_dim,) = kernel_shape
145
self._kernel = self.add_weight(
146
name="kernel",
147
shape=(input_dim,),
148
initializer="zeros",
149
dtype="int8",
150
trainable=False,
151
)
152
self.scale = self.add_weight(
153
name="scale",
154
initializer="ones",
155
trainable=False,
156
)
157
158
159
"""
160
#### Note
161
162
1. INT8 variables should be created with `trainable=False`, as quantized parameters
163
are not meant to be updated during training. Subsequent fine-tuning should not
164
alter these quantized variables.
165
2. If you support INT4 quantization, implement a similar `_int4_build(...)`
166
method that allocates packed INT4 storage (often packed into INT8) plus
167
per-feature scales. The original unpacked dimensions and packing axis should
168
be recorded as instance variables for use in the call path. A reference
169
implementation is available in the Keras
170
[Dense](https://github.com/keras-team/keras/blob/3c3d6adc08db627d89b5ad5e7f9b0ba3e88f2641/keras/src/layers/core/dense.py#L481-L512)
171
layer.
172
"""
173
174
"""
175
### The `_int8_call(...)` method
176
177
The `_int8_call(...)` method implements a minimal INT8 forward path. It uses the
178
quantized variables allocated in `_int8_build(...)` and de-scales the output
179
back to floating-point.
180
181
The base `keras.Layer` class automatically dispatches to this method when the
182
layer is quantized, without requiring you to wire it up manually.
183
184
The INT8 path mirrors the float computation `y = x * w` but performs:
185
186
1. Elementwise multiply using the quantized weight.
187
2. De-scale back to float by dividing with the `scale`.
188
"""
189
190
191
def _int8_call(self, inputs, training=None):
192
x = ops.multiply(inputs, self._kernel)
193
x = ops.divide(x, self.scale)
194
return x
195
196
197
"""
198
## Complete `SimpleScale` class with hooks
199
200
Below is the full class definition that incorporates the all the hooks shown above (`quantize`, `_int8_build`,
201
`_int8_call`).
202
"""
203
204
205
class SimpleScale(Layer):
206
"""A layer that learns a per-feature scaling factor."""
207
208
def __init__(self, **kwargs):
209
super().__init__(**kwargs)
210
211
def build(self, input_shape):
212
input_dim = input_shape[-1]
213
self._kernel = self.add_weight(
214
name="kernel",
215
shape=(input_dim,),
216
initializer="random_uniform",
217
)
218
219
def call(self, inputs, training=None):
220
return ops.multiply(inputs, self._kernel)
221
222
def quantize(self, mode, **kwargs):
223
if mode != "int8":
224
raise NotImplementedError(f"Unsupported quantization mode: {mode}")
225
226
quantized_kernel, scale = quantizers.abs_max_quantize(
227
self._kernel, axis=0, dtype="int8", to_numpy=True
228
)
229
scale = ops.squeeze(scale, axis=0)
230
231
kernel_shape = self._kernel.shape
232
233
del self._kernel
234
235
self._int8_build(kernel_shape)
236
237
self._kernel.assign(quantized_kernel)
238
self.scale.assign(scale)
239
240
self._is_quantized = True
241
242
if self.dtype_policy.quantization_mode is None:
243
policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
244
self.dtype_policy = policy
245
246
def _int8_build(self, kernel_shape):
247
(input_dim,) = kernel_shape
248
self._kernel = self.add_weight(
249
name="kernel",
250
shape=(input_dim,),
251
initializer="zeros",
252
dtype="int8",
253
trainable=False,
254
)
255
self.scale = self.add_weight(
256
name="scale",
257
initializer="ones",
258
trainable=False,
259
)
260
261
def _int8_call(self, inputs, training=None):
262
x = ops.multiply(inputs, self._kernel)
263
x = ops.divide(x, self.scale)
264
return x
265
266
267
"""
268
## Try it: quantize and run a forward pass
269
270
Below we build the layer, then quantize to INT8 and call it again.
271
"""
272
273
# Sample inputs
274
rng = np.random.default_rng()
275
x = rng.random((2, 4)).astype("float32")
276
277
layer = SimpleScale()
278
279
# Forward pass in float
280
y_fp = layer(x)
281
282
# Quantize to INT8 and run again
283
layer.quantize("int8")
284
y_int8 = layer(x)
285
286
print("SimpleScale FP32 sample:", y_fp[0].numpy())
287
print("SimpleScale INT8 sample:", y_int8[0].numpy())
288
289
290
"""
291
## Extending to INT4
292
293
If you want to support INT4 quantization, add:
294
295
- `_int4_build(...)`: allocate a packed 4-bit storage (often packed into int8) plus per-feature scales
296
- `_int4_call(...)`: unpack at runtime and follow the same de-scale pattern
297
- `quantize("int4")`: quantize weights with `value_range=(-8, 7)`, then pack to int4 storage
298
299
See the
300
[Dense](https://github.com/keras-team/keras/blob/3c3d6adc08db627d89b5ad5e7f9b0ba3e88f2641/keras/src/layers/core/dense.py#L602-L653)
301
reference for a complete packed int4 example, including how to track and use the
302
original (unpacked) dimension in the call path.
303
"""
304
305
"""
306
## Adding Serialization Support
307
308
Keras depends on a fixed serialization contract for saving and loading models.
309
This contract is complicated by quantization, since the variables you need to
310
save and load depend on the quantization mode.
311
312
The framework provides two hooks for layers to customize variable serialization:
313
314
- `save_own_variables(self, store)`: Write variables to `store` in a fixed
315
order.
316
- `load_own_variables(self, store)`: Read variables from `store` in the same
317
order.
318
319
Additionally, the `build(...)` method should also be modified to allocate the
320
correct variables based on presence (or absence) of a `self.quantization_mode`.
321
322
For this layer we only aim to support two modes (Non-quantized and INT8), so the
323
serialization contract is:
324
325
- None (no quantization): `["kernel"]`
326
- INT8: `["kernel", "scale"]`
327
328
The following code implements the required hooks; Keras will call them during
329
`model.save(...)` and `keras.saving.load_model(...)`.
330
"""
331
332
333
def save_own_variables(self, store):
334
# Write variables to `store` in a fixed order based on quantization mode.
335
# `store` is a key-value mapping provided by Keras during model.save().
336
# Values are tensors.
337
if not self.built:
338
return
339
mode = self.quantization_mode
340
idx = 0
341
if mode is None:
342
# Order: _kernel
343
store[str(idx)] = self._kernel
344
elif mode == "int8":
345
# Order: _kernel, scale
346
store[str(idx)] = self._kernel
347
idx += 1
348
store[str(idx)] = self.scale
349
else:
350
raise ValueError(f"Unsupported quantization mode for save: {mode}")
351
352
353
def load_own_variables(self, store):
354
# Read variables from `store` in the same order used by
355
# `save_own_variables`. Keras calls this during
356
# `keras.saving.load_model(...)`.
357
if not self.built:
358
return
359
mode = self.quantization_mode
360
idx = 0
361
if mode is None:
362
self._kernel.assign(store[str(idx)])
363
elif mode == "int8":
364
self._kernel.assign(store[str(idx)])
365
idx += 1
366
self.scale.assign(store[str(idx)])
367
else:
368
raise ValueError(f"Unsupported quantization mode for load: {mode}")
369
370
371
"""
372
### Modify the `build(...)` method
373
374
The build method itself also needs to be aware of quantization mode. If a saved
375
quantized layer is being loaded/deserialized, `self.quantization_mode` will be
376
set before `build(...)` is called. In that case, we need to allocate quantized
377
variables directly instead of full-precision ones.
378
"""
379
380
381
def build(self, input_shape):
382
input_dim = input_shape[-1]
383
384
# Quantized build path.
385
if self.quantization_mode:
386
if self.quantization_mode == "int8":
387
self._int8_build((input_dim,))
388
else:
389
# Regular FP32 build path.
390
self._kernel = self.add_weight(
391
name="kernel",
392
shape=(input_dim,),
393
initializer="random_uniform",
394
)
395
396
397
"""
398
## Complete implementation with serialization
399
400
The full class with serialization support looks like this:
401
"""
402
403
404
@keras.saving.register_keras_serializable()
405
class SimpleScale(Layer):
406
"""A layer that learns a per-feature scaling factor."""
407
408
def __init__(self, **kwargs):
409
super().__init__(**kwargs)
410
411
def build(self, input_shape):
412
input_dim = input_shape[-1]
413
414
if self.quantization_mode:
415
if self.quantization_mode == "int8":
416
self._int8_build((input_dim,))
417
else:
418
self._kernel = self.add_weight(
419
name="kernel",
420
shape=(input_dim,),
421
initializer="random_uniform",
422
)
423
424
def call(self, inputs, training=None):
425
return ops.multiply(inputs, self._kernel)
426
427
def quantize(self, mode, **kwargs):
428
if mode != "int8":
429
raise NotImplementedError(f"Unsupported quantization mode: {mode}")
430
431
quantized_kernel, scale = quantizers.abs_max_quantize(
432
self._kernel, axis=0, dtype="int8", to_numpy=True
433
)
434
scale = ops.squeeze(scale, axis=0)
435
436
kernel_shape = self._kernel.shape
437
438
del self._kernel
439
440
self._int8_build(kernel_shape)
441
442
self._kernel.assign(quantized_kernel)
443
self.scale.assign(scale)
444
445
self._is_quantized = True
446
447
if self.dtype_policy.quantization_mode is None:
448
policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
449
self.dtype_policy = policy
450
451
def _int8_build(self, kernel_shape):
452
(input_dim,) = kernel_shape
453
self._kernel = self.add_weight(
454
name="kernel",
455
shape=(input_dim,),
456
initializer="zeros",
457
dtype="int8",
458
trainable=False,
459
)
460
self.scale = self.add_weight(
461
name="scale",
462
initializer="ones",
463
trainable=False,
464
)
465
466
def _int8_call(self, inputs, training=None):
467
x = ops.multiply(inputs, self._kernel)
468
x = ops.divide(x, self.scale)
469
return x
470
471
def save_own_variables(self, store):
472
# Write variables to `store` in a fixed order based on quantization mode.
473
# `store` is a key-value mapping provided by Keras during model.save(); values are tensors.
474
if not self.built:
475
return
476
mode = self.quantization_mode
477
idx = 0
478
if mode is None:
479
# Order: _kernel
480
store[str(idx)] = self._kernel
481
elif mode == "int8":
482
# Order: _kernel, scale
483
store[str(idx)] = self._kernel
484
idx += 1
485
store[str(idx)] = self.scale
486
else:
487
raise ValueError(f"Unsupported quantization mode for save: {mode}")
488
489
def load_own_variables(self, store):
490
# Read variables from `store` in the same order used by `save_own_variables`.
491
# Keras calls this during `keras.saving.load_model(...)`.
492
if not self.built:
493
return
494
mode = self.quantization_mode
495
idx = 0
496
if mode is None:
497
self._kernel.assign(store[str(idx)])
498
elif mode == "int8":
499
self._kernel.assign(store[str(idx)])
500
idx += 1
501
self.scale.assign(store[str(idx)])
502
else:
503
raise ValueError(f"Unsupported quantization mode for load: {mode}")
504
505
506
"""
507
#### Note
508
509
The `@keras.saving.register_keras_serializable()` decorator is needed to
510
register the class for serialization.
511
"""
512
"""
513
## Try it: quantize, save, and load
514
"""
515
model = keras.Sequential([Input(shape=(4,)), SimpleScale()])
516
model.build((None, 4))
517
518
# Quantize to INT8.
519
model.quantize("int8")
520
y_int8 = model(x)
521
print("SimpleScale INT8 sample:", y_int8[0].numpy())
522
523
# Save and load the quantized model.
524
model.save("simplescale_int8.keras")
525
loaded = keras.saving.load_model("simplescale_int8.keras")
526
527
y_loaded = loaded(x)
528
print("Loaded INT8 sample:", y_loaded[0].numpy())
529
530
"""
531
## Practical tips
532
533
Here are concrete patterns you can reuse when making your own layers PTQ-friendly.
534
535
- Build-time vs call-time responsibilities
536
- In `build(...)`, if `self.quantization_mode` is set: allocate the quantized
537
variables and skip allocating the float kernel to avoid duplicates.
538
- Record any metadata you need for the call path, e.g., for INT4:
539
- The axis you packed along (e.g., `_int4_pack_axis`).
540
- The original (unpacked) length on that axis (e.g., `_original_input_dim` or
541
`_original_length_along_pack_axis`).
542
- In quantized call hooks, compute with the quantized buffers and de-scale back
543
to float at the end, wherever possible. This allows you to leverage optimized
544
low-precision kernels (e.g., cuBLAS INT8 GEMM).
545
546
- INT4 specifics (packed nibbles)
547
- Quantize to INT4 values in range [-8, 7] (still dtype int8), then pack two
548
4-bit integers per byte with `quantizers.pack_int4(..., axis=pack_axis)`.
549
- Store the packed kernel with `dtype="int8"`. Unpack on the fly in the call
550
path with `quantizers.unpack_int4(packed, orig_len, axis=pack_axis)`.
551
- Keep the original length and pack axis so you can unpack for LoRA,
552
gradients, and serialization.
553
554
- Inputs quantization and broadcasting
555
- In the forward path de-scale outputs using
556
`outputs /= (inputs_scale * kernel_scale)`; make sure both scales broadcast to
557
the output shape.
558
559
- Dtype policy lifecycle
560
- During `quantize(mode)`: delete FP32 variables, allocate quantized ones,
561
assign values, then set `self._is_quantized = True` before changing the
562
dtype policy.
563
- Only change policy if the current policy has `quantization_mode is None` to
564
avoid an infinite loop.
565
566
- Serialization contract
567
- Provide a fixed-order logic for variable serialization so save/load is
568
deterministic.
569
- Write variables in a fixed order per mode (e.g., None: [kernel, bias],
570
`"int8"`: [kernel, bias, kernel_scale], `"int4"`: [kernel, bias, kernel_scale]).
571
572
- Validation and error handling
573
- Validate `mode` early and raise a `NotImplementedError` for unsupported
574
modes.
575
- After quantization, run a tiny smoke test and assert the output matches the
576
FP32 path and values are within a reasonable tolerance after de-scale.
577
578
- Performance hygiene
579
- Avoid repeated transformations hot paths; precompute as much information
580
as possible and keep the forward-pass hooks lightweight.
581
- Keep quantized buffers `trainable=False` and prefer vectorized operations.
582
583
For more advanced patterns, refer to the
584
[Dense](https://github.com/keras-team/keras/blob/3c3d6adc08db627d89b5ad5e7f9b0ba3e88f2641/keras/src/layers/core/dense.py) and
585
[EinsumDense](https://github.com/keras-team/keras/blob/3c3d6adc08db627d89b5ad5e7f9b0ba3e88f2641/keras/src/layers/core/einsum_dense.py)
586
reference implementations.
587
"""
588
589