Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def get_reasoning_parser(self):
def _init_logits_processor(
self,
schemata_key: tuple[str, str],
enable_thinking: bool = False,
**kwargs,
) -> LogitsProcessorBase:
"""
init logits processor by type and schemata.
Expand All @@ -248,21 +248,21 @@ def _init_logits_processor(
"""
key_type, schemata = schemata_key
if key_type == "json":
return self._json_processor(schemata, enable_thinking)
return self._json_processor(schemata, **kwargs)
elif key_type == "regex":
return self._regex_processor(schemata, enable_thinking)
return self._regex_processor(schemata, **kwargs)
elif key_type == "grammar":
return self._grammar_processor(schemata, enable_thinking)
return self._grammar_processor(schemata, **kwargs)
elif key_type == "structural_tag":
return self._structural_tag_processor(schemata, enable_thinking)
return self._structural_tag_processor(schemata, **kwargs)
else:
llm_logger.error(f"Unsupported processor type {key_type}.")
return None

def get_logits_processor(
self,
schemata_key: tuple[str, str],
enable_thinking: bool = False,
**kwargs,
) -> tuple[LogitsProcessorBase, bool]:
"""
get logits processor by key from cache or create new one.
Expand All @@ -278,9 +278,9 @@ def get_logits_processor(
value = self.cache.get(schemata_key, None)
if value:
value_copy = value.copy()
value_copy.enable_reasoning = enable_thinking
value_copy.enable_reasoning = kwargs.get("enable_thinking", False)
return value_copy, True
value = self.executor.submit(self._init_logits_processor, schemata_key, enable_thinking)
value = self.executor.submit(self._init_logits_processor, schemata_key, **kwargs)
return value, False

def _get_tokenizer_hf(self):
Expand Down
61 changes: 31 additions & 30 deletions fastdeploy/model_executor/guided_decoding/xgrammar_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@
StructuralTagItem,
TokenizerInfo,
allocate_token_bitmask,
apply_token_bitmask_inplace,
)

from .kernels.xgrammar_apply_token_bitmask import apply_token_bitmask_inplace_triton
except Exception as e:
raise Exception(f"import XGrammar failed, please check your environment:\n\t {e}")

Expand Down Expand Up @@ -71,6 +72,7 @@ def __init__(
vocab_size: Optional[int] = None,
batch_size: Optional[int] = None,
enable_thinking: bool = False,
request_id: Optional[str] = None,
):
super().__init__(enable_reasoning=enable_thinking)
self.max_rollback_tokens = 200
Expand All @@ -79,6 +81,7 @@ def __init__(
self.compiled_grammar = compiled_grammar
self.terminate_without_stop_token = terminate_without_stop_token
self.override_stop_tokens = override_stop_tokens
self.request_id = request_id

self.matcher = GrammarMatcher(
compiled_grammar=compiled_grammar,
Expand Down Expand Up @@ -112,36 +115,29 @@ def fill_token_bitmask(self, token_bitmask: torch.Tensor, idx: int) -> None:
def apply_token_mask(
self,
logits: paddle.Tensor,
token_bitmask: torch.Tensor,
token_bitmask: paddle.Tensor,
indices: Optional[List[int]] = None,
) -> paddle.Tensor:
"""
Apply the token mask to the logits, modifying probabilities of invalid tokens.

Args:
logits (paddle.Tensor): The logits tensor to modify
token_bitmask (torch.Tensor): The token bitmask indicating allowed tokens
token_bitmask (paddle.Tensor): The token bitmask indicating allowed tokens
indices (Optional[List[int]]): Optional list of batch indices to apply mask to

