Skip to content

Commit 6624fec

Browse files
authored
Some gptq case could not be handled by ipex. but could be handle by triton (#3298)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
1 parent 5284b5c commit 6624fec

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

server/text_generation_server/layers/gptq/__init__.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,7 @@
1212
WeightsLoader,
1313
DefaultWeightsLoader,
1414
)
15-
16-
if SYSTEM == "ipex":
17-
from .ipex import QuantLinear
18-
elif SYSTEM in {"cuda", "rocm"}:
19-
from .triton import QuantLinear
15+
import math
2016

2117

2218
@dataclass
@@ -70,6 +66,19 @@ def get_linear(self, bias: torch.Tensor):
7066

7167
return ExllamaQuantLinear(self, bias)
7268
else:
69+
if SYSTEM == "ipex" and not (
70+
self.device.type == "xpu"
71+
and (
72+
self.bits != 4
73+
or math.ceil(
74+
(self.qweight.shape[0] * 32 // self.bits) / self.groupsize
75+
)
76+
!= self.scales.shape[0]
77+
)
78+
):
79+
from .ipex import QuantLinear
80+
else:
81+
from .triton import QuantLinear
7382
return QuantLinear(
7483
self.qweight,
7584
self.qzeros,

server/text_generation_server/layers/gptq/triton.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,11 @@ def matmul_248_kernel(
202202

203203

204204
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
205-
with torch.cuda.device(input.device):
205+
with (
206+
torch.xpu.device(input.device)
207+
if torch.xpu.is_available()
208+
else torch.cuda.device(input.device)
209+
):
206210
output = torch.empty(
207211
(input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16
208212
)

0 commit comments

Comments
 (0)