Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions examples/offline_inference/structured_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams

MAX_TOKENS = 50

# Guided decoding by Choice (list of possible options)
guided_decoding_params_choice = GuidedDecodingParams(choice=["Positive", "Negative"])
sampling_params_choice = SamplingParams(guided_decoding=guided_decoding_params_choice)
Expand All @@ -23,7 +25,9 @@
# Guided decoding by Regex
guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n")
sampling_params_regex = SamplingParams(
guided_decoding=guided_decoding_params_regex, stop=["\n"]
guided_decoding=guided_decoding_params_regex,
stop=["\n"],
max_tokens=MAX_TOKENS,
)
prompt_regex = (
"Generate an email address for Alan Turing, who works in Enigma."
Expand All @@ -48,7 +52,10 @@ class CarDescription(BaseModel):

json_schema = CarDescription.model_json_schema()
guided_decoding_params_json = GuidedDecodingParams(json=json_schema)
sampling_params_json = SamplingParams(guided_decoding=guided_decoding_params_json)
sampling_params_json = SamplingParams(
guided_decoding=guided_decoding_params_json,
max_tokens=MAX_TOKENS,
)
prompt_json = (
"Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's"
Expand All @@ -64,7 +71,10 @@ class CarDescription(BaseModel):
number ::= "1 " | "2 "
"""
guided_decoding_params_grammar = GuidedDecodingParams(grammar=simplified_sql_grammar)
sampling_params_grammar = SamplingParams(guided_decoding=guided_decoding_params_grammar)
sampling_params_grammar = SamplingParams(
guided_decoding=guided_decoding_params_grammar,
max_tokens=MAX_TOKENS,
)
prompt_grammar = (
"Generate an SQL query to show the 'username' and 'email'from the 'users' table."
)
Expand Down