Skip to content

Commit d19fece

Browse files
Adds Native GPTQ Layer, Runtime Execution, and Serialization Support (#21641)
* initial commit * documentation * dtype alignment * cleanup model.quantize() + increase top1 tolerance for tests * test fix * cleanup layer discovery inside transformer blocks * delete original kernel slightly earlier * serialization support * Fixed tests * Clean up dense and einsum_dense code * format * fix serialization issues * Added tests for vectorized groupwise (de)quantization * convert g_idx to float32 since int32 cannot reside on GPU in TF * addresses reviews * encodes gptq config in GPTQDTypePolicy name * gptqdtypepolicy fix * set correct _name parameter for GPTQDTypePolicy * Address reviews * renames self.gptq to self._is_gptq_calibrated * fix typo
1 parent a542c5f commit d19fece

File tree

20 files changed

+1231
-325
lines changed

20 files changed

+1231
-325
lines changed

keras/api/_tf_keras/keras/dtype_policies/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from keras.src.dtype_policies.dtype_policy import (
1212
FloatDTypePolicy as FloatDTypePolicy,
1313
)
14+
from keras.src.dtype_policies.dtype_policy import (
15+
GPTQDTypePolicy as GPTQDTypePolicy,
16+
)
1417
from keras.src.dtype_policies.dtype_policy import (
1518
QuantizedDTypePolicy as QuantizedDTypePolicy,
1619
)

keras/api/dtype_policies/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from keras.src.dtype_policies.dtype_policy import (
1212
FloatDTypePolicy as FloatDTypePolicy,
1313
)
14+
from keras.src.dtype_policies.dtype_policy import (
15+
GPTQDTypePolicy as GPTQDTypePolicy,
16+
)
1417
from keras.src.dtype_policies.dtype_policy import (
1518
QuantizedDTypePolicy as QuantizedDTypePolicy,
1619
)

keras/src/dtype_policies/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from keras.src.dtype_policies.dtype_policy import QUANTIZATION_MODES
55
from keras.src.dtype_policies.dtype_policy import DTypePolicy
66
from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy
7+
from keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy
78
from keras.src.dtype_policies.dtype_policy import QuantizedDTypePolicy
89
from keras.src.dtype_policies.dtype_policy import QuantizedFloat8DTypePolicy
910
from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
@@ -14,6 +15,7 @@
1415
QuantizedDTypePolicy,
1516
QuantizedFloat8DTypePolicy,
1617
DTypePolicyMap,
18+
GPTQDTypePolicy,
1719
}
1820
ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}
1921

keras/src/dtype_policies/dtype_policy.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,94 @@ def get_config(self):
288288
return config
289289

290290