Returns:
paddle.Tensor: The modified logits tensor
"""
origin_place = logits.place
origin_dtype = logits.dtype
logits = torch.from_numpy(logits.numpy())

logits = logits.float() # cpu
apply_token_bitmask_inplace(
logits=logits,
bitmask=token_bitmask.to(logits.device, non_blocking=True),
indices=indices,
)
if token_bitmask.place != logits.place:
token_bitmask = token_bitmask.to(device=logits.place)

return paddle.to_tensor(
logits.numpy(),
dtype=origin_dtype,
place=origin_place,
)
if logits.place.is_gpu_place():
apply_token_bitmask_inplace_triton(logits, token_bitmask, self.vocab_size, indices)
else:
llm_logger.error(f"Unsupported device {logits.place}, skip guided decoding.")

return logits

def reset(self) -> None:
"""
Expand All @@ -156,13 +152,16 @@ def accept_token(self, token: int) -> None:
"""
Validate and accept a generated token against the grammar constraints.

# When the output token reaches the maximum length,
# it will be forced to get an eos_token, the output is not restricted by guided decoding

Args:
token (int): The token ID to validate

Raises:
AssertionError: If token is not allowed by the grammar
"""
assert self.matcher.accept_token(token), f"Failed to accept token {token}"
if not self.matcher.accept_token(token):
llm_logger.error(f"request: {self.request_id} failed to accept token [{token}]")
return False
return True

def is_terminated(self) -> bool:
"""
Expand Down Expand Up @@ -226,6 +225,7 @@ def _create_processor(
terminate_without_stop_token: bool = False,
override_stop_tokens: Optional[List[int]] = None,
enable_thinking: bool = False,
request_id: Optional[str] = None,
) -> XGrammarProcessor:
"""
Create a logits processor instance for the given compiled grammar.
Expand All @@ -246,9 +246,10 @@ def _create_processor(
vocab_size=self.vocab_size,
batch_size=self.batch_size,
enable_thinking=enable_thinking,
request_id=request_id,
)

def _json_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]:
def _json_processor(self, schemata: str, **kwargs) -> Optional[XGrammarProcessor]:
"""
Compile JSON schema into a grammar processor.

Expand All @@ -264,9 +265,9 @@ def _json_processor(self, schemata: str, enable_thinking: bool = False) -> Optio
except Exception as e:
llm_logger.error(f"Failed to compile json schema: {e}, {str(traceback.format_exc())}")
return None
return self._create_processor(compiled_grammar, enable_thinking=enable_thinking)
return self._create_processor(compiled_grammar, **kwargs)

def _regex_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]:
def _regex_processor(self, schemata: str, **kwargs) -> Optional[XGrammarProcessor]:
"""
Compile regex pattern into a grammar processor.

Expand All @@ -282,9 +283,9 @@ def _regex_processor(self, schemata: str, enable_thinking: bool = False) -> Opti
except Exception as e:
llm_logger.error(f"Failed to compile regex schema: {e}, {str(traceback.format_exc())}")
return None
return self._create_processor(compiled_grammar, enable_thinking=enable_thinking)
return self._create_processor(compiled_grammar, **kwargs)

def _grammar_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]:
def _grammar_processor(self, schemata: str, **kwargs) -> Optional[XGrammarProcessor]:
"""
Compile grammar (EBNF) into a grammar processor.

Expand All @@ -300,9 +301,9 @@ def _grammar_processor(self, schemata: str, enable_thinking: bool = False) -> Op
except Exception as e:
llm_logger.error(f"Failed to compile ebnf schema: {e}, {str(traceback.format_exc())}")
return None
return self._create_processor(compiled_grammar, enable_thinking=enable_thinking)
return self._create_processor(compiled_grammar, **kwargs)

def _structural_tag_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]:
def _structural_tag_processor(self, schemata: str, **kwargs) -> Optional[XGrammarProcessor]:
"""
Compile structural tags into a grammar processor.

Expand All @@ -327,7 +328,7 @@ def _structural_tag_processor(self, schemata: str, enable_thinking: bool = False
except Exception as e:
llm_logger.error(f"Failed to compile structural tags schema: {e}, {str(traceback.format_exc())}")
return None
return self._create_processor(compiled_grammar, enable_thinking=enable_thinking)
return self._create_processor(compiled_grammar, **kwargs)


class XGrammarChecker(BaseChecker):
Expand Down
Loading
Loading