|
| 1 | +from typing import Optional, Union |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import torch |
| 5 | +import vllm.v1.worker.block_table |
| 6 | +from vllm.distributed import get_dcp_group |
| 7 | +from vllm.utils import cdiv |
| 8 | + |
| 9 | + |
| 10 | +class BlockTable: |
| 11 | + |
| 12 | + def __init__(self, |
| 13 | + block_size: int, |
| 14 | + max_num_reqs: int, |
| 15 | + max_num_blocks_per_req: int, |
| 16 | + max_num_batched_tokens: int, |
| 17 | + pin_memory: bool, |
| 18 | + device: torch.device, |
| 19 | + kernel_sizes: Union[list[int], None] = None): |
| 20 | + self.max_num_reqs = max_num_reqs |
| 21 | + self.max_num_blocks_per_req = max_num_blocks_per_req |
| 22 | + self.max_num_batched_tokens = max_num_batched_tokens |
| 23 | + self.pin_memory = pin_memory |
| 24 | + self.device = device |
| 25 | + self.physical_block_size = block_size |
| 26 | + # If kernel_sizes is None or [0], use physical block size (no splitting) |
| 27 | + if kernel_sizes is None or kernel_sizes == [0]: |
| 28 | + self.block_size = block_size |
| 29 | + self.logical_block_size = block_size |
| 30 | + self.blocks_per_phys_block = 1 |
| 31 | + self.use_hybrid_blocks = False |
| 32 | + else: |
| 33 | + # Find the first kernel size that divides physical_block_size evenly |
| 34 | + selected_kernel_size = None |
| 35 | + for kernel_size in kernel_sizes: |
| 36 | + if kernel_size > 0 \ |
| 37 | + and self.physical_block_size % kernel_size == 0: |
| 38 | + selected_kernel_size = kernel_size |
| 39 | + break |
| 40 | + |
| 41 | + if selected_kernel_size is None: |
| 42 | + raise ValueError( |
| 43 | + f"None of the kernel sizes {kernel_sizes} can divide " |
| 44 | + f"physical block size {self.physical_block_size} evenly") |
| 45 | + |
| 46 | + self.block_size = selected_kernel_size |
| 47 | + self.logical_block_size = selected_kernel_size |
| 48 | + self.blocks_per_phys_block = (self.physical_block_size // |
| 49 | + self.logical_block_size) |
| 50 | + if self.blocks_per_phys_block > 1: |
| 51 | + self.use_hybrid_blocks = True |
| 52 | + else: |
| 53 | + self.use_hybrid_blocks = False |
| 54 | + |
| 55 | + if self.use_hybrid_blocks: |
| 56 | + logical_table_size = (max_num_blocks_per_req * |
| 57 | + self.blocks_per_phys_block) |
| 58 | + else: |
| 59 | + logical_table_size = max_num_blocks_per_req |
| 60 | + |
| 61 | + self.block_table = torch.zeros( |
| 62 | + (max_num_reqs, logical_table_size), |
| 63 | + device=self.device, |
| 64 | + dtype=torch.int32, |
| 65 | + ) |
| 66 | + self.block_table_cpu = torch.zeros( |
| 67 | + (max_num_reqs, logical_table_size), |
| 68 | + device="cpu", |
| 69 | + dtype=torch.int32, |
| 70 | + pin_memory=pin_memory, |
| 71 | + ) |
| 72 | + self.block_table_np = self.block_table_cpu.numpy() |
| 73 | + self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) |
| 74 | + |
| 75 | + self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens, |
| 76 | + dtype=torch.int64, |
| 77 | + device="cpu", |
| 78 | + pin_memory=self.pin_memory) |
| 79 | + self.slot_mapping_np = self.slot_mapping_cpu.numpy() |
| 80 | + self.slot_mapping = torch.zeros(self.max_num_batched_tokens, |
| 81 | + dtype=torch.int64, |
| 82 | + device=self.device) |
| 83 | + try: |
| 84 | + self.dcp_world_size = get_dcp_group().world_size |
| 85 | + self.dcp_rank = get_dcp_group().rank_in_group |
| 86 | + except AssertionError: |
| 87 | + # DCP might not be initialized in testing |
| 88 | + self.dcp_world_size = 1 |
| 89 | + self.dcp_rank = 0 |
| 90 | + |
| 91 | + def append_row( |
| 92 | + self, |
| 93 | + block_ids: list[int], |
| 94 | + row_idx: int, |
| 95 | + ) -> None: |
| 96 | + if not block_ids: |
| 97 | + return |
| 98 | + |
| 99 | + if self.use_hybrid_blocks: |
| 100 | + block_ids = self._convert_physical_to_logical_blocks( |
| 101 | + np.array(block_ids)) |
| 102 | + |
| 103 | + num_blocks = len(block_ids) |
| 104 | + start = self.num_blocks_per_row[row_idx] |
| 105 | + |
| 106 | + self.block_table_np[row_idx, start:start + num_blocks] = block_ids |
| 107 | + self.num_blocks_per_row[row_idx] += num_blocks |
| 108 | + |
| 109 | + def add_row(self, block_ids: list[int], row_idx: int) -> None: |
| 110 | + self.num_blocks_per_row[row_idx] = 0 |
| 111 | + self.append_row(block_ids, row_idx) |
| 112 | + |
| 113 | + def move_row(self, src: int, tgt: int) -> None: |
| 114 | + num_blocks = self.num_blocks_per_row[src] |
| 115 | + self.block_table_np[tgt, :num_blocks] = self.block_table_np[ |
| 116 | + src, :num_blocks] |
| 117 | + self.num_blocks_per_row[tgt] = num_blocks |
| 118 | + |
| 119 | + def swap_row(self, src: int, tgt: int) -> None: |
| 120 | + num_blocks_src = self.num_blocks_per_row[src] |
| 121 | + num_blocks_tgt = self.num_blocks_per_row[tgt] |
| 122 | + self.num_blocks_per_row[src] = num_blocks_tgt |
| 123 | + self.num_blocks_per_row[tgt] = num_blocks_src |
| 124 | + |
| 125 | + self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]] |
| 126 | + |
| 127 | + def compute_slot_mapping(self, req_indices: np.ndarray, |
| 128 | + positions: np.ndarray) -> None: |
| 129 | + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] |
| 130 | + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] |
| 131 | + # where K is the max_num_blocks_per_req and the block size is 2. |
| 132 | + # NOTE(woosuk): We can't simply use `token_indices // block_size` |
| 133 | + # here because M (max_model_len) is not necessarily divisible by |
| 134 | + # block_size. |
| 135 | + |
| 136 | + if self.dcp_world_size > 1: |
| 137 | + # Note(hc): The DCP implement store kvcache with an interleave |
| 138 | + # style, the kvcache for the token whose token_idx is i is |
| 139 | + # always stored on the GPU whose dcp_rank equals i % cp_world_size: |
| 140 | + |
| 141 | + # Use a "virtual block" which equals to world_size * block_size |
| 142 | + # for block_table_indices calculation. |
| 143 | + virtual_block_size = self.block_size * self.dcp_world_size |
| 144 | + |
| 145 | + # IMPORTANT: In hybrid mode, positions are in logical block space, |
| 146 | + # but we need to map them to the correct logical block table indices |
| 147 | + logical_block_idx = positions // virtual_block_size |
| 148 | + |
| 149 | + # Account for the expanded logical table |
| 150 | + # (always needed with unified tensor) |
| 151 | + # Each physical block is split into multiple logical blocks |
| 152 | + # The logical table has been expanded to accommodate this |
| 153 | + block_table_indices = (req_indices * self.max_num_blocks_per_req * |
| 154 | + self.blocks_per_phys_block + |
| 155 | + logical_block_idx) |
| 156 | + |
| 157 | + block_numbers = self.block_table_np.ravel()[block_table_indices] |
| 158 | + # Use virtual_block_size for mask calculation, which marks local |
| 159 | + # tokens. |
| 160 | + virtual_block_offsets = positions % virtual_block_size |
| 161 | + mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank |
| 162 | + # Calculate local block_offsets |
| 163 | + block_offsets = virtual_block_offsets // self.dcp_world_size |
| 164 | + # Calculate slot_mapping |
| 165 | + slot_mapping = block_numbers * self.block_size + block_offsets |
| 166 | + # Write final slots, use -1 for not-local |
| 167 | + self.slot_mapping_np[:req_indices.shape[0]] = np.where( |
| 168 | + mask, slot_mapping, -1) |
| 169 | + else: |
| 170 | + # IMPORTANT: In hybrid mode, positions are in logical block space, |
| 171 | + # but we need to map them to the correct logical block table indices |
| 172 | + logical_block_idx = positions // self.block_size |
| 173 | + |
| 174 | + # Account for the expanded logical table |
| 175 | + # (always needed with unified tensor) |
| 176 | + # Each physical block is split into multiple logical blocks |
| 177 | + # The logical table has been expanded to accommodate this |
| 178 | + block_table_indices = (req_indices * self.max_num_blocks_per_req * |
| 179 | + self.blocks_per_phys_block + |
| 180 | + logical_block_idx) |
| 181 | + |
| 182 | + block_numbers = self.block_table_np.ravel()[block_table_indices] |
| 183 | + block_offsets = positions % self.block_size |
| 184 | + np.add(block_numbers * self.block_size, |
| 185 | + block_offsets, |
| 186 | + out=self.slot_mapping_np[:req_indices.shape[0]]) |
| 187 | + |
| 188 | + def commit_block_table(self, num_reqs: int) -> None: |
| 189 | + self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs], |
| 190 | + non_blocking=True) |
| 191 | + |
| 192 | + def commit_slot_mapping(self, num_tokens: int) -> None: |
| 193 | + self.slot_mapping[:num_tokens].copy_( |
| 194 | + self.slot_mapping_cpu[:num_tokens], non_blocking=True) |
| 195 | + |
| 196 | + def clear(self) -> None: |
| 197 | + self.block_table.fill_(0) |
| 198 | + self.block_table_cpu.fill_(0) |
| 199 | + |
| 200 | + def _convert_physical_to_logical_blocks( |
| 201 | + self, physical_blocks: np.ndarray) -> np.ndarray: |
| 202 | + """Convert physical block IDs to logical block IDs.""" |
| 203 | + if not self.use_hybrid_blocks: |
| 204 | + return physical_blocks |
| 205 | + |
| 206 | + # Create logical block IDs by splitting each physical block |
| 207 | + logical_blocks: list[int] = [] |
| 208 | + for phys_block in physical_blocks: |
| 209 | + # Convert physical block to multiple logical blocks |
| 210 | + # Physical block 1 becomes logical blocks |
| 211 | + # [1*split_ratio, 1*split_ratio+1, ...] |
| 212 | + # But we need to account for the fact that block 0 is special |
| 213 | + base_logical = phys_block * self.blocks_per_phys_block |
| 214 | + logical_blocks.extend( |
| 215 | + range(base_logical, base_logical + self.blocks_per_phys_block)) |
| 216 | + |
| 217 | + return np.array(logical_blocks, dtype=np.int32) |
| 218 | + |
| 219 | + def get_device_tensor(self) -> torch.Tensor: |
| 220 | + """Returns the device tensor of the block table.""" |
| 221 | + return self.block_table |
| 222 | + |
| 223 | + def get_cpu_tensor(self) -> torch.Tensor: |
| 224 | + """Returns the CPU tensor of the block table.""" |
| 225 | + return self.block_table_cpu |
| 226 | + |
| 227 | + def get_numpy_array(self) -> np.ndarray: |
| 228 | + """Returns the numpy array of the block table.""" |
| 229 | + return self.block_table_np |
| 230 | + |
| 231 | + |
| 232 | +class MultiGroupBlockTable: |
| 233 | + """The BlockTables for each KV cache group.""" |
| 234 | + |
| 235 | + def __init__(self, |
| 236 | + max_num_reqs: int, |
| 237 | + max_model_len: int, |
| 238 | + max_num_batched_tokens: int, |
| 239 | + pin_memory: bool, |
| 240 | + device: torch.device, |
| 241 | + block_sizes: list[int], |
| 242 | + num_speculative_tokens: int = 0, |
| 243 | + kernel_sizes: Optional[list[list[int]]] = None) -> None: |
| 244 | + # Note(hc): each dcp rank only store |
| 245 | + # (max_model_len//dcp_world_size) tokens in kvcache, |
| 246 | + # so the block_size which used for calc max_num_blocks_per_req |
| 247 | + # must be multiplied by dcp_world_size. |
| 248 | + try: |
| 249 | + dcp_world_size = get_dcp_group().world_size |
| 250 | + except AssertionError: |
| 251 | + # DCP might not be initialized in testing |
| 252 | + dcp_world_size = 1 |
| 253 | + |
| 254 | + if kernel_sizes is None: |
| 255 | + kernel_sizes = [[0]] * len(block_sizes) |
| 256 | + # Ensure kernel_sizes matches block_sizes length |
| 257 | + elif len(kernel_sizes) == 1 and len(block_sizes) > 1: |
| 258 | + kernel_sizes = kernel_sizes * len(block_sizes) |
| 259 | + elif len(kernel_sizes) != len(block_sizes): |
| 260 | + raise ValueError( |
| 261 | + f"kernel_sizes length ({len(kernel_sizes)}) must match " |
| 262 | + f"block_sizes length ({len(block_sizes)})") |
| 263 | + |
| 264 | + # Use zip to pair block_sizes with kernel_sizes one-to-one |
| 265 | + self.block_tables = [ |
| 266 | + BlockTable( |
| 267 | + block_size, max_num_reqs, |
| 268 | + max(cdiv(max_model_len, block_size * dcp_world_size), |
| 269 | + 1 + num_speculative_tokens), max_num_batched_tokens, |
| 270 | + pin_memory, device, kernel_size_list) |
| 271 | + for block_size, kernel_size_list in zip(block_sizes, kernel_sizes) |
| 272 | + ] |
| 273 | + |
| 274 | + def append_row(self, block_ids: tuple[list[int], ...], |
| 275 | + row_idx: int) -> None: |
| 276 | + for i, block_table in enumerate(self.block_tables): |
| 277 | + block_table.append_row(block_ids[i], row_idx) |
| 278 | + |
| 279 | + def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: |
| 280 | + for i, block_table in enumerate(self.block_tables): |
| 281 | + block_table.add_row(block_ids[i], row_idx) |
| 282 | + |
| 283 | + def move_row(self, src: int, tgt: int) -> None: |
| 284 | + for block_table in self.block_tables: |
| 285 | + block_table.move_row(src, tgt) |
| 286 | + |
| 287 | + def swap_row(self, src: int, tgt: int) -> None: |
| 288 | + for block_table in self.block_tables: |
| 289 | + block_table.swap_row(src, tgt) |
| 290 | + |
| 291 | + def compute_slot_mapping(self, req_indices: np.ndarray, |
| 292 | + positions: np.ndarray) -> None: |
| 293 | + for block_table in self.block_tables: |
| 294 | + block_table.compute_slot_mapping(req_indices, positions) |
| 295 | + |
| 296 | + def commit_block_table(self, num_reqs: int) -> None: |
| 297 | + for block_table in self.block_tables: |
| 298 | + block_table.commit_block_table(num_reqs) |
| 299 | + |
| 300 | + def commit_slot_mapping(self, num_tokens: int) -> None: |
| 301 | + for block_table in self.block_tables: |
| 302 | + block_table.commit_slot_mapping(num_tokens) |
| 303 | + |
| 304 | + def clear(self) -> None: |
| 305 | + for block_table in self.block_tables: |
| 306 | + block_table.clear() |
| 307 | + |
| 308 | + def __getitem__(self, idx: int) -> "BlockTable": |
| 309 | + """Returns the BlockTable for the i-th KV cache group.""" |
| 310 | + return self.block_tables[idx] |
| 311 | + |
| 312 | + |
| 313 | +vllm.v1.worker.block_table.BlockTable = BlockTable |
| 314 | +vllm.v1.worker.block_table.MultiGroupBlockTable = MultiGroupBlockTable |
0 commit comments