Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/dtype_policies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from keras.src.dtype_policies.dtype_policy import (
FloatDTypePolicy as FloatDTypePolicy,
)
from keras.src.dtype_policies.dtype_policy import (
GPTQDTypePolicy as GPTQDTypePolicy,
)
from keras.src.dtype_policies.dtype_policy import (
QuantizedDTypePolicy as QuantizedDTypePolicy,
)
Expand Down
3 changes: 3 additions & 0 deletions keras/api/dtype_policies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from keras.src.dtype_policies.dtype_policy import (
FloatDTypePolicy as FloatDTypePolicy,
)
from keras.src.dtype_policies.dtype_policy import (
GPTQDTypePolicy as GPTQDTypePolicy,
)
from keras.src.dtype_policies.dtype_policy import (
QuantizedDTypePolicy as QuantizedDTypePolicy,
)
Expand Down
2 changes: 2 additions & 0 deletions keras/src/dtype_policies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from keras.src.dtype_policies.dtype_policy import QUANTIZATION_MODES
from keras.src.dtype_policies.dtype_policy import DTypePolicy
from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy
from keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy
from keras.src.dtype_policies.dtype_policy import QuantizedDTypePolicy
from keras.src.dtype_policies.dtype_policy import QuantizedFloat8DTypePolicy
from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
Expand All @@ -14,6 +15,7 @@
QuantizedDTypePolicy,
QuantizedFloat8DTypePolicy,
DTypePolicyMap,
GPTQDTypePolicy,
}
ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}

Expand Down
51 changes: 51 additions & 0 deletions keras/src/dtype_policies/dtype_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,55 @@ def get_config(self):
return config


@keras_export("keras.dtype_policies.GPTQDTypePolicy")
class GPTQDTypePolicy(QuantizedDTypePolicy):
"""Quantized dtype policy for GPTQ quantization.

This policy helps propagate quantization settings for GPTQ
when loading a GPTQ quantized model in Keras format.

Args:
mode: The quantization mode, "gptq".
source_name: The source dtype policy name, e.g. "float32".
weight_bits: Number of bits to quantize weights to. Supported values
are 2, 3, 4, and 8.
group_size: The group size for quantization. Supported values are
-1 (for whole-tensor quantization) or any positive integer.
Typically a smaller group size leads to better accuracy but
slower speed.
"""

def __init__(
self,
mode,
source_name=None,
):
mode, weight_bits, group_size = mode.split("/")
super().__init__(
mode=mode,
source_name=source_name,
)

self._name = f"{mode}/{weight_bits}/{group_size}_from_{source_name}"
self.mode = mode
self.weight_bits = int(weight_bits)
self.group_size = int(group_size)

def __eq__(self, other):
if super().__eq__(other) is False:
return False
return (
self.weight_bits == other.weight_bits
and self.group_size == other.group_size
)

def get_config(self):
config = super().get_config()
mode = f"{self.mode}/{self.weight_bits}/{self.group_size}"
config.update({"mode": mode})
return config


@keras_export(
[
"keras.config.set_dtype_policy",
Expand Down Expand Up @@ -352,6 +401,8 @@ def _get_quantized_dtype_policy_by_str(policy):
mode, source_name = split_name
if policy.startswith("int8") or policy.startswith("int4"):
return QuantizedDTypePolicy(mode, source_name)
elif policy.startswith("gptq"):
return GPTQDTypePolicy(mode, source_name)
elif policy.startswith("float8"):
return QuantizedFloat8DTypePolicy(mode, source_name)
else:
Expand Down
Loading