Skip to content

Commit c8c892a

Browse files
Add Xgrammar Support (#701)
Change: - change launch argument from `--simple_constraint_mode` to `--output_constraint_mode`, now user can choose the constriant decode backend from ['outlines', 'xgrammar'] - add `XgrammarBackend` used for xgrammar constraint decode, maybe we should merge it with `SimpleConstraintBackend` later? - now we adopt the same request body, the same as vLLM with xgrammar(https://docs.vllm.ai/en/stable/serving/openai_compatible_server.html) - user can add `guided_grammar` to pass a EBNF grammar and execute the constraint decode - user can add `guided_json` to pass a standard json schema and do the constraint decode --------- Co-authored-by: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com>
1 parent c483b1e commit c8c892a

File tree

13 files changed

+517
-16
lines changed

13 files changed

+517
-16
lines changed

lightllm/server/api_cli.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,14 @@ def make_argument_parser() -> argparse.ArgumentParser:
149149
parser.add_argument("--enable_chunked_prefill", action="store_true", help="whether to disable chunked prefill")
150150
parser.add_argument("--diverse_mode", action="store_true", help="diversity generation mode")
151151
parser.add_argument("--token_healing_mode", action="store_true", help="code model infer mode")
152-
parser.add_argument("--simple_constraint_mode", action="store_true", help="output constraint mode")
152+
153+
parser.add_argument(
154+
"--output_constraint_mode",
155+
type=str,
156+
choices=["outlines", "xgrammar", "none"],
157+
default="none",
158+
help="set the output constraint backend, none means no output constraint",
159+
)
153160
parser.add_argument(
154161
"--first_token_constraint_mode",
155162
action="store_true",

lightllm/server/core/objs/py_sampling_params.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,11 @@ def __init__(
4747
# Whether to count input tokens for presence_penalty, frequency_penalty and repetition_penalty
4848
input_penalty: bool = DEFAULT_INPUT_PENALTY,
4949
regular_constraint: Optional[str] = None, # Regular expressions constrain the output.
50+
guided_grammar: Optional[str] = None, # EBNF constrain the output.
51+
guided_json: Optional[Union[str, dict]] = None, # JSON schema constrain the output.
5052
# If provided, the engine will construct a logits,
5153
# processor which only retains scores for the given token ids. Defaults to None.
52-
# allowed_token_ids only can be used in "--simple_constraint_mode" started server.
54+
# allowed_token_ids only can be used in "--output_constraint_mode outlines" started server.
5355
allowed_token_ids: Optional[List[int]] = None,
5456
# p d mode used params
5557
group_request_id: Optional[int] = None,
@@ -81,6 +83,8 @@ def __init__(
8183
self.add_spaces_between_special_tokens = add_spaces_between_special_tokens
8284
self.print_eos_token = print_eos_token
8385
self.regular_constraint = regular_constraint
86+
self.guided_grammar = guided_grammar
87+
self.guided_json = guided_json
8488
self.allowed_token_ids = allowed_token_ids
8589
self.group_request_id = group_request_id
8690
self.move_kv_to_decode_node = move_kv_to_decode_node
@@ -257,6 +261,8 @@ def to_dict(self):
257261
ret["best_of"] = self.best_of
258262
ret["input_penalty"] = self.input_penalty
259263
ret["regular_constraint"] = self.regular_constraint
264+
ret["guided_grammar"] = self.guided_grammar
265+
ret["guided_json"] = self.guided_json
260266
ret["allowed_token_ids"] = self.allowed_token_ids
261267
ret["move_kv_to_decode_node"] = self.move_kv_to_decode_node
262268
return ret

lightllm/server/core/objs/sampling_params.py

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
ALLOWED_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_ALLOWED_TOKEN_IDS_MAX_LENGTH", 256))
1313
MAX_STOP_SEQUENCES = int(os.getenv("LIGHTLLM_MAX_STOP_SEQUENCES", 10))
1414
REGULAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH", 2048))
15+
GRAMMAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_GRAMMAR_CONSTRAINT_MAX_LENGTH", 2048))
16+
JSON_SCHEMA_MAX_LENGTH = int(os.getenv("LIGHTLLM_JSON_SCHEMA_MAX_LENGTH", 2048))
1517

1618

1719
class StopSequence(ctypes.Structure):
@@ -76,7 +78,7 @@ def to_list(self):
7678
class RegularConstraint(ctypes.Structure):
7779
_pack_ = 4
7880
_fields_ = [
79-
("constraint", ctypes.c_byte * REGULAR_CONSTRAINT_MAX_LENGTH),
81+
("constraint", ctypes.c_ubyte * REGULAR_CONSTRAINT_MAX_LENGTH),
8082
("length", ctypes.c_int),
8183
]
8284

@@ -98,6 +100,66 @@ def to_str(self):
98100
return bytes(self.constraint[0 : self.length]).decode("utf-8").rstrip("\x00")
99101

100102

103+
class GuidedGrammar(ctypes.Structure):
104+
_pack_ = 4
105+
_fields_ = [
106+
("constraint", ctypes.c_ubyte * GRAMMAR_CONSTRAINT_MAX_LENGTH),
107+
("length", ctypes.c_int),
108+
]
109+
110+
def initialize(self, constraint: str, tokenizer):
111+
constraint_bytes = constraint.encode("utf-8")
112+
assert len(constraint_bytes) < GRAMMAR_CONSTRAINT_MAX_LENGTH, "Guided grammar is too long."
113+
114+
ctypes.memmove(self.constraint, constraint_bytes, len(constraint_bytes))
115+
self.length = len(constraint_bytes)
116+
try:
117+
if self.length > 0:
118+
import xgrammar as xgr
119+
120+
tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer)
121+
xgrammar_compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
122+
xgrammar_compiler.compile_grammar(constraint)
123+
except Exception as e:
124+
raise ValueError(f"guided_grammar '{constraint}' has compile_grammar_error: {str(e)}")
125+
return
126+
127+
def to_str(self):
128+
if self.length == 0:
129+
return ""
130+
return bytes(self.constraint[0 : self.length]).decode("utf-8").rstrip("\x00")
131+
132+
133+
class GuidedJsonSchema(ctypes.Structure):
134+
_pack_ = 4
135+
_fields_ = [
136+
("constraint", ctypes.c_ubyte * JSON_SCHEMA_MAX_LENGTH),
137+
("length", ctypes.c_int),
138+
]
139+
140+
def initialize(self, constraint: str, tokenizer):
141+
constraint_bytes = constraint.encode("utf-8")
142+
assert len(constraint_bytes) < JSON_SCHEMA_MAX_LENGTH, "Guided json schema is too long."
143+
144+
ctypes.memmove(self.constraint, constraint_bytes, len(constraint_bytes))
145+
self.length = len(constraint_bytes)
146+
try:
147+
if self.length > 0:
148+
import xgrammar as xgr
149+
150+
tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer)
151+
xgrammar_compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
152+
xgrammar_compiler.compile_json_schema(constraint)
153+
except Exception as e:
154+
raise ValueError(f"guided_grammar '{constraint}' has compile_grammar_error: {str(e)}")
155+
return
156+
157+
def to_str(self):
158+
if self.length == 0:
159+
return ""
160+
return bytes(self.constraint[0 : self.length]).decode("utf-8").rstrip("\x00")
161+
162+
101163
class AllowedTokenIds(ctypes.Structure):
102164
_pack_ = 4
103165
_fields_ = [
@@ -191,9 +253,11 @@ class SamplingParams(ctypes.Structure):
191253
# Whether to count input tokens for presence_penalty, frequency_penalty and repetition_penalty
192254
("input_penalty", ctypes.c_bool),
193255
("regular_constraint", RegularConstraint),
256+
("guided_grammar", GuidedGrammar),
257+
("guided_json", GuidedJsonSchema),
194258
# If provided, the engine will construct a logits,
195259
# processor which only retains scores for the given token ids. Defaults to None.
196-
# allowed_token_ids only can be used in "--simple_constraint_mode" started server.
260+
# allowed_token_ids only can be used in "--output_constraint_mode outlines" started server.
197261
("allowed_token_ids", AllowedTokenIds),
198262
("stop_sequences", StopSequenceGroups),
199263
("exponential_decay_length_penalty", ExponentialDecayLengthPenalty),
@@ -251,6 +315,16 @@ def init(self, tokenizer, **kwargs):
251315
self.regular_constraint = RegularConstraint()
252316
self.regular_constraint.initialize(regular_constraint)
253317

318+
# Initialize guided_grammar
319+
guided_grammar = kwargs.get("guided_grammar", "")
320+
self.guided_grammar = GuidedGrammar()
321+
self.guided_grammar.initialize(guided_grammar, tokenizer)
322+
323+
# Initialize guided_json
324+
guided_json = kwargs.get("guided_json", "")
325+
self.guided_json = GuidedJsonSchema()
326+
self.guided_json.initialize(guided_json, tokenizer)
327+
254328
# Initialize stop_sequence_groups
255329
stop_sequences = kwargs.get("stop_sequences", [])
256330
self.stop_sequences = StopSequenceGroups()
@@ -316,13 +390,26 @@ def verify(self):
316390
)
317391

318392
self._verify_allowed_token_ids()
393+
self._verify_grammar_constraint()
319394

320395
return
321396

397+
def _verify_grammar_constraint(self):
398+
if self.guided_grammar.length != 0:
399+
if self.regular_constraint.length != 0:
400+
raise ValueError("guided_grammar and regular_constraint can not be used in same time")
401+
if self.guided_json.length != 0:
402+
raise ValueError("guided_grammar and guided_json can not be used in same time")
403+
return
404+
322405
def _verify_allowed_token_ids(self):
323406
if self.allowed_token_ids.size != 0:
324407
if self.regular_constraint.length != 0:
325408
raise ValueError("allowed_token_ids and regular_constraint can not be used in same time")
409+
if self.guided_grammar.length != 0:
410+
raise ValueError("allowed_token_ids and guided_grammar can not be used in same time")
411+
if self.guided_json.length != 0:
412+
raise ValueError("allowed_token_ids and guided_json can not be used in same time")
326413
return
327414

328415
def to_dict(self):
@@ -342,6 +429,8 @@ def to_dict(self):
342429
"best_of": self.best_of,
343430
"input_penalty": self.input_penalty,
344431
"regular_constraint": self.regular_constraint.to_str(),
432+
"guided_grammar": self.guided_grammar.to_str(),
433+
"guided_json": self.guided_json.to_str(),
345434
"allowed_token_ids": self.allowed_token_ids.to_list(),
346435
"group_request_id": self.group_request_id,
347436
"move_kv_to_decode_node": self.move_kv_to_decode_node.to_dict(),

lightllm/server/core/objs/start_args_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class StartArgs:
4141
enable_chunked_prefill: bool = field(default=False)
4242
diverse_mode: bool = field(default=False)
4343
token_healing_mode: bool = field(default=False)
44-
simple_constraint_mode: bool = field(default=False)
44+
output_constraint_mode: str = field(default="none", metadata={"choices": ["none", "simple", "xgrammar"]})
4545
first_token_constraint_mode: bool = field(default=False)
4646
enable_multimodal: bool = field(default=False)
4747
cache_capacity: int = field(default=200)

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import collections
88

99
from dataclasses import dataclass, field
10-
from typing import List, Dict, Tuple, Optional, Any
10+
from typing import List, Dict, Tuple, Optional, Union, Any
1111
from lightllm.common.req_manager import ReqManager
1212
from lightllm.utils.infer_utils import mark_start, mark_end
1313
from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager
@@ -194,10 +194,15 @@ def __init__(
194194

195195
# output constraint states
196196
self.regular_constraint = self.shm_param.regular_constraint.to_str()
197+
self.guided_grammar = self.shm_param.guided_grammar.to_str()
198+
self.guided_json = self.shm_param.guided_json.to_str()
197199
if len(self.regular_constraint) == 0:
198200
self.regular_constraint = None
201+
if len(self.guided_grammar) == 0:
202+
self.guided_grammar = None
203+
if len(self.guided_json) == 0:
204+
self.guided_json = None
199205

200-
self.regex_guide = None
201206
self.fsm_current_state: int = 0
202207
self.allowed_token_ids = self.shm_param.allowed_token_ids.to_list()
203208
if len(self.allowed_token_ids) == 0:
@@ -217,7 +222,12 @@ def __init__(
217222
return
218223

219224
def has_constraint_setting(self) -> bool:
220-
return self.regular_constraint is not None or self.allowed_token_ids is not None
225+
return (
226+
self.regular_constraint is not None
227+
or self.allowed_token_ids is not None
228+
or self.guided_grammar is not None
229+
or self.guided_json is not None
230+
)
221231

222232

223233
class InferReq:

lightllm/server/router/model_infer/mode_backend/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from .chunked_prefill.impl import ChunkedPrefillBackend
55
from .diverse_backend.impl import DiversehBackend
66
from .continues_batch.impl_for_token_healing import TokenHealingBackend
7-
from .continues_batch.impl_for_simple_constraint_mode import SimpleConstraintBackend
7+
from .continues_batch.impl_for_outlines_constraint_mode import OutlinesConstraintBackend
88
from .continues_batch.impl_for_first_token_constraint_mode import FirstTokenConstraintBackend
99
from .dp_backend.impl import DPBackend
1010
from .continues_batch.pd_mode.prefill_node_impl.prefill_impl import ContinuesBatchBackendForPrefillNode
1111
from .continues_batch.pd_mode.decode_node_impl.decode_impl import ContinuesBatchBackendForDecodeNode
12+
from .continues_batch.impl_for_xgrammar_mode import XgrammarBackend

lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_simple_constraint_mode.py renamed to lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_outlines_constraint_mode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
logger = init_logger(__name__)
1515

1616

17-
class SimpleConstraintBackend(ContinuesBatchBackend):
17+
class OutlinesConstraintBackend(ContinuesBatchBackend):
1818
def __init__(self) -> None:
1919
super().__init__()
2020

0 commit comments

Comments
 (0)