291+
@keras_export("keras.dtype_policies.GPTQDTypePolicy")
292+
class GPTQDTypePolicy(QuantizedDTypePolicy):
293+
"""Quantized dtype policy for GPTQ quantization.
294+
295+
This policy helps propagate quantization settings for GPTQ
296+
when loading a GPTQ quantized model in Keras format.
297+
298+
Args:
299+
mode: The quantization mode. This should be a string in the format
300+
`"gptq/<weight_bits>/<group_size>"`.
301+
- `"gptq"`: The identifier for the quantization algorithm.
302+
- `<weight_bits>`: Number of bits to quantize weights to.
303+
Supported values are 2, 3, 4, and 8.
304+
- `<group_size>`: The group size for quantization. Supported
305+
values are -1 (for whole-tensor quantization) or any
306+
positive integer. Typically a smaller group size leads
307+
to better accuracy but slower speed.
308+
Example: `"gptq/4/128"`.
309+
source_name: The source dtype policy name, e.g. "float32".
310+
"""
311+
312+
def __init__(
313+
self,
314+
mode,
315+
source_name=None,
316+
):
317+
parts = mode.split("/")
318+
expected_format = "'gptq/<weight_bits>/<group_size>'"
319+
320+
# Validate format
321+
if len(parts) != 3 or parts[0] != "gptq":
322+
raise ValueError(
323+
"Invalid mode for GPTQDTypePolicy. Expected format "
324+
f"{expected_format}, but got '{mode}'."
325+
)
326+
327+
# Validate and cast weight_bits and group_size
328+
try:
329+
weight_bits = int(parts[1])
330+
group_size = int(parts[2])
331+
except ValueError:
332+
raise ValueError(
333+
"Invalid mode for GPTQDTypePolicy. <weight_bits> and "
334+
"<group_size> must be integers. Expected format "
335+
f"{expected_format}, but got '{mode}'."
336+
)
337+
338+
# Validate supported values
339+
if weight_bits not in [2, 3, 4, 8]:
340+
raise ValueError(
341+
"Invalid weight_bits in mode. Supported values are "
342+
f"2, 3, 4, and 8, but got {weight_bits} from '{mode}'."
343+
)
344+
345+
if group_size < -1 or group_size == 0:
346+
raise ValueError(
347+
"Invalid group_size in mode. Supported values are "
348+
"-1 (whole-tensor) or a positive integer, "
349+
f"but got {group_size} from '{mode}'."
350+
)
351+
352+
base_mode = parts[0]
353+
super().__init__(
354+
mode=base_mode,
355+
source_name=source_name,
356+
)
357+
358+
self._name = f"{mode}_from_{source_name}"
359+
self.mode = base_mode
360+
self.weight_bits = weight_bits
361+
self.group_size = group_size
362+
363+
def __eq__(self, other):
364+
if super().__eq__(other) is False:
365+
return False
366+
return (
367+
self.weight_bits == other.weight_bits
368+
and self.group_size == other.group_size
369+
)
370+
371+
def get_config(self):
372+
config = super().get_config()
373+
# Reconstruct the full mode string for serialization
374+
mode = f"{self.mode}/{self.weight_bits}/{self.group_size}"
375+
config.update({"mode": mode})
376+
return config
377+
378+
291379
@keras_export(
292380
[
293381
"keras.config.set_dtype_policy",
@@ -352,6 +440,8 @@ def _get_quantized_dtype_policy_by_str(policy):
352440
mode, source_name = split_name
353441
if policy.startswith("int8") or policy.startswith("int4"):
354442
return QuantizedDTypePolicy(mode, source_name)
443+
elif policy.startswith("gptq"):
444+
return GPTQDTypePolicy(mode, source_name)
355445
elif policy.startswith("float8"):
356446
return QuantizedFloat8DTypePolicy(mode, source_name)
357447
else:

keras/src/dtype_policies/dtype_policy_test.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from keras.src.dtype_policies.dtype_policy import QuantizedFloat8DTypePolicy
1010
from keras.src.dtype_policies.dtype_policy import dtype_policy
1111
from keras.src.dtype_policies.dtype_policy import set_dtype_policy
12+
from keras.src.quantizers.gptq_config import GPTQConfig
1213
from keras.src.testing import test_case
1314

1415

@@ -691,3 +692,55 @@ def test_set_policy_none(self):
691692
"""Test setting the policy to None."""
692693
with self.assertRaisesRegex(ValueError, "Invalid `policy` argument"):
693694
set_dtype_policy(None)
695+
696+
697+
class GPTQConfigErrorHandlingTest(test_case.TestCase):
698+
"""Test error handling in GPTQConfig."""
699+
700+
def test_invalid_weight_bits(self):
701+
with self.assertRaisesRegex(ValueError, "Unsupported weight_bits"):
702+
GPTQConfig(
703+
dataset=None,
704+
tokenizer=None,
705+
weight_bits=5,
706+
)
707+
708+
def test_negative_num_samples(self):
709+
with self.assertRaisesRegex(
710+
ValueError, "num_samples must be a positive integer."
711+
):
712+
GPTQConfig(
713+
dataset=None,
714+
tokenizer=None,
715+
num_samples=-10,
716+
)
717+
718+
def test_zero_sequence_length(self):
719+
with self.assertRaisesRegex(
720+
ValueError, "sequence_length must be a positive integer."
721+
):
722+
GPTQConfig(
723+
dataset=None,
724+
tokenizer=None,
725+
sequence_length=0,
726+
)
727+
728+
def test_invalid_hessian_damping(self):
729+
with self.assertRaisesRegex(
730+
ValueError, "hessian_damping must be between 0 and 1."
731+
):
732+
GPTQConfig(
733+
dataset=None,
734+
tokenizer=None,
735+
hessian_damping=1.5,
736+
)
737+
738+
def test_invalid_group_size(self):
739+
with self.assertRaisesRegex(
740+
ValueError, "Invalid group_size. Supported values are -1"
741+
):
742+
GPTQConfig(
743+
dataset=None,
744+
tokenizer=None,
745+
group_size=0,
746+
)

0 commit comments

Comments
 (0)