Skip to content

Commit 384f36c

Browse files
committed
fix kv cache blcok shape
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent b25d549 commit 384f36c

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

tests/ut/attention/test_attention_v1.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d,
115115
mock_nd_to_nz_2d.return_value = mock_nz_tensor
116116
mock_npu_format_cast.return_value = mock_nz_tensor
117117

118-
self.builder.build(common_attn_metadata, mock_model)
118+
self.builder.build(1, common_attn_metadata, mock_model)
119119

120120
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
121121
@patch('torch_npu.npu_format_cast')
@@ -151,7 +151,7 @@ def test_build_chunked_prefill(self, mock_ascend_attention_state,
151151
mock_nd_to_nz_spec.return_value = mock_nz_tensor
152152
mock_npu_format_cast.return_value = mock_nz_tensor
153153

154-
self.builder.build(common_attn_metadata, mock_model)
154+
self.builder.build(1, common_attn_metadata, mock_model)
155155

156156
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
157157
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
@@ -175,7 +175,7 @@ def test_build_non_310p(self, mock_is_310p, mock_ascend_metadata):
175175
seq_lens=None)
176176
mock_model = MagicMock()
177177

178-
self.builder.build(common_attn_metadata, mock_model)
178+
self.builder.build(1, common_attn_metadata, mock_model)
179179

180180

181181
class TestAscendAttentionBackendImpl(TestBase):

tests/ut/torchair/test_torchair_mla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ def test_build_decode(self, mock_ascend_config):
456456
num_computed_tokens_cpu=None,
457457
seq_lens=None)
458458

459-
metadata = builder.build(common_attn_metadata, model)
459+
metadata = builder.build(1, common_attn_metadata, model)
460460

461461
self.assertIsInstance(metadata, AscendMLATorchairMetadata)
462462
self.assertEqual(metadata.num_input_tokens, 0)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2459,16 +2459,17 @@ def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
24592459
kv_cache_spec.num_kv_heads,
24602460
kv_cache_spec.head_size)
24612461
elif hasattr(attn_backend, "get_supported_block_size"):
2462-
# kv_cache_shape = attn_backend.get_kv_cache_shape(
2463-
# num_blocks, kv_cache_spec.block_size,
2464-
# kv_cache_spec.num_kv_heads,
2465-
# kv_cache_spec.head_size)
24662462
block_size = attn_backend.get_supported_block_size()[0]
24672463
block_size_chunk = kv_cache_spec.block_size // block_size
24682464
kv_cache_shape = attn_backend.get_kv_cache_shape(
24692465
num_blocks * block_size_chunk, block_size,
24702466
kv_cache_spec.num_kv_heads,
24712467
kv_cache_spec.head_size)
2468+
else:
2469+
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
2470+
num_blocks, kv_cache_spec.block_size,
2471+
kv_cache_spec.num_kv_heads,
2472+
kv_cache_spec.head_size)
24722473
dtype = kv_cache_spec.dtype
24732474
if self.model_config.is_deepseek_mla:
24742475
num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape

0 commit comments

Comments
 (0)