Skip to content

Commit 73a8e05

Browse files
committed
add vllm patch
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent 5241893 commit 73a8e05

File tree

5 files changed

+727
-0
lines changed

5 files changed

+727
-0
lines changed

vllm_ascend/patch/platform/patch_common/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18+
import vllm_ascend.patch.platform.patch_common.patch_block_table # noqa
1819
import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa
20+
import vllm_ascend.patch.platform.patch_common.patch_mamba_config # noqa
1921
import vllm_ascend.patch.platform.patch_common.patch_shared_fused_moe # noqa
Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
from typing import Optional, Union
2+
3+
import numpy as np
4+
import torch
5+
import vllm.v1.worker.block_table
6+
from vllm.distributed import get_dcp_group
7+
from vllm.utils import cdiv
8+
9+
10+
class BlockTable:
11+
12+
def __init__(self,
13+
block_size: int,
14+
max_num_reqs: int,
15+
max_num_blocks_per_req: int,
16+
max_num_batched_tokens: int,
17+
pin_memory: bool,
18+
device: torch.device,
19+
kernel_sizes: Union[list[int], None] = None):
20+
self.max_num_reqs = max_num_reqs
21+
self.max_num_blocks_per_req = max_num_blocks_per_req
22+
self.max_num_batched_tokens = max_num_batched_tokens
23+
self.pin_memory = pin_memory
24+
self.device = device
25+
self.physical_block_size = block_size
26+
# If kernel_sizes is None or [0], use physical block size (no splitting)
27+
if kernel_sizes is None or kernel_sizes == [0]:
28+
self.block_size = block_size
29+
self.logical_block_size = block_size
30+
self.blocks_per_phys_block = 1
31+
self.use_hybrid_blocks = False
32+
else:
33+
# Find the first kernel size that divides physical_block_size evenly
34+
selected_kernel_size = None
35+
for kernel_size in kernel_sizes:
36+
if kernel_size > 0 \
37+
and self.physical_block_size % kernel_size == 0:
38+
selected_kernel_size = kernel_size
39+
break
40+
41+
if selected_kernel_size is None:
42+
raise ValueError(
43+
f"None of the kernel sizes {kernel_sizes} can divide "
44+
f"physical block size {self.physical_block_size} evenly")
45+
46+
self.block_size = selected_kernel_size
47+
self.logical_block_size = selected_kernel_size
48+
self.blocks_per_phys_block = (self.physical_block_size //
49+
self.logical_block_size)
50+
if self.blocks_per_phys_block > 1:
51+
self.use_hybrid_blocks = True
52+
else:
53+
self.use_hybrid_blocks = False
54+
55+
if self.use_hybrid_blocks:
56+
logical_table_size = (max_num_blocks_per_req *
57+
self.blocks_per_phys_block)
58+
else:
59+
logical_table_size = max_num_blocks_per_req
60+
61+
self.block_table = torch.zeros(
62+
(max_num_reqs, logical_table_size),
63+
device=self.device,
64+
dtype=torch.int32,
65+
)
66+
self.block_table_cpu = torch.zeros(
67+
(max_num_reqs, logical_table_size),
68+
device="cpu",
69+
dtype=torch.int32,
70+
pin_memory=pin_memory,
71+
)
72+
self.block_table_np = self.block_table_cpu.numpy()
73+
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
74+
75+
self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens,
76+
dtype=torch.int64,
77+
device="cpu",
78+
pin_memory=self.pin_memory)
79+
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
80+
self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
81+
dtype=torch.int64,
82+
device=self.device)
83+
try:
84+
self.dcp_world_size = get_dcp_group().world_size
85+
self.dcp_rank = get_dcp_group().rank_in_group
86+
except AssertionError:
87+
# DCP might not be initialized in testing
88+
self.dcp_world_size = 1
89+
self.dcp_rank = 0
90+
91+
def append_row(
92+
self,
93+
block_ids: list[int],
94+
row_idx: int,
95+
) -> None:
96+
if not block_ids:
97+
return
98+
99+
if self.use_hybrid_blocks:
100+
block_ids = self._convert_physical_to_logical_blocks(
101+
np.array(block_ids))
102+
103+
num_blocks = len(block_ids)
104+
start = self.num_blocks_per_row[row_idx]
105+
106+
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
107+
self.num_blocks_per_row[row_idx] += num_blocks
108+
109+
def add_row(self, block_ids: list[int], row_idx: int) -> None:
110+
self.num_blocks_per_row[row_idx] = 0
111+
self.append_row(block_ids, row_idx)
112+
113+
def move_row(self, src: int, tgt: int) -> None:
114+
num_blocks = self.num_blocks_per_row[src]
115+
self.block_table_np[tgt, :num_blocks] = self.block_table_np[
116+
src, :num_blocks]
117+
self.num_blocks_per_row[tgt] = num_blocks
118+
119+
def swap_row(self, src: int, tgt: int) -> None:
120+
num_blocks_src = self.num_blocks_per_row[src]
121+
num_blocks_tgt = self.num_blocks_per_row[tgt]
122+
self.num_blocks_per_row[src] = num_blocks_tgt
123+
self.num_blocks_per_row[tgt] = num_blocks_src
124+
125+
self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]
126+
127+
def compute_slot_mapping(self, req_indices: np.ndarray,
128+
positions: np.ndarray) -> None:
129+
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
130+
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
131+
# where K is the max_num_blocks_per_req and the block size is 2.
132+
# NOTE(woosuk): We can't simply use `token_indices // block_size`
133+
# here because M (max_model_len) is not necessarily divisible by
134+
# block_size.
135+
136+
if self.dcp_world_size > 1:
137+
# Note(hc): The DCP implement store kvcache with an interleave
138+
# style, the kvcache for the token whose token_idx is i is
139+
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
140+
141+
# Use a "virtual block" which equals to world_size * block_size
142+
# for block_table_indices calculation.
143+
virtual_block_size = self.block_size * self.dcp_world_size
144+
145+
# IMPORTANT: In hybrid mode, positions are in logical block space,
146+
# but we need to map them to the correct logical block table indices
147+
logical_block_idx = positions // virtual_block_size
148+
149+
# Account for the expanded logical table
150+
# (always needed with unified tensor)
151+
# Each physical block is split into multiple logical blocks
152+
# The logical table has been expanded to accommodate this
153+
block_table_indices = (req_indices * self.max_num_blocks_per_req *
154+
self.blocks_per_phys_block +
155+
logical_block_idx)
156+
157+
block_numbers = self.block_table_np.ravel()[block_table_indices]
158+
# Use virtual_block_size for mask calculation, which marks local
159+
# tokens.
160+
virtual_block_offsets = positions % virtual_block_size
161+
mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank
162+
# Calculate local block_offsets
163+
block_offsets = virtual_block_offsets // self.dcp_world_size
164+
# Calculate slot_mapping
165+
slot_mapping = block_numbers * self.block_size + block_offsets
166+
# Write final slots, use -1 for not-local
167+
self.slot_mapping_np[:req_indices.shape[0]] = np.where(
168+
mask, slot_mapping, -1)
169+
else:
170+
# IMPORTANT: In hybrid mode, positions are in logical block space,
171+
# but we need to map them to the correct logical block table indices
172+
logical_block_idx = positions // self.block_size
173+
174+
# Account for the expanded logical table
175+
# (always needed with unified tensor)
176+
# Each physical block is split into multiple logical blocks
177+
# The logical table has been expanded to accommodate this
178+
block_table_indices = (req_indices * self.max_num_blocks_per_req *
179+
self.blocks_per_phys_block +
180+
logical_block_idx)
181+
182+
block_numbers = self.block_table_np.ravel()[block_table_indices]
183+
block_offsets = positions % self.block_size
184+
np.add(block_numbers * self.block_size,
185+
block_offsets,
186+
out=self.slot_mapping_np[:req_indices.shape[0]])
187+
188+
def commit_block_table(self, num_reqs: int) -> None:
189+
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
190+
non_blocking=True)
191+
192+
def commit_slot_mapping(self, num_tokens: int) -> None:
193+
self.slot_mapping[:num_tokens].copy_(
194+
self.slot_mapping_cpu[:num_tokens], non_blocking=True)
195+
196+
def clear(self) -> None:
197+
self.block_table.fill_(0)
198+
self.block_table_cpu.fill_(0)
199+
200+
def _convert_physical_to_logical_blocks(
201+
self, physical_blocks: np.ndarray) -> np.ndarray:
202+
"""Convert physical block IDs to logical block IDs."""
203+
if not self.use_hybrid_blocks:
204+
return physical_blocks
205+
206+
# Create logical block IDs by splitting each physical block
207+
logical_blocks: list[int] = []
208+
for phys_block in physical_blocks:
209+
# Convert physical block to multiple logical blocks
210+
# Physical block 1 becomes logical blocks
211+
# [1*split_ratio, 1*split_ratio+1, ...]
212+
# But we need to account for the fact that block 0 is special
213+
base_logical = phys_block * self.blocks_per_phys_block
214+
logical_blocks.extend(
215+
range(base_logical, base_logical + self.blocks_per_phys_block))
216+
217+
return np.array(logical_blocks, dtype=np.int32)
218+
219+
def get_device_tensor(self) -> torch.Tensor:
220+
"""Returns the device tensor of the block table."""
221+
return self.block_table
222+
223+
def get_cpu_tensor(self) -> torch.Tensor:
224+
"""Returns the CPU tensor of the block table."""
225+
return self.block_table_cpu
226+
227+
def get_numpy_array(self) -> np.ndarray:
228+
"""Returns the numpy array of the block table."""
229+
return self.block_table_np
230+
231+
232+
class MultiGroupBlockTable:
233+
"""The BlockTables for each KV cache group."""
234+
235+
def __init__(self,
236+
max_num_reqs: int,
237+
max_model_len: int,
238+
max_num_batched_tokens: int,
239+
pin_memory: bool,
240+
device: torch.device,
241+
block_sizes: list[int],
242+
num_speculative_tokens: int = 0,
243+
kernel_sizes: Optional[list[list[int]]] = None) -> None:
244+
# Note(hc): each dcp rank only store
245+
# (max_model_len//dcp_world_size) tokens in kvcache,
246+
# so the block_size which used for calc max_num_blocks_per_req
247+
# must be multiplied by dcp_world_size.
248+
try:
249+
dcp_world_size = get_dcp_group().world_size
250+
except AssertionError:
251+
# DCP might not be initialized in testing
252+
dcp_world_size = 1
253+
254+
if kernel_sizes is None:
255+
kernel_sizes = [[0]] * len(block_sizes)
256+
# Ensure kernel_sizes matches block_sizes length
257+
elif len(kernel_sizes) == 1 and len(block_sizes) > 1:
258+
kernel_sizes = kernel_sizes * len(block_sizes)
259+
elif len(kernel_sizes) != len(block_sizes):
260+
raise ValueError(
261+
f"kernel_sizes length ({len(kernel_sizes)}) must match "
262+
f"block_sizes length ({len(block_sizes)})")
263+
264+
# Use zip to pair block_sizes with kernel_sizes one-to-one
265+
self.block_tables = [
266+
BlockTable(
267+
block_size, max_num_reqs,
268+
max(cdiv(max_model_len, block_size * dcp_world_size),
269+
1 + num_speculative_tokens), max_num_batched_tokens,
270+
pin_memory, device, kernel_size_list)
271+
for block_size, kernel_size_list in zip(block_sizes, kernel_sizes)
272+
]
273+
274+
def append_row(self, block_ids: tuple[list[int], ...],
275+
row_idx: int) -> None:
276+
for i, block_table in enumerate(self.block_tables):
277+
block_table.append_row(block_ids[i], row_idx)
278+
279+
def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
280+
for i, block_table in enumerate(self.block_tables):
281+
block_table.add_row(block_ids[i], row_idx)
282+
283+
def move_row(self, src: int, tgt: int) -> None:
284+
for block_table in self.block_tables:
285+
block_table.move_row(src, tgt)
286+
287+
def swap_row(self, src: int, tgt: int) -> None:
288+
for block_table in self.block_tables:
289+
block_table.swap_row(src, tgt)
290+
291+
def compute_slot_mapping(self, req_indices: np.ndarray,
292+
positions: np.ndarray) -> None:
293+
for block_table in self.block_tables:
294+
block_table.compute_slot_mapping(req_indices, positions)
295+
296+
def commit_block_table(self, num_reqs: int) -> None:
297+
for block_table in self.block_tables:
298+
block_table.commit_block_table(num_reqs)
299+
300+
def commit_slot_mapping(self, num_tokens: int) -> None:
301+
for block_table in self.block_tables:
302+
block_table.commit_slot_mapping(num_tokens)
303+
304+
def clear(self) -> None:
305+
for block_table in self.block_tables:
306+
block_table.clear()
307+
308+
def __getitem__(self, idx: int) -> "BlockTable":
309+
"""Returns the BlockTable for the i-th KV cache group."""
310+
return self.block_tables[idx]
311+
312+
313+
vllm.v1.worker.block_table.BlockTable = BlockTable
314+
vllm.v1.worker.block_table.MultiGroupBlockTable = MultiGroupBlockTable

0 commit comments

Comments
 (0)