1111from vllm .entrypoints .chat_utils import ChatCompletionMessageParam
1212from vllm .entrypoints .openai .logits_processors import get_logits_processors
1313from vllm .pooling_params import PoolingParams
14- from vllm .sampling_params import SamplingParams
14+ from vllm .sampling_params import LogitsProcessor , SamplingParams
1515from vllm .utils import random_uuid
1616
1717
@@ -215,15 +215,22 @@ class ChatCompletionRequest(OpenAIBaseModel):
215215
216216 # doc: end-chat-completion-extra-params
217217
218- def to_sampling_params (self ,
219- tokenizer : PreTrainedTokenizer ) -> SamplingParams :
220- # We now allow logprobs being true without top_logrobs.
218+ def to_sampling_params (
219+ self , tokenizer : PreTrainedTokenizer ,
220+ guided_decode_logits_processor : Optional [LogitsProcessor ],
221+ default_max_tokens : int ) -> SamplingParams :
222+ max_tokens = self .max_tokens
223+ if max_tokens is None :
224+ max_tokens = default_max_tokens
221225
226+ # We now allow logprobs being true without top_logrobs.
222227 logits_processors = get_logits_processors (
223228 logit_bias = self .logit_bias ,
224229 allowed_token_ids = None ,
225230 tokenizer = tokenizer ,
226231 )
232+ if guided_decode_logits_processor :
233+ logits_processors .append (guided_decode_logits_processor )
227234
228235 return SamplingParams (
229236 n = self .n ,
@@ -241,7 +248,7 @@ def to_sampling_params(self,
241248 logprobs = self .top_logprobs if self .logprobs else None ,
242249 prompt_logprobs = self .top_logprobs if self .echo else None ,
243250 ignore_eos = self .ignore_eos ,
244- max_tokens = self . max_tokens ,
251+ max_tokens = max_tokens ,
245252 min_tokens = self .min_tokens ,
246253 use_beam_search = self .use_beam_search ,
247254 early_stopping = self .early_stopping ,
@@ -395,14 +402,23 @@ class CompletionRequest(OpenAIBaseModel):
395402
396403 # doc: end-completion-extra-params
397404
398- def to_sampling_params (self , tokenizer : PreTrainedTokenizer ):
405+ def to_sampling_params (
406+ self , tokenizer : PreTrainedTokenizer ,
407+ guided_decode_logits_processor : Optional [LogitsProcessor ],
408+ default_max_tokens : int ) -> SamplingParams :
409+ max_tokens = self .max_tokens
410+ if max_tokens is None :
411+ max_tokens = default_max_tokens
412+
399413 echo_without_generation = self .echo and self .max_tokens == 0
400414
401415 logits_processors = get_logits_processors (
402416 logit_bias = self .logit_bias ,
403417 allowed_token_ids = self .allowed_token_ids ,
404418 tokenizer = tokenizer ,
405419 )
420+ if guided_decode_logits_processor :
421+ logits_processors .append (guided_decode_logits_processor )
406422
407423 return SamplingParams (
408424 n = self .n ,
@@ -419,7 +435,7 @@ def to_sampling_params(self, tokenizer: PreTrainedTokenizer):
419435 stop_token_ids = self .stop_token_ids ,
420436 logprobs = self .logprobs ,
421437 ignore_eos = self .ignore_eos ,
422- max_tokens = self . max_tokens if not echo_without_generation else 1 ,
438+ max_tokens = max_tokens if not echo_without_generation else 1 ,
423439 min_tokens = self .min_tokens ,
424440 use_beam_search = self .use_beam_search ,
425441 early_stopping = self .early_stopping ,
0 commit comments