Skip to content

Commit 81bb245

Browse files
committed
fix deepseek mla
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent 88a689d commit 81bb245

File tree

1 file changed

+7
-2
lines changed
  • vllm_ascend/models/layers

1 file changed

+7
-2
lines changed

vllm_ascend/models/layers/mla.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,13 @@ def forward(
132132
output = torch.empty(output_shape,
133133
dtype=hidden_states.dtype,
134134
device=hidden_states.device)
135+
if forward_context.attn_metadata:
136+
attn_metadata = forward_context.attn_metadata[
137+
self.mla_attn.layer_name]
138+
else:
139+
attn_metadata = forward_context.attn_metadata
135140
output = self.mla_attn.impl.forward(hidden_states, kv_cache,
136-
forward_context.attn_metadata,
137-
need_gather_q_kv, output)
141+
attn_metadata, need_gather_q_kv,
142+
output)
138143
output = output.view(-1, output_shape[-1])
139144
return output

0 commit comments

Comments
 (0)