Skip to content

Commit ebf3bf8

Browse files
authored
BUG: vllm structured output compatibility (#4111)
1 parent 5e5e938 commit ebf3bf8

File tree

1 file changed

+69
-15
lines changed

1 file changed

+69
-15
lines changed

xinference/model/llm/vllm/core.py

Lines changed: 69 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -936,9 +936,21 @@ def _convert_request_output_to_completion(
936936

937937
async def _get_tokenizer(self, lora_request: Any) -> Any:
938938
try:
939-
return await self._engine.get_tokenizer(lora_request) # type: ignore
939+
# vLLM 0.11.0+ get_tokenizer doesn't accept lora_request parameter
940+
if (
941+
VLLM_VERSION >= version.parse("0.11.0")
942+
or VLLM_VERSION.base_version >= "0.11.0"
943+
):
944+
return await self._engine.get_tokenizer() # type: ignore
945+
else:
946+
return await self._engine.get_tokenizer(lora_request) # type: ignore
940947
except AttributeError:
941-
return await self._engine.get_tokenizer_async(lora_request) # type: ignore
948+
# Fallback to get_tokenizer_async for older versions
949+
try:
950+
return await self._engine.get_tokenizer_async(lora_request) # type: ignore
951+
except (AttributeError, TypeError):
952+
# If all else fails, try without parameters
953+
return await self._engine.get_tokenizer() # type: ignore
942954

943955
def _tokenize(self, tokenizer: Any, prompt: str, config: dict) -> List[int]:
944956
truncate_prompt_tokens = config.get("truncate_prompt_tokens")
@@ -1019,23 +1031,65 @@ async def async_generate(
10191031
# guided decoding only available for vllm >= 0.6.3
10201032
from vllm.sampling_params import GuidedDecodingParams
10211033

1022-
guided_options = GuidedDecodingParams.from_optional(
1023-
json=sanitized_generate_config.pop("guided_json", None),
1024-
regex=sanitized_generate_config.pop("guided_regex", None),
1025-
choice=sanitized_generate_config.pop("guided_choice", None),
1026-
grammar=sanitized_generate_config.pop("guided_grammar", None),
1027-
json_object=sanitized_generate_config.pop("guided_json_object", None),
1028-
backend=sanitized_generate_config.pop("guided_decoding_backend", None),
1029-
whitespace_pattern=sanitized_generate_config.pop(
1030-
"guided_whitespace_pattern", None
1031-
),
1034+
# Extract guided decoding parameters
1035+
guided_params: dict[str, Any] = {}
1036+
guided_json = sanitized_generate_config.pop("guided_json", None)
1037+
if guided_json:
1038+
guided_params["json"] = guided_json
1039+
1040+
guided_regex = sanitized_generate_config.pop("guided_regex", None)
1041+
if guided_regex:
1042+
guided_params["regex"] = guided_regex
1043+
1044+
guided_choice = sanitized_generate_config.pop("guided_choice", None)
1045+
if guided_choice:
1046+
guided_params["choice"] = guided_choice
1047+
1048+
guided_grammar = sanitized_generate_config.pop("guided_grammar", None)
1049+
if guided_grammar:
1050+
guided_params["grammar"] = guided_grammar
1051+
1052+
guided_json_object = sanitized_generate_config.pop(
1053+
"guided_json_object", None
1054+
)
1055+
if guided_json_object:
1056+
guided_params["json_object"] = guided_json_object
1057+
1058+
guided_backend = sanitized_generate_config.pop(
1059+
"guided_decoding_backend", None
10321060
)
1061+
if guided_backend:
1062+
guided_params["_backend"] = guided_backend
10331063

1034-
sampling_params = SamplingParams(
1035-
guided_decoding=guided_options, **sanitized_generate_config
1064+
guided_whitespace_pattern = sanitized_generate_config.pop(
1065+
"guided_whitespace_pattern", None
10361066
)
1067+
if guided_whitespace_pattern:
1068+
guided_params["whitespace_pattern"] = guided_whitespace_pattern
1069+
1070+
# Create GuidedDecodingParams if we have any guided parameters
1071+
guided_options = None
1072+
if guided_params:
1073+
try:
1074+
guided_options = GuidedDecodingParams(**guided_params)
1075+
except Exception as e:
1076+
logger.warning(f"Failed to create GuidedDecodingParams: {e}")
1077+
guided_options = None
1078+
1079+
# Use structured_outputs for vLLM >= 0.11.0, guided_decoding for older versions
1080+
if (
1081+
VLLM_VERSION >= version.parse("0.11.0")
1082+
or VLLM_VERSION.base_version >= "0.11.0"
1083+
):
1084+
sampling_params = SamplingParams(
1085+
structured_outputs=guided_options, **sanitized_generate_config
1086+
)
1087+
else:
1088+
sampling_params = SamplingParams(
1089+
guided_decoding=guided_options, **sanitized_generate_config
1090+
)
10371091
else:
1038-
# ignore generate configs
1092+
# ignore generate configs for older versions
10391093
sanitized_generate_config.pop("guided_json", None)
10401094
sanitized_generate_config.pop("guided_regex", None)
10411095
sanitized_generate_config.pop("guided_choice", None)

0 commit comments

Comments
 (0)