Skip to content
12 changes: 12 additions & 0 deletions python/sglang/srt/layers/attention/flashattention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def __init__(
speculative_step_id=0,
topk=0,
speculative_num_steps=0,
fa_impl_ver=3,
):
super().__init__()

Expand Down Expand Up @@ -338,6 +339,8 @@ def __init__(
)
self.speculative_step_id = speculative_step_id

self.fa_impl_ver = fa_impl_ver

# Local attention settings
self.attention_chunk_size = (
model_runner.attention_chunk_size
Expand Down Expand Up @@ -712,6 +715,8 @@ def forward_extend(

# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs = {}
if self.fa_impl_ver != 3:
kwargs["ver"] = self.fa_impl_ver
if sinks is not None:
kwargs["sinks"] = sinks

Expand All @@ -738,6 +743,7 @@ def forward_extend(

# Use Flash Attention for prefill
if not self.use_mla:
assert self.fa_impl_ver in [3], "Only FA3 support here"
# Do multi-head attention
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id
Expand Down Expand Up @@ -830,6 +836,7 @@ def forward_extend(
softmax_scale=layer.scaling,
causal=False,
return_softmax_lse=True,
**kwargs,
)
else:
# MHA for extend part of sequence without attending prefix kv cache
Expand All @@ -844,13 +851,15 @@ def forward_extend(
softmax_scale=layer.scaling,
causal=True,
return_softmax_lse=forward_batch.mha_return_lse,
**kwargs,
)
if forward_batch.mha_return_lse:
output, lse, *rest = output
lse = torch.transpose(lse, 0, 1).contiguous()
return output, lse
return output
else:
assert self.fa_impl_ver in [3], "Only FA3 support here"
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
layer.layer_id
Expand Down Expand Up @@ -939,6 +948,7 @@ def forward_decode(
k_rope: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fa_impl_ver in [3], "Only FA3 support decoding"
if k is not None:
assert v is not None
if save_kv_cache:
Expand Down Expand Up @@ -985,6 +995,8 @@ def forward_decode(

# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs = {}
if self.fa_impl_ver != 3:
kwargs["ver"] = self.fa_impl_ver
if sinks is not None:
kwargs["sinks"] = sinks

Expand Down
10 changes: 10 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ def model_specific_adjustment(self):
"aiter",
"flashinfer",
"fa3",
"fa4",
"triton",
"flashmla",
"cutlass_mla",
Expand Down Expand Up @@ -1571,6 +1572,15 @@ def _get_attention_backend_from_str(self, backend_str: str):
)

return FlashAttentionBackend(self)
elif backend_str == "fa4":
assert (
self.use_mla_backend
), "FlashAttention v4 Support is at an early stage, only MLA model supported now"
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
)

return FlashAttentionBackend(self, fa_impl_ver=4)
elif backend_str == "cutlass_mla":
from sglang.srt.layers.attention.cutlass_mla_backend import (
CutlassMLABackend,
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,9 @@ def _dispatch_mla_subtype():
return AttnForwardMethod.MHA_CHUNKED_KV
else:
return _dispatch_mla_subtype()
elif attention_backend == "fa4":
# TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
return AttnForwardMethod.MHA_CHUNKED_KV
elif attention_backend == "trtllm_mla":
if (
forward_batch.forward_mode.is_extend()
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
# NVIDIA specific
"cutlass_mla",
"fa3",
"fa4",
"flashinfer",
"flashmla",
"trtllm_mla",
Expand Down
Loading