@@ -288,6 +288,94 @@ def get_config(self):
288
288
return config
289
289
290
290
291
+ @keras_export ("keras.dtype_policies.GPTQDTypePolicy" )
292
+ class GPTQDTypePolicy (QuantizedDTypePolicy ):
293
+ """Quantized dtype policy for GPTQ quantization.
294
+
295
+ This policy helps propagate quantization settings for GPTQ
296
+ when loading a GPTQ quantized model in Keras format.
297
+
298
+ Args:
299
+ mode: The quantization mode. This should be a string in the format
300
+ `"gptq/<weight_bits>/<group_size>"`.
301
+ - `"gptq"`: The identifier for the quantization algorithm.
302
+ - `<weight_bits>`: Number of bits to quantize weights to.
303
+ Supported values are 2, 3, 4, and 8.
304
+ - `<group_size>`: The group size for quantization. Supported
305
+ values are -1 (for whole-tensor quantization) or any
306
+ positive integer. Typically a smaller group size leads
307
+ to better accuracy but slower speed.
308
+ Example: `"gptq/4/128"`.
309
+ source_name: The source dtype policy name, e.g. "float32".
310
+ """
311
+
312
+ def __init__ (
313
+ self ,
314
+ mode ,
315
+ source_name = None ,
316
+ ):
317
+ parts = mode .split ("/" )
318
+ expected_format = "'gptq/<weight_bits>/<group_size>'"
319
+
320
+ # Validate format
321
+ if len (parts ) != 3 or parts [0 ] != "gptq" :
322
+ raise ValueError (
323
+ "Invalid mode for GPTQDTypePolicy. Expected format "
324
+ f"{ expected_format } , but got '{ mode } '."
325
+ )
326
+
327
+ # Validate and cast weight_bits and group_size
328
+ try :
329
+ weight_bits = int (parts [1 ])
330
+ group_size = int (parts [2 ])
331
+ except ValueError :
332
+ raise ValueError (
333
+ "Invalid mode for GPTQDTypePolicy. <weight_bits> and "
334
+ "<group_size> must be integers. Expected format "
335
+ f"{ expected_format } , but got '{ mode } '."
336
+ )
337
+
338
+ # Validate supported values
339
+ if weight_bits not in [2 , 3 , 4 , 8 ]:
340
+ raise ValueError (
341
+ "Invalid weight_bits in mode. Supported values are "
342
+ f"2, 3, 4, and 8, but got { weight_bits } from '{ mode } '."
343
+ )
344
+
345
+ if group_size < - 1 or group_size == 0 :
346
+ raise ValueError (
347
+ "Invalid group_size in mode. Supported values are "
348
+ "-1 (whole-tensor) or a positive integer, "
349
+ f"but got { group_size } from '{ mode } '."
350
+ )
351
+
352
+ base_mode = parts [0 ]
353
+ super ().__init__ (
354
+ mode = base_mode ,
355
+ source_name = source_name ,
356
+ )
357
+
358
+ self ._name = f"{ mode } _from_{ source_name } "
359
+ self .mode = base_mode
360
+ self .weight_bits = weight_bits
361
+ self .group_size = group_size
362
+
363
+ def __eq__ (self , other ):
364
+ if super ().__eq__ (other ) is False :
365
+ return False
366
+ return (
367
+ self .weight_bits == other .weight_bits
368
+ and self .group_size == other .group_size
369
+ )
370
+
371
+ def get_config (self ):
372
+ config = super ().get_config ()
373
+ # Reconstruct the full mode string for serialization
374
+ mode = f"{ self .mode } /{ self .weight_bits } /{ self .group_size } "
375
+ config .update ({"mode" : mode })
376
+ return config
377
+
378
+
291
379
@keras_export (
292
380
[
293
381
"keras.config.set_dtype_policy" ,
@@ -352,6 +440,8 @@ def _get_quantized_dtype_policy_by_str(policy):
352
440
mode , source_name = split_name
353
441
if policy .startswith ("int8" ) or policy .startswith ("int4" ):
354
442
return QuantizedDTypePolicy (mode , source_name )
443
+ elif policy .startswith ("gptq" ):
444
+ return GPTQDTypePolicy (mode , source_name )
355
445
elif policy .startswith ("float8" ):
356
446
return QuantizedFloat8DTypePolicy (mode , source_name )
357
447
else :
0 commit comments