Skip to content

[V0][Bugfix] Fix parallel sampling performance regression when guided decoding is enabled #17731

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
import copy
import os
from typing import Any

Expand Down Expand Up @@ -34,9 +35,24 @@ def __init__(
self.grammar = grammar
self.tokenizer = tokenizer
self.tokenizer_name = tokenizer.name_or_path
self.ll_tokenizer = None
self.ll_matcher = None
self.bitmask = None
self.new_sampling = False
self.initialized = False

def clone(self) -> "GuidanceLogitsProcessor":
cloned = copy.copy(self)
if self.initialized:
cloned.ll_matcher = llguidance.LLMatcher(
self.ll_tokenizer, # type: ignore[assignment]
self.grammar,
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
)
self.bitmask = llguidance.torch.allocate_token_bitmask(
1, self.ll_tokenizer.vocab_size) # type: ignore[attr-defined]
return cloned

def _initialize(self):
if self.initialized:
return
Expand All @@ -56,7 +72,7 @@ def _initialize(self):

# create reusable bitmask
self.bitmask = llguidance.torch.allocate_token_bitmask(
1, self.ll_tokenizer.vocab_size)
1, self.ll_tokenizer.vocab_size) # type: ignore[attr-defined]

self.initialized = True

Expand All @@ -70,15 +86,17 @@ def __call__(
self._initialize()

if self.new_sampling and len(input_ids) > 0:
self.ll_matcher.consume_token(input_ids[-1])
err = self.ll_matcher.get_error()
self.ll_matcher.consume_token( # type: ignore[attr-defined]
input_ids[-1])
err = self.ll_matcher.get_error() # type: ignore[attr-defined]
if err:
logger.warning("Error in LLMatcher: %s", err)

llguidance.torch.fill_next_token_bitmask(self.ll_matcher, self.bitmask,
0)
llguidance.torch.apply_token_bitmask_inplace(
scores, self.bitmask.to(scores.device))
scores,
self.bitmask.to(scores.device)) # type: ignore[attr-defined]

self.new_sampling = True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ def __init__(self, guide: Guide, reasoner: Optional[ReasoningParser]):
self._fsm_state: defaultdict[int, Union[int,
CFGState]] = defaultdict(int)

def clone(self) -> "BaseLogitsProcessor":
cloned = copy.copy(self)
cloned._guide = self._guide.copy()
cloned._fsm_state = copy.deepcopy(self._fsm_state)
return cloned

Comment on lines +59 to +64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to get #15975 in first before assigning this private attrs.

def __call__(self, input_ids: list[int],
scores: torch.Tensor) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token."""
Expand Down Expand Up @@ -218,6 +224,12 @@ def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase,
reasoner)
self._guide = self._guide.copy()

def clone(self) -> "CFGLogitsProcessor":
cloned = copy.copy(self)
cloned._fsm_state = copy.deepcopy(self._fsm_state)
cloned._guide = self._guide.copy()
return cloned


@lru_cache(maxsize=32)
def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
Expand Down
8 changes: 5 additions & 3 deletions vllm/model_executor/guided_decoding/xgrammar_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,9 @@ class XGrammarLogitsProcessor:
prefilled: bool = field(default=False)

def __post_init__(self):
self.tokenizer_info = self.config.tokenizer_info(
self.config.tokenizer_data)
if self.tokenizer_info is None:
self.tokenizer_info = self.config.tokenizer_info(
self.config.tokenizer_data)

def __getstate__(self) -> dict[str, Any]:
return {'config': self.config, 'reasoner': self.reasoner}
Expand Down Expand Up @@ -400,7 +401,8 @@ def __call__(self, input_ids: list[int],
def clone(self) -> XGrammarLogitsProcessor:
"""Create a new instance with shared compiled grammar
but separate state"""
new_processor = XGrammarLogitsProcessor(self.config, self.reasoner)
new_processor = XGrammarLogitsProcessor(self.config, self.reasoner,
None, self.tokenizer_info)

# Share the compiled grammar context (immutable after compilation)
new_processor.ctx = self.ctx
Expand Down
2 changes: 1 addition & 1 deletion vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,7 +1494,7 @@ def add_request(request_id: str, engine, params, **kwargs):
for i in range(original_params.n):
request_id_i = f"{request_id}_parallel_sample_{i}"
group.seq_id_to_index[request_id_i] = i
params = copy.deepcopy(original_params)
params = params.clone()
params.n = 1
if params.seed is not None:
params.seed += i
Expand Down