Skip to content

Commit bacc7b6

Browse files
author
xusenlin
committed
fix vllm server for additional parameters
1 parent 5126bd2 commit bacc7b6

File tree

3 files changed

+41
-9
lines changed

3 files changed

+41
-9
lines changed

api/protocol.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ class ChatCompletionRequest(BaseModel):
7171
functions: Optional[List[Dict[str, Any]]] = None
7272
function_call: Union[str, Dict[str, str]] = "auto"
7373

74+
# Additional parameters supported by vLLM
75+
best_of: Optional[int] = None
76+
top_k: Optional[int] = -1
77+
ignore_eos: Optional[bool] = False
78+
use_beam_search: Optional[bool] = False
79+
7480

7581
class FunctionCallResponse(BaseModel):
7682
name: str
@@ -103,6 +109,7 @@ class ChatCompletionResponse(BaseModel):
103109
class DeltaMessage(BaseModel):
104110
role: Optional[str] = None
105111
content: Optional[str] = None
112+
function_call: Optional[FunctionCallResponse] = None
106113

107114

108115
class ChatCompletionResponseStreamChoice(BaseModel):
@@ -149,6 +156,11 @@ class CompletionRequest(BaseModel):
149156
frequency_penalty: Optional[float] = 0.0
150157
user: Optional[str] = None
151158

159+
# Additional parameters supported by vLLM
160+
top_k: Optional[int] = -1
161+
ignore_eos: Optional[bool] = False
162+
use_beam_search: Optional[bool] = False
163+
152164

153165
class CompletionResponseChoice(BaseModel):
154166
index: int

api/router.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,16 @@ async def create_embeddings(request: EmbeddingsRequest, model_name: str = None):
430430
decoding = tiktoken.model.encoding_for_model(request.model)
431431
inputs = [decoding.decode(text) for text in inputs]
432432

433+
# https://huggingface.co/BAAI/bge-large-zh
434+
if embed_client is not None:
435+
if "bge" in args.embedding_name.lower():
436+
instruction = ""
437+
if "zh" in args.embedding_name.lower():
438+
instruction = "为这个句子生成表示以用于检索相关文章:"
439+
elif "en" in args.embedding_name.lower():
440+
instruction = "Represent this sentence for searching relevant passages: "
441+
inputs = [instruction + q for q in inputs]
442+
433443
data, token_num = [], 0
434444
batches = [
435445
inputs[i: min(i + 1024, len(inputs))]

api/vllm_server.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,17 @@ async def get_gen_prompt(request, args):
9393
return prompt_adapter.generate_prompt(request.messages), request
9494

9595

96-
async def check_length(request, prompt, args):
97-
if "baichuan-13b" in args.model_name.lower():
98-
input_ids = build_baichuan_chat_input(tokenizer, prompt)
99-
elif "qwen" in args.model_name.lower():
100-
input_ids = build_qwen_chat_input(tokenizer, prompt)
101-
else:
96+
async def get_model_inputs(request, prompt, args):
97+
if isinstance(prompt, str):
10298
input_ids = tokenizer(prompt).input_ids
99+
else:
100+
if "baichuan-13b" in args.model_name.lower():
101+
input_ids = build_baichuan_chat_input(tokenizer, prompt)
102+
elif "qwen" in args.model_name.lower():
103+
input_ids = build_qwen_chat_input(tokenizer, prompt)
104+
else:
105+
raise ValueError(f"Model not supported yet: {args.model_name.lower()}")
106+
103107
token_num = len(input_ids)
104108
if token_num + request.max_tokens > max_model_len:
105109
return input_ids, create_error_response(
@@ -143,7 +147,7 @@ async def create_chat_completion(raw_request: Request):
143147

144148
prompt, request = await get_gen_prompt(request, args)
145149
request.max_tokens = request.max_tokens or 512
146-
token_ids, error_check_ret = await check_length(request, prompt, args)
150+
token_ids, error_check_ret = await get_model_inputs(request, prompt, args)
147151
if error_check_ret is not None:
148152
return error_check_ret
149153

@@ -169,6 +173,10 @@ async def create_chat_completion(raw_request: Request):
169173
top_p=request.top_p,
170174
stop=list(set(stop)),
171175
max_tokens=request.max_tokens,
176+
best_of=request.best_of,
177+
top_k=request.top_k,
178+
ignore_eos=request.ignore_eos,
179+
use_beam_search=request.use_beam_search,
172180
)
173181
except ValueError as e:
174182
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
@@ -377,7 +385,7 @@ async def create_completion(raw_request: Request):
377385
else:
378386
prompt = request.prompt
379387

380-
token_ids, error_check_ret = await check_length(request, prompt, args)
388+
token_ids, error_check_ret = await get_model_inputs(request, prompt, args)
381389
if error_check_ret is not None:
382390
return error_check_ret
383391

@@ -388,10 +396,12 @@ async def create_completion(raw_request: Request):
388396
presence_penalty=request.presence_penalty,
389397
frequency_penalty=request.frequency_penalty,
390398
temperature=request.temperature,
391-
top_p=request.top_p,
399+
top_k=request.top_k,
392400
stop=request.stop,
401+
ignore_eos=request.ignore_eos,
393402
max_tokens=request.max_tokens,
394403
logprobs=request.logprobs,
404+
use_beam_search=request.use_beam_search,
395405
)
396406
except ValueError as e:
397407
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))

0 commit comments

Comments
 (0)