We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 88a689d commit 81bb245Copy full SHA for 81bb245
vllm_ascend/models/layers/mla.py
@@ -132,8 +132,13 @@ def forward(
132
output = torch.empty(output_shape,
133
dtype=hidden_states.dtype,
134
device=hidden_states.device)
135
+ if forward_context.attn_metadata:
136
+ attn_metadata = forward_context.attn_metadata[
137
+ self.mla_attn.layer_name]
138
+ else:
139
+ attn_metadata = forward_context.attn_metadata
140
output = self.mla_attn.impl.forward(hidden_states, kv_cache,
- forward_context.attn_metadata,
- need_gather_q_kv, output)
141
+ attn_metadata, need_gather_q_kv,
142
+ output)
143
output = output.view(-1, output_shape[-1])
144
return output
0 commit comments