Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,14 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument("--enable_chunked_prefill", action="store_true", help="whether to disable chunked prefill")
parser.add_argument("--diverse_mode", action="store_true", help="diversity generation mode")
parser.add_argument("--token_healing_mode", action="store_true", help="code model infer mode")
parser.add_argument("--simple_constraint_mode", action="store_true", help="output constraint mode")

parser.add_argument(
"--output_constraint_mode",
type=str,
choices=["outlines", "xgrammar", "none"],
default="none",
help="set the output constraint backend, none means no output constraint",
)
parser.add_argument(
"--first_token_constraint_mode",
action="store_true",
Expand Down
8 changes: 7 additions & 1 deletion lightllm/server/core/objs/py_sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,11 @@ def __init__(
# Whether to count input tokens for presence_penalty, frequency_penalty and repetition_penalty
input_penalty: bool = DEFAULT_INPUT_PENALTY,
regular_constraint: Optional[str] = None, # Regular expressions constrain the output.
guided_grammar: Optional[str] = None, # EBNF constrain the output.
guided_json: Optional[Union[str, dict]] = None, # JSON schema constrain the output.
# If provided, the engine will construct a logits,
# processor which only retains scores for the given token ids. Defaults to None.
# allowed_token_ids only can be used in "--simple_constraint_mode" started server.
# allowed_token_ids only can be used in "--output_constraint_mode outlines" started server.
allowed_token_ids: Optional[List[int]] = None,
# p d mode used params
group_request_id: Optional[int] = None,
Expand Down Expand Up @@ -81,6 +83,8 @@ def __init__(
self.add_spaces_between_special_tokens = add_spaces_between_special_tokens
self.print_eos_token = print_eos_token
self.regular_constraint = regular_constraint
self.guided_grammar = guided_grammar
self.guided_json = guided_json
self.allowed_token_ids = allowed_token_ids
self.group_request_id = group_request_id
self.move_kv_to_decode_node = move_kv_to_decode_node
Expand Down Expand Up @@ -257,6 +261,8 @@ def to_dict(self):
ret["best_of"] = self.best_of
ret["input_penalty"] = self.input_penalty
ret["regular_constraint"] = self.regular_constraint
ret["guided_grammar"] = self.guided_grammar
ret["guided_json"] = self.guided_json
ret["allowed_token_ids"] = self.allowed_token_ids
ret["move_kv_to_decode_node"] = self.move_kv_to_decode_node
return ret
Expand Down
93 changes: 91 additions & 2 deletions lightllm/server/core/objs/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
ALLOWED_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_ALLOWED_TOKEN_IDS_MAX_LENGTH", 256))
MAX_STOP_SEQUENCES = int(os.getenv("LIGHTLLM_MAX_STOP_SEQUENCES", 10))
REGULAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH", 2048))
GRAMMAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_GRAMMAR_CONSTRAINT_MAX_LENGTH", 2048))
JSON_SCHEMA_MAX_LENGTH = int(os.getenv("LIGHTLLM_JSON_SCHEMA_MAX_LENGTH", 2048))


class StopSequence(ctypes.Structure):
Expand Down Expand Up @@ -76,7 +78,7 @@ def to_list(self):
class RegularConstraint(ctypes.Structure):
_pack_ = 4
_fields_ = [
("constraint", ctypes.c_byte * REGULAR_CONSTRAINT_MAX_LENGTH),
("constraint", ctypes.c_ubyte * REGULAR_CONSTRAINT_MAX_LENGTH),
("length", ctypes.c_int),
]

Expand All @@ -98,6 +100,66 @@ def to_str(self):
return bytes(self.constraint[0 : self.length]).decode("utf-8").rstrip("\x00")


class GuidedGrammar(ctypes.Structure):
_pack_ = 4
_fields_ = [
("constraint", ctypes.c_ubyte * GRAMMAR_CONSTRAINT_MAX_LENGTH),
("length", ctypes.c_int),
]

def initialize(self, constraint: str, tokenizer):
constraint_bytes = constraint.encode("utf-8")
assert len(constraint_bytes) < GRAMMAR_CONSTRAINT_MAX_LENGTH, "Guided grammar is too long."

ctypes.memmove(self.constraint, constraint_bytes, len(constraint_bytes))
self.length = len(constraint_bytes)
try:
if self.length > 0:
import xgrammar as xgr

tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer)
xgrammar_compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
xgrammar_compiler.compile_grammar(constraint)
except Exception as e:
raise ValueError(f"guided_grammar '{constraint}' has compile_grammar_error: {str(e)}")
return

def to_str(self):
if self.length == 0:
return ""
return bytes(self.constraint[0 : self.length]).decode("utf-8").rstrip("\x00")


class GuidedJsonSchema(ctypes.Structure):
_pack_ = 4
_fields_ = [
("constraint", ctypes.c_ubyte * JSON_SCHEMA_MAX_LENGTH),
("length", ctypes.c_int),
]

def initialize(self, constraint: str, tokenizer):
constraint_bytes = constraint.encode("utf-8")
assert len(constraint_bytes) < JSON_SCHEMA_MAX_LENGTH, "Guided json schema is too long."

ctypes.memmove(self.constraint, constraint_bytes, len(constraint_bytes))
self.length = len(constraint_bytes)
try:
if self.length > 0:
import xgrammar as xgr

tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer)
xgrammar_compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
xgrammar_compiler.compile_json_schema(constraint)
except Exception as e:
raise ValueError(f"guided_grammar '{constraint}' has compile_grammar_error: {str(e)}")
return

def to_str(self):
if self.length == 0:
return ""
return bytes(self.constraint[0 : self.length]).decode("utf-8").rstrip("\x00")


class AllowedTokenIds(ctypes.Structure):
_pack_ = 4
_fields_ = [
Expand Down Expand Up @@ -191,9 +253,11 @@ class SamplingParams(ctypes.Structure):
# Whether to count input tokens for presence_penalty, frequency_penalty and repetition_penalty
("input_penalty", ctypes.c_bool),
("regular_constraint", RegularConstraint),
("guided_grammar", GuidedGrammar),
("guided_json", GuidedJsonSchema),
# If provided, the engine will construct a logits,
# processor which only retains scores for the given token ids. Defaults to None.
# allowed_token_ids only can be used in "--simple_constraint_mode" started server.
# allowed_token_ids only can be used in "--output_constraint_mode outlines" started server.
("allowed_token_ids", AllowedTokenIds),
("stop_sequences", StopSequenceGroups),
("exponential_decay_length_penalty", ExponentialDecayLengthPenalty),
Expand Down Expand Up @@ -251,6 +315,16 @@ def init(self, tokenizer, **kwargs):
self.regular_constraint = RegularConstraint()
self.regular_constraint.initialize(regular_constraint)

# Initialize guided_grammar
guided_grammar = kwargs.get("guided_grammar", "")
self.guided_grammar = GuidedGrammar()
self.guided_grammar.initialize(guided_grammar, tokenizer)

# Initialize guided_json
guided_json = kwargs.get("guided_json", "")
self.guided_json = GuidedJsonSchema()
self.guided_json.initialize(guided_json, tokenizer)

# Initialize stop_sequence_groups
stop_sequences = kwargs.get("stop_sequences", [])
self.stop_sequences = StopSequenceGroups()
Expand Down Expand Up @@ -316,13 +390,26 @@ def verify(self):
)

self._verify_allowed_token_ids()
self._verify_grammar_constraint()

return

def _verify_grammar_constraint(self):
if self.guided_grammar.length != 0:
if self.regular_constraint.length != 0:
raise ValueError("guided_grammar and regular_constraint can not be used in same time")
if self.guided_json.length != 0:
raise ValueError("guided_grammar and guided_json can not be used in same time")
return

def _verify_allowed_token_ids(self):
if self.allowed_token_ids.size != 0:
if self.regular_constraint.length != 0:
raise ValueError("allowed_token_ids and regular_constraint can not be used in same time")
if self.guided_grammar.length != 0:
raise ValueError("allowed_token_ids and guided_grammar can not be used in same time")
if self.guided_json.length != 0:
raise ValueError("allowed_token_ids and guided_json can not be used in same time")
return

def to_dict(self):
Expand All @@ -342,6 +429,8 @@ def to_dict(self):
"best_of": self.best_of,
"input_penalty": self.input_penalty,
"regular_constraint": self.regular_constraint.to_str(),
"guided_grammar": self.guided_grammar.to_str(),
"guided_json": self.guided_json.to_str(),
"allowed_token_ids": self.allowed_token_ids.to_list(),
"group_request_id": self.group_request_id,
"move_kv_to_decode_node": self.move_kv_to_decode_node.to_dict(),
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class StartArgs:
enable_chunked_prefill: bool = field(default=False)
diverse_mode: bool = field(default=False)
token_healing_mode: bool = field(default=False)
simple_constraint_mode: bool = field(default=False)
output_constraint_mode: str = field(default="none", metadata={"choices": ["none", "simple", "xgrammar"]})
first_token_constraint_mode: bool = field(default=False)
enable_multimodal: bool = field(default=False)
cache_capacity: int = field(default=200)
Expand Down
16 changes: 13 additions & 3 deletions lightllm/server/router/model_infer/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import collections

from dataclasses import dataclass, field
from typing import List, Dict, Tuple, Optional, Any
from typing import List, Dict, Tuple, Optional, Union, Any
from lightllm.common.req_manager import ReqManager
from lightllm.utils.infer_utils import mark_start, mark_end
from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager
Expand Down Expand Up @@ -194,10 +194,15 @@ def __init__(

# output constraint states
self.regular_constraint = self.shm_param.regular_constraint.to_str()
self.guided_grammar = self.shm_param.guided_grammar.to_str()
self.guided_json = self.shm_param.guided_json.to_str()
if len(self.regular_constraint) == 0:
self.regular_constraint = None
if len(self.guided_grammar) == 0:
self.guided_grammar = None
if len(self.guided_json) == 0:
self.guided_json = None

self.regex_guide = None
self.fsm_current_state: int = 0
self.allowed_token_ids = self.shm_param.allowed_token_ids.to_list()
if len(self.allowed_token_ids) == 0:
Expand All @@ -217,7 +222,12 @@ def __init__(
return

def has_constraint_setting(self) -> bool:
return self.regular_constraint is not None or self.allowed_token_ids is not None
return (
self.regular_constraint is not None
or self.allowed_token_ids is not None
or self.guided_grammar is not None
or self.guided_json is not None
)


class InferReq:
Expand Down
3 changes: 2 additions & 1 deletion lightllm/server/router/model_infer/mode_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from .chunked_prefill.impl import ChunkedPrefillBackend
from .diverse_backend.impl import DiversehBackend
from .continues_batch.impl_for_token_healing import TokenHealingBackend
from .continues_batch.impl_for_simple_constraint_mode import SimpleConstraintBackend
from .continues_batch.impl_for_outlines_constraint_mode import OutlinesConstraintBackend
from .continues_batch.impl_for_first_token_constraint_mode import FirstTokenConstraintBackend
from .dp_backend.impl import DPBackend
from .continues_batch.pd_mode.prefill_node_impl.prefill_impl import ContinuesBatchBackendForPrefillNode
from .continues_batch.pd_mode.decode_node_impl.decode_impl import ContinuesBatchBackendForDecodeNode
from .continues_batch.impl_for_xgrammar_mode import XgrammarBackend
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
logger = init_logger(__name__)


class SimpleConstraintBackend(ContinuesBatchBackend):
class OutlinesConstraintBackend(ContinuesBatchBackend):
def __init__(self) -> None:
super().__init__()

Expand Down
Loading