Skip to content

Commit 904d0a0

Browse files
convert g_idx to float32 since int32 cannot reside on GPU in TF
1 parent ade009e commit 904d0a0

File tree

4 files changed

+9
-7
lines changed

4 files changed

+9
-7
lines changed

keras/src/layers/core/dense.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def _gptq_build(self, kernel_shape, config):
440440
name="g_idx",
441441
shape=(self.kernel_shape[0],),
442442
initializer="zeros",
443-
dtype="int32",
443+
dtype="float32",
444444
trainable=False,
445445
)
446446

keras/src/layers/core/einsum_dense.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ def _gptq_build(self, kernel_shape, config):
549549
name="g_idx",
550550
shape=(rows,),
551551
initializer="zeros",
552-
dtype="int32",
552+
dtype="float32",
553553
trainable=False,
554554
)
555555

keras/src/quantizers/gptq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def gptq_quantize_matrix(
247247
# g_idx in permuted domain
248248
g_idx = ops.arange(0, in_features, dtype="int32")
249249
g_idx = ops.divide(g_idx, base_group)
250-
g_idx = ops.cast(g_idx, "int32")
250+
g_idx = ops.cast(g_idx, "float32")
251251

252252
# Map group indices and quantized weights back to original column order
253253
if activation_order:

keras/src/quantizers/quantizers.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -878,8 +878,9 @@ def quantize_with_sz_map(weights_matrix, scale, zero, g_idx, maxq):
878878
A tensor with the same shape as `weights_matrix` containing the
879879
quantized weights produced using the provided group parameters.
880880
"""
881-
scale_cols = ops.take(scale, g_idx, axis=1) # [out_features, in_features]
882-
zero_cols = ops.take(zero, g_idx, axis=1) # [out_features, in_features]
881+
groups = ops.cast(g_idx, "int32")
882+
scale_cols = ops.take(scale, groups, axis=1) # [out_features, in_features]
883+
zero_cols = ops.take(zero, groups, axis=1) # [out_features, in_features]
883884

884885
# Quantize elementwise, then cast to int
885886
return quantize_with_zero_point(weights_matrix, scale_cols, zero_cols, maxq)
@@ -907,8 +908,9 @@ def dequantize_with_sz_map(weights_matrix, scale, zero, g_idx):
907908
dequantized weights produced using the provided group parameters.
908909
"""
909910
# Map group indices to scales and zeros
910-
scales_mapped = ops.take(scale, g_idx, axis=1)
911-
zeros_mapped = ops.take(zero, g_idx, axis=1)
911+
groups = ops.cast(g_idx, "int32")
912+
scales_mapped = ops.take(scale, groups, axis=1)
913+
zeros_mapped = ops.take(zero, groups, axis=1)
912914
zeros_mapped = ops.cast(zeros_mapped, scales_mapped.dtype)
913915

914916
quantized = ops.multiply(

0 commit comments

Comments
 (0)