Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/sglang/srt/layers/moe/cutlass_w4a8_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def cutlass_w4a8_moe(
k,
)

c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half)
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half)
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16)
c2 = torch.empty((m * topk, k), device=device, dtype=torch.bfloat16)

cutlass_w4a8_moe_mm(
c1,
Expand All @@ -166,7 +166,7 @@ def cutlass_w4a8_moe(
topk,
)

intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half)
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16)
silu_and_mul(c1, intermediate)

intermediate_q = torch.empty(
Expand Down Expand Up @@ -197,7 +197,7 @@ def cutlass_w4a8_moe(
src2dst,
local_topk_ids,
topk_weights,
num_experts,
total_num_experts,
topk,
k,
0,
Expand Down
11 changes: 5 additions & 6 deletions sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
a_q = torch.clamp((a / a_scale), -448.0, 448.0).to(torch.float8_e4m3fn).to(device)

# Create output tensor
c = torch.empty((m, n), dtype=torch.float16, device=device)
c = torch.empty((m, n), dtype=torch.bfloat16, device=device)
cutlass_w4a8_moe_mm(
c,
a_q,
Expand Down Expand Up @@ -211,7 +211,7 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
b_strides = a_strides
s_strides = c_strides

c_perm = torch.empty((batch_size, n), dtype=torch.float16, device=device)
c_perm = torch.empty((batch_size, n), dtype=torch.bfloat16, device=device)
cutlass_w4a8_moe_mm(
c_perm,
a_q_perm,
Expand Down Expand Up @@ -262,10 +262,9 @@ def ref_grouped_gemm(c, a, a_scale, w, w_scale, num_experts, experts_selection_r
continue
a = a_q[token_idx]

ref_w_scale_repeat = w_scale[i].repeat_interleave(128, dim=1).to(float)
ref_w = (w[i].to(float) * ref_w_scale_repeat).to(dtype)
c = torch.matmul(a.to(dtype), ref_w.t().to(dtype)) * a_scale
c = c.to(dtype)
ref_w_scale_repeat = w_scale[i].repeat_interleave(128, dim=1).to(torch.float32)
ref_w = w[i].to(torch.float32) * ref_w_scale_repeat
c = torch.matmul(a.to(torch.float32), ref_w.t()) * a_scale
c_ref[token_idx] = c.to(dtype)

return c_ref
Expand Down