Skip to content

Commit 7911194

Browse files
committed
Disable kernel cutlass_mla_decode on SM103
Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
1 parent 93088b6 commit 7911194

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

sgl-kernel/csrc/attention/cutlass_mla_kernel.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ limitations under the License.
2626

2727
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
2828
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
29+
#include "utils.h"
2930

3031
// clang-format off
3132
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
@@ -217,6 +218,10 @@ void cutlass_mla_decode(
217218
torch::Tensor const& workspace,
218219
double sm_scale,
219220
int64_t num_kv_splits) {
221+
auto sm_version = getSMVersion();
222+
// On SM103a, half of the accuracy tests are failing.
223+
TORCH_CHECK(sm_version == 100, "cutlass_mla_decode is only supported on compute capability 10.0, but found sm version ", sm_version);
224+
220225
auto in_dtype = q_nope.dtype();
221226
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
222227
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device());

sgl-kernel/tests/test_cutlass_mla.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size
55
from torch import Tensor
66

7-
if torch.cuda.get_device_capability() < (10, 0):
7+
# Disable tests on SM103 until the accuracy issues are fixed.
8+
if torch.cuda.get_device_capability() != (10, 0):
89
pytest.skip(
9-
reason="Cutlass MLA Requires compute capability of 10 or above.",
10+
reason="Cutlass MLA Requires compute capability of 10.",
1011
allow_module_level=True,
1112
)
1213

0 commit comments

Comments
 (0)