12
12
ALLOWED_TOKEN_IDS_MAX_LENGTH = int (os .getenv ("LIGHTLLM_ALLOWED_TOKEN_IDS_MAX_LENGTH" , 256 ))
13
13
MAX_STOP_SEQUENCES = int (os .getenv ("LIGHTLLM_MAX_STOP_SEQUENCES" , 10 ))
14
14
REGULAR_CONSTRAINT_MAX_LENGTH = int (os .getenv ("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH" , 2048 ))
15
+ GRAMMAR_CONSTRAINT_MAX_LENGTH = int (os .getenv ("LIGHTLLM_GRAMMAR_CONSTRAINT_MAX_LENGTH" , 2048 ))
16
+ JSON_SCHEMA_MAX_LENGTH = int (os .getenv ("LIGHTLLM_JSON_SCHEMA_MAX_LENGTH" , 2048 ))
15
17
16
18
17
19
class StopSequence (ctypes .Structure ):
@@ -76,7 +78,7 @@ def to_list(self):
76
78
class RegularConstraint (ctypes .Structure ):
77
79
_pack_ = 4
78
80
_fields_ = [
79
- ("constraint" , ctypes .c_byte * REGULAR_CONSTRAINT_MAX_LENGTH ),
81
+ ("constraint" , ctypes .c_ubyte * REGULAR_CONSTRAINT_MAX_LENGTH ),
80
82
("length" , ctypes .c_int ),
81
83
]
82
84
@@ -98,6 +100,66 @@ def to_str(self):
98
100
return bytes (self .constraint [0 : self .length ]).decode ("utf-8" ).rstrip ("\x00 " )
99
101
100
102
103
+ class GuidedGrammar (ctypes .Structure ):
104
+ _pack_ = 4
105
+ _fields_ = [
106
+ ("constraint" , ctypes .c_ubyte * GRAMMAR_CONSTRAINT_MAX_LENGTH ),
107
+ ("length" , ctypes .c_int ),
108
+ ]
109
+
110
+ def initialize (self , constraint : str , tokenizer ):
111
+ constraint_bytes = constraint .encode ("utf-8" )
112
+ assert len (constraint_bytes ) < GRAMMAR_CONSTRAINT_MAX_LENGTH , "Guided grammar is too long."
113
+
114
+ ctypes .memmove (self .constraint , constraint_bytes , len (constraint_bytes ))
115
+ self .length = len (constraint_bytes )
116
+ try :
117
+ if self .length > 0 :
118
+ import xgrammar as xgr
119
+
120
+ tokenizer_info = xgr .TokenizerInfo .from_huggingface (tokenizer )
121
+ xgrammar_compiler = xgr .GrammarCompiler (tokenizer_info , max_threads = 8 )
122
+ xgrammar_compiler .compile_grammar (constraint )
123
+ except Exception as e :
124
+ raise ValueError (f"guided_grammar '{ constraint } ' has compile_grammar_error: { str (e )} " )
125
+ return
126
+
127
+ def to_str (self ):
128
+ if self .length == 0 :
129
+ return ""
130
+ return bytes (self .constraint [0 : self .length ]).decode ("utf-8" ).rstrip ("\x00 " )
131
+
132
+
133
+ class GuidedJsonSchema (ctypes .Structure ):
134
+ _pack_ = 4
135
+ _fields_ = [
136
+ ("constraint" , ctypes .c_ubyte * JSON_SCHEMA_MAX_LENGTH ),
137
+ ("length" , ctypes .c_int ),
138
+ ]
139
+
140
+ def initialize (self , constraint : str , tokenizer ):
141
+ constraint_bytes = constraint .encode ("utf-8" )
142
+ assert len (constraint_bytes ) < JSON_SCHEMA_MAX_LENGTH , "Guided json schema is too long."
143
+
144
+ ctypes .memmove (self .constraint , constraint_bytes , len (constraint_bytes ))
145
+ self .length = len (constraint_bytes )
146
+ try :
147
+ if self .length > 0 :
148
+ import xgrammar as xgr
149
+
150
+ tokenizer_info = xgr .TokenizerInfo .from_huggingface (tokenizer )
151
+ xgrammar_compiler = xgr .GrammarCompiler (tokenizer_info , max_threads = 8 )
152
+ xgrammar_compiler .compile_json_schema (constraint )
153
+ except Exception as e :
154
+ raise ValueError (f"guided_grammar '{ constraint } ' has compile_grammar_error: { str (e )} " )
155
+ return
156
+
157
+ def to_str (self ):
158
+ if self .length == 0 :
159
+ return ""
160
+ return bytes (self .constraint [0 : self .length ]).decode ("utf-8" ).rstrip ("\x00 " )
161
+
162
+
101
163
class AllowedTokenIds (ctypes .Structure ):
102
164
_pack_ = 4
103
165
_fields_ = [
@@ -191,9 +253,11 @@ class SamplingParams(ctypes.Structure):
191
253
# Whether to count input tokens for presence_penalty, frequency_penalty and repetition_penalty
192
254
("input_penalty" , ctypes .c_bool ),
193
255
("regular_constraint" , RegularConstraint ),
256
+ ("guided_grammar" , GuidedGrammar ),
257
+ ("guided_json" , GuidedJsonSchema ),
194
258
# If provided, the engine will construct a logits,
195
259
# processor which only retains scores for the given token ids. Defaults to None.
196
- # allowed_token_ids only can be used in "--simple_constraint_mode " started server.
260
+ # allowed_token_ids only can be used in "--output_constraint_mode outlines " started server.
197
261
("allowed_token_ids" , AllowedTokenIds ),
198
262
("stop_sequences" , StopSequenceGroups ),
199
263
("exponential_decay_length_penalty" , ExponentialDecayLengthPenalty ),
@@ -251,6 +315,16 @@ def init(self, tokenizer, **kwargs):
251
315
self .regular_constraint = RegularConstraint ()
252
316
self .regular_constraint .initialize (regular_constraint )
253
317
318
+ # Initialize guided_grammar
319
+ guided_grammar = kwargs .get ("guided_grammar" , "" )
320
+ self .guided_grammar = GuidedGrammar ()
321
+ self .guided_grammar .initialize (guided_grammar , tokenizer )
322
+
323
+ # Initialize guided_json
324
+ guided_json = kwargs .get ("guided_json" , "" )
325
+ self .guided_json = GuidedJsonSchema ()
326
+ self .guided_json .initialize (guided_json , tokenizer )
327
+
254
328
# Initialize stop_sequence_groups
255
329
stop_sequences = kwargs .get ("stop_sequences" , [])
256
330
self .stop_sequences = StopSequenceGroups ()
@@ -316,13 +390,26 @@ def verify(self):
316
390
)
317
391
318
392
self ._verify_allowed_token_ids ()
393
+ self ._verify_grammar_constraint ()
319
394
320
395
return
321
396
397
+ def _verify_grammar_constraint (self ):
398
+ if self .guided_grammar .length != 0 :
399
+ if self .regular_constraint .length != 0 :
400
+ raise ValueError ("guided_grammar and regular_constraint can not be used in same time" )
401
+ if self .guided_json .length != 0 :
402
+ raise ValueError ("guided_grammar and guided_json can not be used in same time" )
403
+ return
404
+
322
405
def _verify_allowed_token_ids (self ):
323
406
if self .allowed_token_ids .size != 0 :
324
407
if self .regular_constraint .length != 0 :
325
408
raise ValueError ("allowed_token_ids and regular_constraint can not be used in same time" )
409
+ if self .guided_grammar .length != 0 :
410
+ raise ValueError ("allowed_token_ids and guided_grammar can not be used in same time" )
411
+ if self .guided_json .length != 0 :
412
+ raise ValueError ("allowed_token_ids and guided_json can not be used in same time" )
326
413
return
327
414
328
415
def to_dict (self ):
@@ -342,6 +429,8 @@ def to_dict(self):
342
429
"best_of" : self .best_of ,
343
430
"input_penalty" : self .input_penalty ,
344
431
"regular_constraint" : self .regular_constraint .to_str (),
432
+ "guided_grammar" : self .guided_grammar .to_str (),
433
+ "guided_json" : self .guided_json .to_str (),
345
434
"allowed_token_ids" : self .allowed_token_ids .to_list (),
346
435
"group_request_id" : self .group_request_id ,
347
436
"move_kv_to_decode_node" : self .move_kv_to_decode_node .to_dict (),
0 commit comments