From a5e6abf6c305ba0aca11a7fd77247a64c68359df Mon Sep 17 00:00:00 2001 From: leiwen83 Date: Thu, 7 Sep 2023 10:50:39 +0800 Subject: [PATCH] add best_of and use_beam_search for completions interface (#2372) Signed-off-by: Lei Wen Co-authored-by: Lei Wen --- fastchat/protocol/api_protocol.py | 2 +- fastchat/protocol/openai_api_protocol.py | 4 +- fastchat/serve/openai_api_server.py | 29 +++++++++- fastchat/serve/vllm_worker.py | 70 +++++++++++++++++------- 4 files changed, 79 insertions(+), 26 deletions(-) diff --git a/fastchat/protocol/api_protocol.py b/fastchat/protocol/api_protocol.py index 7dc8fe1c3..1091f5e5a 100644 --- a/fastchat/protocol/api_protocol.py +++ b/fastchat/protocol/api_protocol.py @@ -150,7 +150,7 @@ class CompletionResponse(BaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: List[CompletionResponseChoice] - usage: UsageInfo + usage: Union[UsageInfo, List[UsageInfo]] class CompletionResponseStreamChoice(BaseModel): diff --git a/fastchat/protocol/openai_api_protocol.py b/fastchat/protocol/openai_api_protocol.py index 6232e8b9b..fc3c91ebd 100644 --- a/fastchat/protocol/openai_api_protocol.py +++ b/fastchat/protocol/openai_api_protocol.py @@ -151,11 +151,13 @@ class CompletionRequest(BaseModel): presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 user: Optional[str] = None + use_beam_search: Optional[bool] = False + best_of: Optional[int] = None class CompletionResponseChoice(BaseModel): index: int - text: str + text: Union[str, List[str]] logprobs: Optional[int] = None finish_reason: Optional[Literal["stop", "length"]] = None diff --git a/fastchat/serve/openai_api_server.py b/fastchat/serve/openai_api_server.py index 02e8481f4..e399345d8 100644 --- a/fastchat/serve/openai_api_server.py +++ b/fastchat/serve/openai_api_server.py @@ -241,6 +241,9 @@ async def get_gen_params( max_tokens: Optional[int], echo: Optional[bool], stop: Optional[Union[str, List[str]]], + best_of: Optional[int] = None, + n: Optional[int] = 1, + use_beam_search: Optional[bool] = None, ) -> Dict[str, Any]: conv = await get_conv(model_name, worker_addr) conv = Conversation( @@ -287,6 +290,11 @@ async def get_gen_params( "stop_token_ids": conv.stop_token_ids, } + if best_of is not None: + gen_params.update({"n": n, "best_of": best_of}) + if use_beam_search is not None: + gen_params.update({"use_beam_search": use_beam_search}) + new_stop = set() _add_to_set(stop, new_stop) _add_to_set(conv.stop_str, new_stop) @@ -494,12 +502,18 @@ async def create_completion(request: CompletionRequest): max_tokens=request.max_tokens, echo=request.echo, stop=request.stop, + best_of=request.best_of, + n=request.n, + use_beam_search=request.use_beam_search, ) for i in range(request.n): content = asyncio.create_task( generate_completion(gen_params, worker_addr) ) text_completions.append(content) + # when use with best_of, only need send one request + if request.best_of: + break try: all_tasks = await asyncio.gather(*text_completions) @@ -519,9 +533,18 @@ async def create_completion(request: CompletionRequest): finish_reason=content.get("finish_reason", "stop"), ) ) - task_usage = UsageInfo.parse_obj(content["usage"]) - for usage_key, usage_value in task_usage.dict().items(): - setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + idx = 0 + while True: + info = content["usage"] + if isinstance(info, list): + info = info[idx] + + task_usage = UsageInfo.parse_obj(info) + + for usage_key, usage_value in task_usage.dict().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + idx += 1 + break return CompletionResponse( model=request.model, choices=choices, usage=UsageInfo.parse_obj(usage) diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py index 8e255b79c..71a30f890 100644 --- a/fastchat/serve/vllm_worker.py +++ b/fastchat/serve/vllm_worker.py @@ -18,6 +18,7 @@ from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid +from fastchat.constants import ErrorCode, SERVER_ERROR_MSG from fastchat.serve.model_worker import ( BaseModelWorker, logger, @@ -74,6 +75,9 @@ async def generate_stream(self, params): if self.tokenizer.eos_token_id is not None: stop_token_ids.append(self.tokenizer.eos_token_id) echo = params.get("echo", True) + use_beam_search = params.get("use_beam_search", False) + best_of = params.get("best_of", None) + n = params.get("n", 1) # Handle stop_str stop = set() @@ -90,27 +94,51 @@ async def generate_stream(self, params): top_p = max(top_p, 1e-5) if temperature <= 1e-5: top_p = 1.0 - sampling_params = SamplingParams( - n=1, - temperature=temperature, - top_p=top_p, - use_beam_search=False, - stop=list(stop), - max_tokens=max_new_tokens, - ) - results_generator = engine.generate(context, sampling_params, request_id) - - async for request_output in results_generator: - prompt = request_output.prompt - if echo: - text_outputs = [ - prompt + output.text for output in request_output.outputs - ] - else: - text_outputs = [output.text for output in request_output.outputs] - text_outputs = " ".join(text_outputs) - # Note: usage is not supported yet - ret = {"text": text_outputs, "error_code": 0, "usage": {}} + try: + sampling_params = SamplingParams( + n=n, + temperature=temperature, + top_p=top_p, + use_beam_search=use_beam_search, + stop=list(stop), + max_tokens=max_new_tokens, + best_of=best_of, + ) + + results_generator = engine.generate(context, sampling_params, request_id) + + async for request_output in results_generator: + prompt = request_output.prompt + prompt_tokens = len(request_output.prompt_token_ids) + output_usage = [] + for out in request_output.outputs: + completion_tokens = len(out.token_ids) + total_tokens = prompt_tokens + completion_tokens + output_usage.append( + { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + } + ) + + if echo: + text_outputs = [ + prompt + output.text for output in request_output.outputs + ] + else: + text_outputs = [output.text for output in request_output.outputs] + + if sampling_params.best_of is None: + text_outputs = [" ".join(text_outputs)] + ret = {"text": text_outputs, "error_code": 0, "usage": output_usage} + yield (json.dumps(ret) + "\0").encode() + except (ValueError, RuntimeError) as e: + ret = { + "text": f"{e}", + "error_code": ErrorCode.PARAM_OUT_OF_RANGE, + "usage": {}, + } yield (json.dumps(ret) + "\0").encode() async def generate(self, params):