@@ -93,13 +93,17 @@ async def get_gen_prompt(request, args):
93
93
return prompt_adapter .generate_prompt (request .messages ), request
94
94
95
95
96
- async def check_length (request , prompt , args ):
97
- if "baichuan-13b" in args .model_name .lower ():
98
- input_ids = build_baichuan_chat_input (tokenizer , prompt )
99
- elif "qwen" in args .model_name .lower ():
100
- input_ids = build_qwen_chat_input (tokenizer , prompt )
101
- else :
96
+ async def get_model_inputs (request , prompt , args ):
97
+ if isinstance (prompt , str ):
102
98
input_ids = tokenizer (prompt ).input_ids
99
+ else :
100
+ if "baichuan-13b" in args .model_name .lower ():
101
+ input_ids = build_baichuan_chat_input (tokenizer , prompt )
102
+ elif "qwen" in args .model_name .lower ():
103
+ input_ids = build_qwen_chat_input (tokenizer , prompt )
104
+ else :
105
+ raise ValueError (f"Model not supported yet: { args .model_name .lower ()} " )
106
+
103
107
token_num = len (input_ids )
104
108
if token_num + request .max_tokens > max_model_len :
105
109
return input_ids , create_error_response (
@@ -143,7 +147,7 @@ async def create_chat_completion(raw_request: Request):
143
147
144
148
prompt , request = await get_gen_prompt (request , args )
145
149
request .max_tokens = request .max_tokens or 512
146
- token_ids , error_check_ret = await check_length (request , prompt , args )
150
+ token_ids , error_check_ret = await get_model_inputs (request , prompt , args )
147
151
if error_check_ret is not None :
148
152
return error_check_ret
149
153
@@ -169,6 +173,10 @@ async def create_chat_completion(raw_request: Request):
169
173
top_p = request .top_p ,
170
174
stop = list (set (stop )),
171
175
max_tokens = request .max_tokens ,
176
+ best_of = request .best_of ,
177
+ top_k = request .top_k ,
178
+ ignore_eos = request .ignore_eos ,
179
+ use_beam_search = request .use_beam_search ,
172
180
)
173
181
except ValueError as e :
174
182
return create_error_response (HTTPStatus .BAD_REQUEST , str (e ))
@@ -377,7 +385,7 @@ async def create_completion(raw_request: Request):
377
385
else :
378
386
prompt = request .prompt
379
387
380
- token_ids , error_check_ret = await check_length (request , prompt , args )
388
+ token_ids , error_check_ret = await get_model_inputs (request , prompt , args )
381
389
if error_check_ret is not None :
382
390
return error_check_ret
383
391
@@ -388,10 +396,12 @@ async def create_completion(raw_request: Request):
388
396
presence_penalty = request .presence_penalty ,
389
397
frequency_penalty = request .frequency_penalty ,
390
398
temperature = request .temperature ,
391
- top_p = request .top_p ,
399
+ top_k = request .top_k ,
392
400
stop = request .stop ,
401
+ ignore_eos = request .ignore_eos ,
393
402
max_tokens = request .max_tokens ,
394
403
logprobs = request .logprobs ,
404
+ use_beam_search = request .use_beam_search ,
395
405
)
396
406
except ValueError as e :
397
407
return create_error_response (HTTPStatus .BAD_REQUEST , str (e ))
0 commit comments