Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions python/sglang/srt/layers/attention/flashattention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def __init__(
speculative_step_id=0,
topk=0,
speculative_num_steps=0,
fa_impl_ver=3,
):
super().__init__()

Expand Down Expand Up @@ -338,6 +339,8 @@ def __init__(
)
self.speculative_step_id = speculative_step_id

self.fa_impl_ver = fa_impl_ver

# Local attention settings
self.attention_chunk_size = (
model_runner.attention_chunk_size
Expand Down Expand Up @@ -712,6 +715,8 @@ def forward_extend(

# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs = {}
if self.fa_impl_ver != 3:
kwargs["ver"] = self.fa_impl_ver
if sinks is not None:
kwargs["sinks"] = sinks

Expand All @@ -738,6 +743,7 @@ def forward_extend(

# Use Flash Attention for prefill
if not self.use_mla:
assert self.fa_impl_ver in [3], "Only FA3 support here"
# Do multi-head attention
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id
Expand Down Expand Up @@ -830,6 +836,7 @@ def forward_extend(
softmax_scale=layer.scaling,
causal=False,
return_softmax_lse=True,
**kwargs,
)
else:
# MHA for extend part of sequence without attending prefix kv cache
Expand All @@ -844,13 +851,15 @@ def forward_extend(
softmax_scale=layer.scaling,
causal=True,
return_softmax_lse=forward_batch.mha_return_lse,
**kwargs,
)
if forward_batch.mha_return_lse:
output, lse, *rest = output
lse = torch.transpose(lse, 0, 1).contiguous()
return output, lse
return output
else:
assert self.fa_impl_ver in [3], "Only FA3 support here"
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
layer.layer_id
Expand Down Expand Up @@ -939,6 +948,7 @@ def forward_decode(
k_rope: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fa_impl_ver in [3], "Only FA3 support decoding"
if k is not None:
assert v is not None
if save_kv_cache:
Expand Down Expand Up @@ -985,6 +995,8 @@ def forward_decode(

# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs = {}
if self.fa_impl_ver != 3:
kwargs["ver"] = self.fa_impl_ver
if sinks is not None:
kwargs["sinks"] = sinks

Expand Down
10 changes: 10 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ def model_specific_adjustment(self):
"aiter",
"flashinfer",
"fa3",
"fa4",
"triton",
"flashmla",
"cutlass_mla",
Expand Down Expand Up @@ -1548,6 +1549,15 @@ def _get_attention_backend_from_str(self, backend_str: str):
)

return FlashAttentionBackend(self)
elif backend_str == "fa4":
assert (
self.use_mla_backend
), "FlashAttention v4 Support is at an early stage, only MLA model supported now"
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
)

return FlashAttentionBackend(self, fa_impl_ver=4)
elif backend_str == "cutlass_mla":
from sglang.srt.layers.attention.cutlass_mla_backend import (
CutlassMLABackend,
Expand Down
8 changes: 7 additions & 1 deletion python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,9 @@ def _dispatch_mla_subtype():
return AttnForwardMethod.MHA_CHUNKED_KV
else:
return _dispatch_mla_subtype()
elif attention_backend == "fa4":
# TODO: FA4 support is at an early stage, this is a hacky way to support MLA models.
return AttnForwardMethod.MHA_CHUNKED_KV
elif attention_backend == "aiter":
if (
forward_batch.forward_mode.is_extend()
Expand Down Expand Up @@ -1732,7 +1735,10 @@ def forward_normal_chunked_kv_prepare(
)

def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
has_extend_prefix = any(forward_batch.extend_prefix_lens_cpu)
has_extend_prefix = forward_batch.extend_prefix_lens_cpu and any(
forward_batch.extend_prefix_lens_cpu
)

# Only initialize the info once
if has_extend_prefix and forward_batch.num_prefix_chunks is None:
forward_batch.prepare_chunked_prefix_cache_info(q.device)
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
# NVIDIA specific
"cutlass_mla",
"fa3",
"fa4",
"flashinfer",
"flashmla",
"trtllm_mla",
Expand Down
19 changes: 19 additions & 0 deletions sgl-kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ FetchContent_Declare(
)
FetchContent_Populate(repo-flash-attention)

# flash-attention origin
FetchContent_Declare(
repo-flash-attention-origin
GIT_REPOSITORY https://github.com/Dao-AILab/flash-attention.git
GIT_TAG 203b9b3dba39d5d08dffb49c09aa622984dff07d
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-flash-attention-origin)

# mscclpp
FetchContent_Declare(
repo-mscclpp
Expand Down Expand Up @@ -499,3 +508,13 @@ install(DIRECTORY "${repo-triton_SOURCE_DIR}/python/triton_kernels/triton_kernel
DESTINATION "triton_kernels"
PATTERN ".git*" EXCLUDE
PATTERN "__pycache__" EXCLUDE)

# flash attention 4
# TODO: find a better install condition.
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
# flash_attn/cute
install(DIRECTORY "${repo-flash-attention-origin_SOURCE_DIR}/flash_attn/cute/"
DESTINATION "flash_attn/cute"
PATTERN ".git*" EXCLUDE
PATTERN "__pycache__" EXCLUDE)
endif()
5 changes: 5 additions & 0 deletions sgl-kernel/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ classifiers = [
]
dependencies = []

[project.optional-dependencies]
blackwell = [
"nvidia-cutlass-dsl==4.1.0",
]

[project.urls]
"Homepage" = "https://github.com/sgl-project/sglang/tree/main/sgl-kernel"
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
Expand Down
Loading
Loading