Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.

#include "cutlass_sm100_mla/device/sm100_mla.hpp"
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
#include "utils.h"

// clang-format off
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
Expand Down Expand Up @@ -217,6 +218,10 @@ void cutlass_mla_decode(
torch::Tensor const& workspace,
double sm_scale,
int64_t num_kv_splits) {
auto sm_version = getSMVersion();
// On SM103a, half of the accuracy tests are failing.
TORCH_CHECK(sm_version == 100, "cutlass_mla_decode is only supported on compute capability 10.0, but found sm version ", sm_version);

auto in_dtype = q_nope.dtype();
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device());
Expand Down
5 changes: 3 additions & 2 deletions sgl-kernel/tests/test_cutlass_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size
from torch import Tensor

if torch.cuda.get_device_capability() < (10, 0):
# Disable tests on SM103 until the accuracy issues are fixed.
if torch.cuda.get_device_capability() != (10, 0):
pytest.skip(
reason="Cutlass MLA Requires compute capability of 10 or above.",
reason="Cutlass MLA Requires compute capability of 10.",
allow_module_level=True,
)

Expand Down
Loading