Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/backend/structured_outputs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
"\n",
"- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints.\n",
"- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema, regular expression, and EBNF constraints.\n",
"- [Llguidance](https://github.com/guidance-ai/llguidance): Supports JSON schema, regular expression, and EBNF constraints.\n",
"\n",
"We suggest using XGrammar for its better performance and utility. XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md). For more details, see [XGrammar technical overview](https://blog.mlc.ai/2024/11/22/achieving-efficient-flexible-portable-structured-generation-with-xgrammar).\n",
"\n",
"To use Xgrammar, simply add `--grammar-backend` xgrammar when launching the server. If no backend is specified, Outlines will be used as the default.\n",
"To use Xgrammar, simply add `--grammar-backend xgrammar` when launching the server.\n",
"To use llguidance, add `--grammar-backend llguidance` when launching the server.\n",
"If no backend is specified, Outlines will be used as the default.\n",
"\n",
"For better output quality, **It's advisable to explicitly include instructions in the prompt to guide the model to generate the desired format.** For example, you can specify, 'Please generate the output in the following JSON format: ...'.\n"
]
Expand Down
1 change: 1 addition & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ runtime_common = [
"xgrammar==0.1.10",
"ninja",
"transformers==4.48.3",
"llguidance>=0.6.15"
]
srt = [
"sglang[runtime_common]",
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/constrained/base_grammar_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size):
from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend

grammar_backend = XGrammarGrammarBackend(tokenizer, vocab_size=vocab_size)
elif server_args.grammar_backend == "llguidance":
from sglang.srt.constrained.llguidance_backend import GuidanceBackend

grammar_backend = GuidanceBackend(
tokenizer=tokenizer,
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
)
else:
raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")

Expand Down
146 changes: 146 additions & 0 deletions python/sglang/srt/constrained/llguidance_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Constrained decoding with llguidance backend."""

import json
import os
from typing import List, Optional, Tuple

import llguidance
import llguidance.hf
import llguidance.torch
import torch
from llguidance.gbnf_to_lark import any_to_lark

from sglang.srt.constrained.base_grammar_backend import (
BaseGrammarBackend,
BaseGrammarObject,
)


class GuidanceGrammar(BaseGrammarObject):
def __init__(
self, llguidance_tokenizer: llguidance.LLTokenizer, serialized_grammar: str
):
self.llguidance_tokenizer = llguidance_tokenizer
self.serialized_grammar = serialized_grammar

# TODO: add support for fast-forward tokens in the future
self.ll_interpreter = llguidance.LLInterpreter(
self.llguidance_tokenizer,
self.serialized_grammar,
enable_backtrack=False,
enable_ff_tokens=False,
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
)
self.pending_ff_tokens: list[int] = []
self.finished = False
self.bitmask = None

def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]:
if len(self.pending_ff_tokens) > 0:
s = self.llguidance_tokenizer.decode_str(self.pending_ff_tokens)
ff_tokens = self.pending_ff_tokens
self.pending_ff_tokens = []
return (ff_tokens, s)

return None

def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
return "", -1

def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
):
pass

def accept_token(self, token: int):
backtrack, ff_tokens = self.ll_interpreter.commit_token(token)
if len(ff_tokens) > 0 and backtrack == 0:
# first token is last generated token
ff_tokens = ff_tokens[1:]
self.pending_ff_tokens.extend(ff_tokens)

def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
if len(self.pending_ff_tokens) > 0:
# if we have pending fast-forward tokens,
# just return them immediately
ff_token = self.pending_ff_tokens.pop(0)
vocab_mask[idx, :] = 0
vocab_mask[idx, ff_token // 32] = 1 << (ff_token % 32)
return

if self.ll_interpreter.has_pending_stop():
self.finished = True

llguidance.torch.fill_next_token_bitmask(self.ll_interpreter, vocab_mask, idx)

def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
) -> torch.Tensor:
if self.bitmask is None or self.bitmask.shape[0] < batch_size:
# only create bitmask when batch gets larger
self.bitmask = llguidance.torch.allocate_token_bitmask(
batch_size, self.llguidance_tokenizer.vocab_size
)
bitmask = self.bitmask
else:
bitmask = self.bitmask[:batch_size]

return bitmask

@staticmethod
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
return vocab_mask.to(device, non_blocking=True)

@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
llguidance.torch.apply_token_bitmask_inplace(logits, vocab_mask)

def copy(self):
return GuidanceGrammar(
llguidance_tokenizer=self.llguidance_tokenizer,
serialized_grammar=self.serialized_grammar,
)


class GuidanceBackend(BaseGrammarBackend):
def __init__(self, tokenizer, whitespace_pattern: Optional[str] = None):
super().__init__()

self.tokenizer = tokenizer
self.whitespace_flexible = (
True if whitespace_pattern == "whitespace_flexible" else False
)
self.llguidance_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None)

def init_value_impl(self, key: Tuple[str, str]) -> GuidanceGrammar:
mode, value = key
if mode == "json":
json_schema = value
compiler = llguidance.JsonCompiler(
whitespace_flexible=self.whitespace_flexible
)
serialized_grammar = compiler.compile(json_schema)
elif mode == "regex":
compiler = llguidance.RegexCompiler()
serialized_grammar = compiler.compile(regex=value)
elif mode == "ebnf":
compiler = llguidance.LarkCompiler()
serialized_grammar = compiler.compile(any_to_lark(value))

return GuidanceGrammar(
llguidance_tokenizer=self.llguidance_tokenizer,
serialized_grammar=serialized_grammar,
)
2 changes: 1 addition & 1 deletion python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--grammar-backend",
type=str,
choices=["xgrammar", "outlines"],
choices=["xgrammar", "outlines", "llguidance"],
default=ServerArgs.grammar_backend,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shuaills How can we set xgrammar as the default one?

Copy link
Contributor Author

@JC1DA JC1DA Feb 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhaochenyang20 ServerArgs.grammar_backend is set to outlines by default (from main branch, we didn't set this value)

https://github.com/sgl-project/sglang/blob/45c465b923cada5bea8cb5b74eedda042a49dc9d/python/sglang/srt/server_args.py#L121

Do you want me to update the flag to xgrammar by default?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will have a separate PR for that, thanks

help="Choose the backend for grammar-guided decoding.",
)
Expand Down
15 changes: 12 additions & 3 deletions test/srt/test_ebnf_constrained.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_email
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_greeting
python3 -m unittest test_ebnf_constrained.TestEBNFConstrainedLLGuidance.test_ebnf_generate_email
python3 -m unittest test_ebnf_constrained.TestEBNFConstrainedLLGuidance.test_ebnf_generate_greeting
"""

