Skip to content

Commit cd48a2c

Browse files
[temp] serialization stuff
1 parent 88a7542 commit cd48a2c

File tree

4 files changed

+144
-23
lines changed

4 files changed

+144
-23
lines changed

keras/src/layers/core/dense.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def __init__(
9090
bias_constraint=None,
9191
lora_rank=None,
9292
lora_alpha=None,
93+
quantization_config=None,
9394
**kwargs,
9495
):
9596
super().__init__(activity_regularizer=activity_regularizer, **kwargs)
@@ -107,11 +108,16 @@ def __init__(
107108
self.lora_enabled = False
108109
self.input_spec = InputSpec(min_ndim=2)
109110
self.supports_masking = True
111+
self.quantization_config = quantization_config
110112

111113
def build(self, input_shape):
112114
kernel_shape = (input_shape[-1], self.units)
113115
if self.quantization_mode:
114-
self.quantized_build(kernel_shape, mode=self.quantization_mode)
116+
self.quantized_build(
117+
kernel_shape,
118+
mode=self.quantization_mode,
119+
config=self.quantization_config,
120+
)
115121
if self.quantization_mode not in ("int8", "int4", "gptq"):
116122
# If the layer is quantized to int8 or int4, `self._kernel` will be
117123
# added in `self._int8_build` or `_int4_build`. Therefore, we skip
@@ -238,6 +244,11 @@ def save_own_variables(self, store):
238244
target_variables.append(self.kernel_amax_history)
239245
target_variables.append(self.outputs_grad_scale)
240246
target_variables.append(self.outputs_grad_amax_history)
247+
elif self.quantization_mode == "gptq":
248+
target_variables.append(self.quantized_kernel)
249+
target_variables.append(self.kernel_scale)
250+
target_variables.append(self.kernel_zero)
251+
target_variables.append(self.g_idx)
241252
else:
242253
raise self._quantization_mode_error(self.quantization_mode)
243254
for i, variable in enumerate(target_variables):
@@ -264,6 +275,11 @@ def load_own_variables(self, store):
264275
target_variables.append(self.kernel_amax_history)
265276
target_variables.append(self.outputs_grad_scale)
266277
target_variables.append(self.outputs_grad_amax_history)
278+
elif self.quantization_mode == "gptq":
279+
target_variables.append(self.quantized_kernel)
280+
target_variables.append(self.kernel_scale)
281+
target_variables.append(self.kernel_zero)
282+
target_variables.append(self.g_idx)
267283
else:
268284
raise self._quantization_mode_error(self.quantization_mode)
269285
for i, variable in enumerate(target_variables):
@@ -289,11 +305,25 @@ def get_config(self):
289305
"kernel_constraint": constraints.serialize(self.kernel_constraint),
290306
"bias_constraint": constraints.serialize(self.bias_constraint),
291307
}
308+
if self.quantization_config:
309+
config["quantization_config"] = self.quantization_config
292310
if self.lora_rank:
293311
config["lora_rank"] = self.lora_rank
294312
config["lora_alpha"] = self.lora_alpha
295313
return {**base_config, **config}
296314

315+
@classmethod
316+
def from_config(cls, config, custom_objects=None):
317+
config = config.copy()
318+
from keras.src.saving import deserialize_keras_object
319+
320+
if "quantization_config" in config:
321+
config["quantization_config"] = deserialize_keras_object(
322+
config["quantization_config"],
323+
custom_objects=custom_objects,
324+
)
325+
return cls(**config)
326+
297327
def _check_load_own_variables(self, store):
298328
all_vars = self._trainable_variables + self._non_trainable_variables
299329
if len(store.keys()) != len(all_vars):
@@ -328,19 +358,19 @@ def _check_load_own_variables(self, store):
328358
f"Expected: {[v.name for v in all_vars]}"
329359
)
330360

331-
# Quantization-related (int8 and float8) methods
332-
333-
def quantized_build(self, kernel_shape, mode, config):
361+
def quantized_build(self, input_shape, mode, config=None):
334362
if mode == "int8":
335-
self._int8_build(kernel_shape)
363+
self._int8_build(input_shape)
336364
elif mode == "int4":
337-
self._int4_build(kernel_shape)
365+
self._int4_build(input_shape)
338366
elif mode == "float8":
339367
self._float8_build()
340368
elif mode == "gptq":
341-
self._gptq_build(kernel_shape, config)
369+
self._gptq_build(input_shape, config)
342370
else:
343371
raise self._quantization_mode_error(mode)
372+
if config is not None:
373+
self.quantization_config = config
344374
self._is_quantized = True
345375

346376
def _int8_build(self, kernel_shape):
@@ -371,10 +401,10 @@ def _gptq_build(self, kernel_shape, config):
371401
trainable=False,
372402
)
373403

