-
Notifications
You must be signed in to change notification settings - Fork 3.1k
[Feature] Support llguidance for constrained decoding #3298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
23bd5c3
Add llguidance support
JC1DA ed8f927
Compile regex before creating GuidanceGrammar
JC1DA 07ed65a
Add support for ebnf grammar
JC1DA 6fe4019
remove llguidance_utils
JC1DA 51a6ac0
Update llguidance to 0.6.15
JC1DA 35ba9d8
Fix incorrect bitmask shape when batch size changes
JC1DA 473e7e2
Add llguidance testcases
JC1DA c84af1b
Use LarkCompiler to compile ebnf to grammar
JC1DA 4495cc9
format code
JC1DA b165a8a
Update structured outputs docs
JC1DA 3bf08ad
Update structured outputs docs
JC1DA 8f97d2b
Fix typos
JC1DA 4029cc2
add TestJumpForwardLLGuidance testcase
JC1DA 1535cb3
Add TODO for fast-forward tokens support
JC1DA c7e8c31
Merge branch 'main' into integrate_guidance
zhaochenyang20 d4bdd96
Merge branch 'main' into integrate_guidance
zhaochenyang20 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
# Copyright 2023-2024 SGLang Team | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Constrained decoding with llguidance backend.""" | ||
|
||
import json | ||
import os | ||
from typing import List, Optional, Tuple | ||
|
||
import llguidance | ||
import llguidance.hf | ||
import llguidance.torch | ||
import torch | ||
from llguidance.gbnf_to_lark import any_to_lark | ||
|
||
from sglang.srt.constrained.base_grammar_backend import ( | ||
BaseGrammarBackend, | ||
BaseGrammarObject, | ||
) | ||
|
||
|
||
class GuidanceGrammar(BaseGrammarObject): | ||
def __init__( | ||
self, llguidance_tokenizer: llguidance.LLTokenizer, serialized_grammar: str | ||
): | ||
self.llguidance_tokenizer = llguidance_tokenizer | ||
self.serialized_grammar = serialized_grammar | ||
|
||
# TODO: add support for fast-forward tokens in the future | ||
self.ll_interpreter = llguidance.LLInterpreter( | ||
self.llguidance_tokenizer, | ||
self.serialized_grammar, | ||
enable_backtrack=False, | ||
enable_ff_tokens=False, | ||
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")), | ||
) | ||
self.pending_ff_tokens: list[int] = [] | ||
self.finished = False | ||
self.bitmask = None | ||
|
||
def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]: | ||
if len(self.pending_ff_tokens) > 0: | ||
s = self.llguidance_tokenizer.decode_str(self.pending_ff_tokens) | ||
ff_tokens = self.pending_ff_tokens | ||
self.pending_ff_tokens = [] | ||
return (ff_tokens, s) | ||
|
||
return None | ||
|
||
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]: | ||
return "", -1 | ||
|
||
def jump_and_retokenize( | ||
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int | ||
): | ||
pass | ||
|
||
def accept_token(self, token: int): | ||
backtrack, ff_tokens = self.ll_interpreter.commit_token(token) | ||
if len(ff_tokens) > 0 and backtrack == 0: | ||
# first token is last generated token | ||
ff_tokens = ff_tokens[1:] | ||
self.pending_ff_tokens.extend(ff_tokens) | ||
|
||
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: | ||
if len(self.pending_ff_tokens) > 0: | ||
# if we have pending fast-forward tokens, | ||
# just return them immediately | ||
ff_token = self.pending_ff_tokens.pop(0) | ||
vocab_mask[idx, :] = 0 | ||
vocab_mask[idx, ff_token // 32] = 1 << (ff_token % 32) | ||
return | ||
|
||
if self.ll_interpreter.has_pending_stop(): | ||
self.finished = True | ||
|
||
llguidance.torch.fill_next_token_bitmask(self.ll_interpreter, vocab_mask, idx) | ||
|
||
def allocate_vocab_mask( | ||
self, vocab_size: int, batch_size: int, device | ||
) -> torch.Tensor: | ||
if self.bitmask is None or self.bitmask.shape[0] < batch_size: | ||
# only create bitmask when batch gets larger | ||
self.bitmask = llguidance.torch.allocate_token_bitmask( | ||
batch_size, self.llguidance_tokenizer.vocab_size | ||
) | ||
bitmask = self.bitmask | ||
else: | ||
bitmask = self.bitmask[:batch_size] | ||
|
||
return bitmask | ||
|
||
@staticmethod | ||
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor: | ||
return vocab_mask.to(device, non_blocking=True) | ||
|
||
@staticmethod | ||
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: | ||
llguidance.torch.apply_token_bitmask_inplace(logits, vocab_mask) | ||
|
||
def copy(self): | ||
return GuidanceGrammar( | ||
llguidance_tokenizer=self.llguidance_tokenizer, | ||
serialized_grammar=self.serialized_grammar, | ||
) | ||
|
||
|
||
class GuidanceBackend(BaseGrammarBackend): | ||
def __init__(self, tokenizer, whitespace_pattern: Optional[str] = None): | ||
super().__init__() | ||
|
||
self.tokenizer = tokenizer | ||
self.whitespace_flexible = ( | ||
True if whitespace_pattern == "whitespace_flexible" else False | ||
) | ||
self.llguidance_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None) | ||
|
||
def init_value_impl(self, key: Tuple[str, str]) -> GuidanceGrammar: | ||
mode, value = key | ||
if mode == "json": | ||
json_schema = value | ||
compiler = llguidance.JsonCompiler( | ||
whitespace_flexible=self.whitespace_flexible | ||
) | ||
serialized_grammar = compiler.compile(json_schema) | ||
elif mode == "regex": | ||
compiler = llguidance.RegexCompiler() | ||
serialized_grammar = compiler.compile(regex=value) | ||
elif mode == "ebnf": | ||
compiler = llguidance.LarkCompiler() | ||
serialized_grammar = compiler.compile(any_to_lark(value)) | ||
|
||
return GuidanceGrammar( | ||
llguidance_tokenizer=self.llguidance_tokenizer, | ||
serialized_grammar=serialized_grammar, | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@shuaills How can we set xgrammar as the default one?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhaochenyang20 ServerArgs.grammar_backend is set to outlines by default (from main branch, we didn't set this value)
https://github.com/sgl-project/sglang/blob/45c465b923cada5bea8cb5b74eedda042a49dc9d/python/sglang/srt/server_args.py#L121
Do you want me to update the flag to xgrammar by default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will have a separate PR for that, thanks