import json
Expand All @@ -17,7 +19,7 @@
)


def setup_class(cls, disable_overlap: bool):
def setup_class(cls, backend: str, disable_overlap: bool):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.ebnf_grammar = 'root ::= "test"' # Default grammar
Expand All @@ -26,7 +28,7 @@ def setup_class(cls, disable_overlap: bool):
"--max-running-requests",
"10",
"--grammar-backend",
"xgrammar",
backend,
]

if disable_overlap:
Expand All @@ -43,7 +45,7 @@ def setup_class(cls, disable_overlap: bool):
class TestEBNFConstrained(unittest.TestCase):
@classmethod
def setUpClass(cls):
setup_class(cls, disable_overlap=False)
setup_class(cls, "xgrammar", disable_overlap=False)
cls.check_jump_forward = False

@classmethod
Expand Down Expand Up @@ -236,5 +238,12 @@ def test_ebnf_generate_custom_log_format(self):
)


class TestEBNFConstrainedLLGuidance(TestEBNFConstrained):
@classmethod
def setUpClass(cls):
setup_class(cls, "llguidance", disable_overlap=False)
cls.check_jump_forward = False


if __name__ == "__main__":
unittest.main()
9 changes: 9 additions & 0 deletions test/srt/test_json_constrained.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
python3 -m unittest test_json_constrained.TestJSONConstrainedOutlinesBackend.test_json_generate
python3 -m unittest test_json_constrained.TestJSONConstrainedXGrammarBackend.test_json_generate
python3 -m unittest test_json_constrained.TestJSONConstrainedLLGuidanceBackend.test_json_generate
"""

import json
Expand Down Expand Up @@ -30,6 +31,7 @@ def setup_class(cls, backend: str, disable_overlap: bool):
"population": {"type": "integer"},
},
"required": ["name", "population"],
"additionalProperties": False,
}
)

Expand Down Expand Up @@ -146,5 +148,12 @@ def setUpClass(cls):
cls.check_jump_forward = False


class TestJSONConstrainedLLGuidanceBackend(TestJSONConstrainedOutlinesBackend):
@classmethod
def setUpClass(cls):
setup_class(cls, backend="llguidance", disable_overlap=False)
cls.check_jump_forward = False


if __name__ == "__main__":
unittest.main()
25 changes: 21 additions & 4 deletions test/srt/test_regex_constrained.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
"""
python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_email
python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_greeting
python3 -m unittest test_regex_constrained.TestRegexConstrainedLLGuidance.test_regex_generate_email
python3 -m unittest test_regex_constrained.TestRegexConstrainedLLGuidance.test_regex_generate_greeting
python3 -m unittest test_regex_constrained.TestJumpForwardLLGuidance.test_regex_generate_email
python3 -m unittest test_regex_constrained.TestJumpForwardLLGuidance.test_regex_generate_greeting
"""

import json
Expand All @@ -17,15 +21,15 @@
)


def setup_class(cls, disable_overlap: bool):
def setup_class(cls, backend: str, disable_overlap: bool):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST

other_args = [
"--max-running-requests",
"10",
"--grammar-backend",
"xgrammar",
backend,
]

if disable_overlap:
Expand All @@ -42,7 +46,7 @@ def setup_class(cls, disable_overlap: bool):
class TestRegexConstrained(unittest.TestCase):
@classmethod
def setUpClass(cls):
setup_class(cls, disable_overlap=False)
setup_class(cls, "xgrammar", disable_overlap=False)
cls.check_jump_forward = False

@classmethod
Expand Down Expand Up @@ -178,9 +182,22 @@ def test_regex_generate_custom_log_format(self):
class TestJumpForward(TestRegexConstrained):
@classmethod
def setUpClass(cls):
setup_class(cls, disable_overlap=True)
setup_class(cls, "xgrammar", disable_overlap=True)
cls.check_jump_forward = True


class TestJumpForwardLLGuidance(TestRegexConstrained):
@classmethod
def setUpClass(cls):
setup_class(cls, "llguidance", disable_overlap=True)
cls.check_jump_forward = True


class TestRegexConstrainedLLGuidance(TestRegexConstrained):
@classmethod
def setUpClass(cls):
setup_class(cls, "llguidance", disable_overlap=True)


if __name__ == "__main__":
unittest.main()
Loading