374-
if config.group_size == -1:
404+
if config["group_size"] == -1:
375405
n_groups = 1
376406
else:
377-
n_groups = ceil(self.kernel_shape[0] / config.group_size)
407+
n_groups = ceil(self.kernel_shape[0] / config["group_size"])
378408
self.kernel_scale = self.add_weight(
379409
name="kernel_scale",
380410
shape=(self.units, n_groups),
@@ -761,7 +791,7 @@ def _get_kernel_with_merged_lora(self):
761791
`kernel_scale`: The quantization scale for the merged kernel.
762792
This is `None` if the layer is not quantized.
763793
"""
764-
if self.dtype_policy.quantization_mode is None:
794+
if self.dtype_policy.quantization_mode in (None, "gptq"):
765795
return self.kernel, None
766796

767797
kernel_value = self._kernel

keras/src/layers/core/einsum_dense.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def __init__(
133133
bias_constraint=None,
134134
lora_rank=None,
135135
lora_alpha=None,
136+
quantization_config=None,
136137
**kwargs,
137138
):
138139
super().__init__(**kwargs)
@@ -152,6 +153,7 @@ def __init__(
152153
self.lora_rank = lora_rank
153154
self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank
154155
self.lora_enabled = False
156+
self.quantization_config = quantization_config
155157

156158
def build(self, input_shape):
157159
shape_data = _analyze_einsum_string(
@@ -164,7 +166,11 @@ def build(self, input_shape):
164166
self.full_output_shape = tuple(full_output_shape)
165167
self.input_spec = InputSpec(ndim=len(input_shape))
166168
if self.quantization_mode is not None:
167-
self.quantized_build(kernel_shape, mode=self.quantization_mode)
169+
self.quantized_build(
170+
kernel_shape,
171+
mode=self.quantization_mode,
172+
config=self.quantization_config,
173+
)
168174
# Skip creating a duplicate kernel variable when the layer is already
169175
# quantized to int8 or int4, because `quantized_build` has created the
170176
# appropriate kernel variable. For other modes (e.g., float8 or no
@@ -297,6 +303,11 @@ def save_own_variables(self, store):
297303
target_variables.append(self.kernel_amax_history)
298304
target_variables.append(self.outputs_grad_scale)
299305
target_variables.append(self.outputs_grad_amax_history)
306+
elif self.quantization_mode == "gptq":
307+
target_variables.append(self.quantized_kernel)
308+
target_variables.append(self.kernel_scale)
309+
target_variables.append(self.kernel_zero)
310+
target_variables.append(self.g_idx)
300311
else:
301312
raise self._quantization_mode_error(self.quantization_mode)
302313
for i, variable in enumerate(target_variables):
@@ -323,6 +334,11 @@ def load_own_variables(self, store):
323334
target_variables.append(self.kernel_amax_history)
324335
target_variables.append(self.outputs_grad_scale)
325336
target_variables.append(self.outputs_grad_amax_history)
337+
elif self.quantization_mode == "gptq":
338+
target_variables.append(self.quantized_kernel)
339+
target_variables.append(self.kernel_scale)
340+
target_variables.append(self.kernel_zero)
341+
target_variables.append(self.g_idx)
326342
else:
327343
raise self._quantization_mode_error(self.quantization_mode)
328344
for i, variable in enumerate(target_variables):
@@ -352,11 +368,25 @@ def get_config(self):
352368
"kernel_constraint": constraints.serialize(self.kernel_constraint),
353369
"bias_constraint": constraints.serialize(self.bias_constraint),
354370
}
371+
if self.quantization_config:
372+
config["quantization_config"] = self.quantization_config
355373
if self.lora_rank:
356374
config["lora_rank"] = self.lora_rank
357375
config["lora_alpha"] = self.lora_alpha
358376
return {**base_config, **config}
359377

378+
@classmethod
379+
def from_config(cls, config, custom_objects=None):
380+
config = config.copy()
381+
from keras.src.saving import deserialize_keras_object
382+
383+
if "quantization_config" in config:
384+
config["quantization_config"] = deserialize_keras_object(
385+
config["quantization_config"],
386+
custom_objects=custom_objects,
387+
)
388+
return cls(**config)
389+
360390
def _check_load_own_variables(self, store):
361391
all_vars = self._trainable_variables + self._non_trainable_variables
362392
if len(store.keys()) != len(all_vars):
@@ -391,19 +421,19 @@ def _check_load_own_variables(self, store):
391421
f"Expected: {[v.name for v in all_vars]}"
392422
)
393423

394-
# Quantization-related (int8 and float8) methods
395-
396-
def quantized_build(self, kernel_shape, mode, config):
424+
def quantized_build(self, input_shape, mode, config=None):
397425
if mode == "int8":
398-
self._int8_build(kernel_shape)
426+
self._int8_build(input_shape)
399427
elif mode == "int4":
400-
self._int4_build(kernel_shape)
428+
self._int4_build(input_shape)
401429
elif mode == "float8":
402430
self._float8_build()
403431
elif mode == "gptq":
404-
self._gptq_build(kernel_shape, config=config)
432+
self._gptq_build(input_shape, config=config)
405433
else:
406434
raise self._quantization_mode_error(mode)
435+
if config is not None:
436+
self.quantization_config = config
407437
self._is_quantized = True
408438

409439
def _int8_build(self, kernel_shape):
@@ -466,10 +496,10 @@ def _gptq_build(self, kernel_shape, config):
466496
else:
467497
raise ValueError("Could not determine row/column split.")
468498

469-
if config.group_size == -1:
499+
if config["group_size"] == -1:
470500
n_groups = 1
471501
else:
472-
n_groups = ceil(rows / config.group_size)
502+
n_groups = ceil(rows / config["group_size"])
473503

474504
if hasattr(self, "_set_quantization_info"):
475505
self._set_quantization_info()
@@ -965,7 +995,7 @@ def _get_kernel_with_merged_lora(self):
965995
This is `None` if the layer is not quantized.
966996
"""
967997
# If not a quantized layer, return the full-precision kernel directly.
968-
if self.dtype_policy.quantization_mode is None:
998+
if self.dtype_policy.quantization_mode in (None, "gptq"):
969999
return self.kernel, None
9701000

9711001
# If quantized but LoRA is not enabled, return the original quantized

keras/src/quantizers/gptq_config.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,24 @@ def __init__(
157157
self.group_size = group_size
158158
self.symmetric = symmetric
159159
self.activation_order = activation_order
160+
161+
def __getitem__(self, key):
162+
return getattr(self, key)
163+
164+
def get_config(self):
165+
return {
166+
"dataset": self.dataset,
167+
"tokenizer": self.tokenizer,
168+
"num_samples": self.num_samples,
169+
"per_channel": self.per_channel,
170+
"sequence_length": self.sequence_length,
171+
"hessian_damping": self.hessian_damping,
172+
"weight_bits": self.weight_bits,
173+
"group_size": self.group_size,
174+
"symmetric": self.symmetric,
175+
"activation_order": self.activation_order,
176+
}
177+
178+
@classmethod
179+
def from_config(cls, config):
180+
return cls(**config)

keras/src/quantizers/gptq_test.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from collections.abc import Callable
23

34
import numpy as np
@@ -9,6 +10,7 @@
910
from keras.src import backend
1011
from keras.src import layers
1112
from keras.src import ops
13+
from keras.src import saving
1214
from keras.src import testing
1315
from keras.src.quantizers.gptq import GPTQ
1416
from keras.src.quantizers.gptq import _stable_permutation
@@ -410,16 +412,31 @@ def _get_sequence_classifier():
410412
num_heads = 4
411413
ff_dim = 32
412414

415+
@keras.saving.register_keras_serializable(package="GPTQTest")
413416
class SimpleTransformerBlock(layers.Layer):
414417
def __init__(self, embed_dim, num_heads, ff_dim, **kwargs):
415418
super().__init__(**kwargs)
419+
self.embed_dim = embed_dim
420+
self.num_heads = num_heads
421+
self.ff_dim = ff_dim
422+
416423
self.att = layers.MultiHeadAttention(
417-
num_heads=num_heads, key_dim=embed_dim // num_heads
424+
num_heads=num_heads, key_dim=embed_dim // num_heads, **kwargs
418425
)
426+
sub_kwargs = kwargs.copy()
427+
sub_kwargs.pop("name", None)
419428
self.ffn = models.Sequential(
420429
[
421-
layers.Dense(ff_dim, activation="relu", use_bias=True),
422-
layers.Dense(embed_dim, use_bias=True),
430+
layers.Dense(
431+
ff_dim,
432+
activation="relu",
433+
use_bias=True,
434+
name="ffn_dense_1",
435+
**kwargs,
436+
),
437+
layers.Dense(
438+
embed_dim, use_bias=True, name="ffn_dense_2", **kwargs
439+
),
423440
]
424441
)
425442
self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
@@ -431,6 +448,19 @@ def call(self, inputs):
431448
ffn_output = self.ffn(out1)
432449
return self.layernorm2(out1 + ffn_output)
433450

451+
def get_config(self):
452+
base_config = super().get_config()
453+
config = {
454+
"embed_dim": self.embed_dim,
455+
"num_heads": self.num_heads,
456+
"ff_dim": self.ff_dim,
457+
}
458+
return {**base_config, **config}
459+
460+
@classmethod
461+
def from_config(cls, config):
462+
return cls(**config)
463+
434464
inputs = layers.Input(shape=(SEQ_LEN,), dtype="int32")
435465
x = layers.Embedding(VOCAB_SIZE, embed_dim)(inputs)
436466
x = SimpleTransformerBlock(embed_dim, num_heads, ff_dim)(x)
@@ -617,6 +647,16 @@ def test_quantize_gptq_combinations(self, dataset, config):
617647
)
618648
self.assertLessEqual(kl, 0.30, f"KL divergence too high: {kl:.3f}")
619649

650+
# Save and load the quantized model
651+
temp_filepath = os.path.join(
652+
self.get_temp_dir(), "quantized_model.keras"
653+
)
654+
model.save(temp_filepath)
655+
loaded = saving.load_model(temp_filepath)
656+
self.assertAllClose(
657+
model.predict(x_eval), loaded.predict(x_eval), atol=1e-6
658+
)
659+
620660
@parameterized.named_parameters(
621661
{
622662
"testcase_name": "gptq_with_invalid_config",

0 commit comments

Comments
 (0)