diff --git a/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py b/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py index b9a879e32d..d6961c9e3a 100644 --- a/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py +++ b/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py @@ -231,7 +231,7 @@ def get_reasoning_parser(self): def _init_logits_processor( self, schemata_key: tuple[str, str], - enable_thinking: bool = False, + **kwargs, ) -> LogitsProcessorBase: """ init logits processor by type and schemata. @@ -248,13 +248,13 @@ def _init_logits_processor( """ key_type, schemata = schemata_key if key_type == "json": - return self._json_processor(schemata, enable_thinking) + return self._json_processor(schemata, **kwargs) elif key_type == "regex": - return self._regex_processor(schemata, enable_thinking) + return self._regex_processor(schemata, **kwargs) elif key_type == "grammar": - return self._grammar_processor(schemata, enable_thinking) + return self._grammar_processor(schemata, **kwargs) elif key_type == "structural_tag": - return self._structural_tag_processor(schemata, enable_thinking) + return self._structural_tag_processor(schemata, **kwargs) else: llm_logger.error(f"Unsupported processor type {key_type}.") return None @@ -262,7 +262,7 @@ def _init_logits_processor( def get_logits_processor( self, schemata_key: tuple[str, str], - enable_thinking: bool = False, + **kwargs, ) -> tuple[LogitsProcessorBase, bool]: """ get logits processor by key from cache or create new one. @@ -278,9 +278,9 @@ def get_logits_processor( value = self.cache.get(schemata_key, None) if value: value_copy = value.copy() - value_copy.enable_reasoning = enable_thinking + value_copy.enable_reasoning = kwargs.get("enable_thinking", False) return value_copy, True - value = self.executor.submit(self._init_logits_processor, schemata_key, enable_thinking) + value = self.executor.submit(self._init_logits_processor, schemata_key, **kwargs) return value, False def _get_tokenizer_hf(self): diff --git a/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py b/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py index d32d57f3c9..949b10bcd2 100644 --- a/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py +++ b/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py @@ -40,8 +40,9 @@ StructuralTagItem, TokenizerInfo, allocate_token_bitmask, - apply_token_bitmask_inplace, ) + + from .kernels.xgrammar_apply_token_bitmask import apply_token_bitmask_inplace_triton except Exception as e: raise Exception(f"import XGrammar failed, please check your environment:\n\t {e}") @@ -71,6 +72,7 @@ def __init__( vocab_size: Optional[int] = None, batch_size: Optional[int] = None, enable_thinking: bool = False, + request_id: Optional[str] = None, ): super().__init__(enable_reasoning=enable_thinking) self.max_rollback_tokens = 200 @@ -79,6 +81,7 @@ def __init__( self.compiled_grammar = compiled_grammar self.terminate_without_stop_token = terminate_without_stop_token self.override_stop_tokens = override_stop_tokens + self.request_id = request_id self.matcher = GrammarMatcher( compiled_grammar=compiled_grammar, @@ -112,7 +115,7 @@ def fill_token_bitmask(self, token_bitmask: torch.Tensor, idx: int) -> None: def apply_token_mask( self, logits: paddle.Tensor, - token_bitmask: torch.Tensor, + token_bitmask: paddle.Tensor, indices: Optional[List[int]] = None, ) -> paddle.Tensor: """ @@ -120,28 +123,21 @@ def apply_token_mask( Args: logits (paddle.Tensor): The logits tensor to modify - token_bitmask (torch.Tensor): The token bitmask indicating allowed tokens + token_bitmask (paddle.Tensor): The token bitmask indicating allowed tokens indices (Optional[List[int]]): Optional list of batch indices to apply mask to Returns: paddle.Tensor: The modified logits tensor """ - origin_place = logits.place - origin_dtype = logits.dtype - logits = torch.from_numpy(logits.numpy()) - - logits = logits.float() # cpu - apply_token_bitmask_inplace( - logits=logits, - bitmask=token_bitmask.to(logits.device, non_blocking=True), - indices=indices, - ) + if token_bitmask.place != logits.place: + token_bitmask = token_bitmask.to(device=logits.place) - return paddle.to_tensor( - logits.numpy(), - dtype=origin_dtype, - place=origin_place, - ) + if logits.place.is_gpu_place(): + apply_token_bitmask_inplace_triton(logits, token_bitmask, self.vocab_size, indices) + else: + llm_logger.error(f"Unsupported device {logits.place}, skip guided decoding.") + + return logits def reset(self) -> None: """ @@ -156,13 +152,16 @@ def accept_token(self, token: int) -> None: """ Validate and accept a generated token against the grammar constraints. + # When the output token reaches the maximum length, + # it will be forced to get an eos_token, the output is not restricted by guided decoding + Args: token (int): The token ID to validate - - Raises: - AssertionError: If token is not allowed by the grammar """ - assert self.matcher.accept_token(token), f"Failed to accept token {token}" + if not self.matcher.accept_token(token): + llm_logger.error(f"request: {self.request_id} failed to accept token [{token}]") + return False + return True def is_terminated(self) -> bool: """ @@ -226,6 +225,7 @@ def _create_processor( terminate_without_stop_token: bool = False, override_stop_tokens: Optional[List[int]] = None, enable_thinking: bool = False, + request_id: Optional[str] = None, ) -> XGrammarProcessor: """ Create a logits processor instance for the given compiled grammar. @@ -246,9 +246,10 @@ def _create_processor( vocab_size=self.vocab_size, batch_size=self.batch_size, enable_thinking=enable_thinking, + request_id=request_id, ) - def _json_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]: + def _json_processor(self, schemata: str, **kwargs) -> Optional[XGrammarProcessor]: """ Compile JSON schema into a grammar processor. @@ -264,9 +265,9 @@ def _json_processor(self, schemata: str, enable_thinking: bool = False) -> Optio except Exception as e: llm_logger.error(f"Failed to compile json schema: {e}, {str(traceback.format_exc())}") return None - return self._create_processor(compiled_grammar, enable_thinking=enable_thinking) + return self._create_processor(compiled_grammar, **kwargs) - def _regex_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]: + def _regex_processor(self, schemata: str, **kwargs) -> Optional[XGrammarProcessor]: """ Compile regex pattern into a grammar processor. @@ -282,9 +283,9 @@ def _regex_processor(self, schemata: str, enable_thinking: bool = False) -> Opti except Exception as e: llm_logger.error(f"Failed to compile regex schema: {e}, {str(traceback.format_exc())}") return None - return self._create_processor(compiled_grammar, enable_thinking=enable_thinking) + return self._create_processor(compiled_grammar, **kwargs) - def _grammar_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]: + def _grammar_processor(self, schemata: str, **kwargs) -> Optional[XGrammarProcessor]: """ Compile grammar (EBNF) into a grammar processor. @@ -300,9 +301,9 @@ def _grammar_processor(self, schemata: str, enable_thinking: bool = False) -> Op except Exception as e: llm_logger.error(f"Failed to compile ebnf schema: {e}, {str(traceback.format_exc())}") return None - return self._create_processor(compiled_grammar, enable_thinking=enable_thinking) + return self._create_processor(compiled_grammar, **kwargs) - def _structural_tag_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]: + def _structural_tag_processor(self, schemata: str, **kwargs) -> Optional[XGrammarProcessor]: """ Compile structural tags into a grammar processor. @@ -327,7 +328,7 @@ def _structural_tag_processor(self, schemata: str, enable_thinking: bool = False except Exception as e: llm_logger.error(f"Failed to compile structural tags schema: {e}, {str(traceback.format_exc())}") return None - return self._create_processor(compiled_grammar, enable_thinking=enable_thinking) + return self._create_processor(compiled_grammar, **kwargs) class XGrammarChecker(BaseChecker): diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index f8fd1755ab..72b3dd3933 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -15,9 +15,11 @@ """ import threading +from abc import abstractmethod from concurrent.futures import ThreadPoolExecutor from typing import Any, Dict, List, Optional +import numpy as np import paddle import paddle.nn.functional as F from paddle import nn @@ -52,34 +54,51 @@ def top_p_normalize_probs_paddle( class SamplerProcessor: - """ - SamplingProcessor for guided decoding. + """Handles guided decoding with thread-safe logits processing and token masking. + + Manages asynchronous operations for efficient sampling with: + - Logits processors for constrained decoding + - Vocabulary masking + - Thread-safe state updates """ def __init__(self): self.async_step = None self.token_bitmask = None + self.insert_processor = False self.logits_processor: Dict[int, Optional[Any]] = dict() self.executor = ThreadPoolExecutor() self.logits_lock = threading.Lock() - self.reasoning_parser = None + self.reasoning_end_id = None def apply_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None): - self.reasoning_parser = reasoning_parser + if reasoning_parser: + self.reasoning_end_id = reasoning_parser.think_end_token_id def add_logits_processor( - self, - ids: int, - future: Optional[Any] = None, - prefill_tokens: List[int] = [], + self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = [], skip: bool = False ): - """add logits processor to SamplerProcessor""" + """Registers a logits processor for guided decoding. + + Args: + ids: Unique sequence identifier + future: Logits processor instance or Future containing one + prefill_tokens: Initial tokens to pre-load + skip: Whether to skip processor insertion + """ + if self.async_step is not None: + self.async_step.result() + self.async_step = None + with self.logits_lock: if future is None: if ids in self.logits_processor: del self.logits_processor[ids] return + if not skip: + self.insert_processor = True + if isinstance(future, LogitsProcessorBase): self.logits_processor[ids] = future for token in prefill_tokens: @@ -91,15 +110,25 @@ def add_logits_processor( else: self.logits_processor[ids] = [future, prefill_tokens] - def update_vocab_mask(self, skip_idx_list: List[int] = []): - """update vocab mask. (cpu-heavy operation)""" + def get_available_processors(self): + """ + get available logits processor + """ + with self.logits_lock: + for processor in self.logits_processor.values(): + if processor.is_terminated() or not isinstance(processor, LogitsProcessorBase): + continue + return processor + return None + + def update_logits_processor(self): + """update logits processor""" if len(self.logits_processor) == 0: return with self.logits_lock: for idx, processor in self.logits_processor.items(): if processor is None: - del self.logits_processor[idx] continue if not isinstance(processor, LogitsProcessorBase): @@ -108,45 +137,70 @@ def update_vocab_mask(self, skip_idx_list: List[int] = []): for token in prefill_tokens: self.logits_processor[idx].accept_token(token) - available_processors = None - for processor in self.logits_processor.values(): - if processor.is_terminated(): - continue - available_processors = processor - if available_processors is None: - return + def update_vocab_mask(self, skip_idx_list: List[int] = []): + """Updates vocabulary mask based on active constraints. + + Note: This is a CPU-intensive operation that: + 1. Processes pending logits processors + 2. Allocates and fills token bitmask + + Args: + skip_idx_list: Sequence IDs to exclude from masking + """ + if len(self.logits_processor) == 0: + return + + available_processors = self.get_available_processors() + if available_processors is None: + return # allocate token bitmask - self.token_bitmask = available_processors.allocate_token_bitmask() + token_bitmask = available_processors.allocate_token_bitmask() + self.update_logits_processor() with self.logits_lock: + # TODO: 支持并行 fill token bitmask # fill token bitmask for idx, processor in self.logits_processor.items(): if processor.is_terminated() or idx in skip_idx_list: continue - processor.fill_token_bitmask(self.token_bitmask, idx) + processor.fill_token_bitmask(token_bitmask, idx) + self.token_bitmask = paddle.to_tensor(token_bitmask.numpy()) def apply_token_mask(self, logits: paddle.Tensor, skip_idx_list: List[int] = []): - """apply token mask to logits""" - if len(self.logits_processor) == 0 or self.token_bitmask is None: + """Applies vocabulary mask to restrict token sampling. + + Args: + logits: Input logits tensor + skip_idx_list: Sequence IDs to exclude from masking + + Returns: + Masked logits tensor + """ + if len(self.logits_processor) == 0: return logits - # self.async_step.result() - available_processors = None - with self.logits_lock: - for processor in self.logits_processor.values(): - if processor.is_terminated(): - continue - available_processors = processor - if available_processors is None: + self.update_logits_processor() + if self.insert_processor: + # TODO: 只更新插入位置的processor + self.update_vocab_mask(skip_idx_list) + with self.logits_lock: + self.insert_processor = False + + if self.async_step is not None: + self.async_step.result() + self.async_step = None + + available_processors = self.get_available_processors() + if available_processors is None or self.token_bitmask is None: return logits indices = [] for idx, processor in self.logits_processor.items(): - if processor is None or idx in skip_idx_list: + if processor is None or processor.is_terminated() or idx in skip_idx_list: continue - if self.reasoning_parser is None or not processor.enable_reasoning or processor.reasoning_ended: + if self.reasoning_end_id is None or not processor.enable_reasoning or processor.reasoning_ended: indices.append(idx) return available_processors.apply_token_mask(logits, self.token_bitmask, indices=indices) @@ -160,41 +214,73 @@ def _accept_token(self, idx: int, token: int): return if ( - self.reasoning_parser is not None + self.reasoning_end_id is not None and self.logits_processor[idx].enable_reasoning and not self.logits_processor[idx].reasoning_ended ): - reasoning_ended = self.reasoning_parser.is_reasoning_end([token]) - self.logits_processor[idx].reasoning_ended = reasoning_ended + # check reasoning end + self.logits_processor[idx].reasoning_ended = self.reasoning_end_id == token return self.logits_processor[idx].accept_token(token) - def update_output_tokens(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []): - """update output tokens""" + def update_output_tokens( + self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = [], skip_list_next: List[int] = [] + ): + """Updates processors with newly generated tokens asynchronously. + + Args: + next_tokens: Newly sampled tokens + skip_idx_list: Current step IDs to skip + skip_list_next: Next step IDs to skip + """ if len(self.logits_processor) == 0: return - token_ids = next_tokens.numpy().tolist() - with self.logits_lock: - for idx in self.logits_processor.keys(): - token = token_ids[idx][0] - if token < 0 or self.logits_processor[idx] is None or idx in skip_idx_list: - continue + # create async operation for guided decoding + def async_update(next_tokens, skip_idx_list, skip_list_next): + with self.logits_lock: + for idx_tuple, token in np.ndenumerate(next_tokens.numpy()): + idx = idx_tuple[0] + if token < 0 or self.logits_processor[idx] is None or idx in skip_idx_list: + continue + self._accept_token(idx, token) - self._accept_token(idx, token) + self.update_vocab_mask(skip_list_next) + + self.async_step = self.executor.submit(async_update, next_tokens, skip_idx_list, skip_list_next) - def pre_process(self, skip_idx_list: List[int] = []): - """pre process before running""" - # create async operation for guided decoding - # TODO: support async - self.update_vocab_mask(skip_idx_list) - # self.async_step = self.executor.submit(self.update_vocab_mask) + +class SamplerBase(nn.Layer): + def __init__(self): + """Base class for sampler""" + super().__init__() + + @abstractmethod + def forward_cuda(self, *args, **kwargs) -> SamplerOutput: + pass + + def apply_logits_processor(self, *args, **kwargs): + """apply logits processor to sampler""" + pass + + def async_post_process(self, *args, **kwargs): + """async accept token from outside, update vocab mask for inside""" + pass + + def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None): + """set reasoning parser""" + pass -class Sampler(nn.Layer): +class Sampler(SamplerBase): """ - Sampler for normal generation. + Normal generation sampler with guided decoding support. + + Features: + - Top-p (nucleus) sampling + - Repetition/frequency penalties + - Integration with guided decoding processors """ def __init__(self, fd_config: FDConfig = None): @@ -227,17 +313,24 @@ def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = Non """set reasoning parser""" self.processor.apply_reasoning_parser(reasoning_parser) - def apply_logits_processor(self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = []): + def apply_logits_processor( + self, + ids: int, + future: Optional[Any] = None, + prefill_tokens: List[int] = [], + skip: bool = False, + ): """apply logits processor to sampler""" - self.processor.add_logits_processor(ids, future, prefill_tokens) + self.processor.add_logits_processor(ids, future=future, prefill_tokens=prefill_tokens, skip=skip) - def pre_process(self, skip_idx_list: List[int] = []): - """pre process before running""" - self.processor.pre_process(skip_idx_list) - - def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []): - """post process after running""" - self.processor.update_output_tokens(next_tokens, skip_idx_list) + def async_post_process( + self, + next_tokens: paddle.Tensor, + skip_idx_list: List[int] = [], + skip_list_next: List[int] = [], + ): + """async accept token from outside, update vocab mask for inside""" + self.processor.update_output_tokens(next_tokens, skip_idx_list, skip_list_next) def compute_logprobs( self, @@ -327,6 +420,7 @@ def forward_cuda( skip_idx_list: List[int] = [], ) -> SamplerOutput: """ """ + # guided decoding apply token mask for logits logits = self.processor.apply_token_mask(logits, skip_idx_list) num_logprobs = sampling_metadata.max_num_logprobs @@ -378,7 +472,7 @@ def forward_cuda( return sampler_output -class SpeculativeSampler(nn.Layer): +class SpeculativeSampler(SamplerBase): """ Sampler for speculative generation. """ @@ -394,22 +488,6 @@ def __init__(self, fd_config: FDConfig): self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode - def pre_process(self, skip_idx_list: List[int] = []): - """pre process before running""" - pass - - def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None): - """set reasoning parser""" - pass - - def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []): - """post process after running""" - pass - - def apply_logits_processor(self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = []): - """apply logits processor to sampler""" - pass - def forward_cuda( self, logits: paddle.Tensor, @@ -477,7 +555,7 @@ def forward_cuda( return None -class MTPSampler(nn.Layer): +class MTPSampler(SamplerBase): """ """ def __init__(self, fd_config: FDConfig): @@ -488,27 +566,6 @@ def __init__(self, fd_config: FDConfig): else: raise NotImplementedError - def pre_process(self, skip_idx_list: List[int] = []): - """pre process before running""" - pass - - def apply_logits_processor( - self, - ids: int, - future: Optional[Any] = None, - prefill_tokens: List[int] = [], - ): - """apply logits processor to sampler""" - pass - - def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None): - """set reasoning parser""" - pass - - def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []): - """post process after running""" - pass - def forward_cuda( self, logits: paddle.Tensor, diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 8bb0239d53..0a469fadee 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -264,6 +264,8 @@ def _init_logits_processor(self, request): self.guided_backend.get_logits_processor( schemata_key=schemata_key, enable_thinking=enable_thinking, + override_stop_tokens=request.eos_token_ids, + request_id=request.request_id, ), schemata_key, )