Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions python/sglang/srt/mem_cache/hiradix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,9 +468,9 @@ def check_prefetch_progress(self, req_id: str) -> bool:

# todo: more policies for prefetch progress such as timeout
# the current policy is to prefetch with best effort and terminate when queuing is over
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch.pop(
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
req_id
)
]

if operation.host_indices is None:
# prefetch has not been issued due to insufficient host memory
Expand Down Expand Up @@ -512,6 +512,7 @@ def check_prefetch_progress(self, req_id: str) -> bool:
host_indices[min_completed_tokens:completed_tokens]
)
last_host_node.release_host()
del self.ongoing_prefetch[req_id]
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)

return True
Expand Down Expand Up @@ -775,15 +776,14 @@ def release_aborted_request(self, rid: str):
if rid not in self.ongoing_prefetch:
return

last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch.pop(
rid
)
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[rid]
if operation.host_indices is None:
return

completed_tokens, _ = self.cache_controller.terminate_prefetch(operation)
if self.tp_world_size > 1:
torch.distributed.barrier(group=self.tp_group)
last_host_node.release_host()
del self.ongoing_prefetch[rid]
self.cache_controller.append_host_mem_release(host_indices[:completed_tokens])
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
Loading