Skip to content

Commit

Permalink
add best_of and use_beam_search for completions interface (#2372)
Browse files Browse the repository at this point in the history
Signed-off-by: Lei Wen <wenlei03@qiyi.com>
Co-authored-by: Lei Wen <wenlei03@qiyi.com>
  • Loading branch information
leiwen83 and wenlei03 authored Sep 7, 2023
1 parent dc3dd12 commit a5e6abf
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 26 deletions.
2 changes: 1 addition & 1 deletion fastchat/protocol/api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class CompletionResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseChoice]
usage: UsageInfo
usage: Union[UsageInfo, List[UsageInfo]]


class CompletionResponseStreamChoice(BaseModel):
Expand Down
4 changes: 3 additions & 1 deletion fastchat/protocol/openai_api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,13 @@ class CompletionRequest(BaseModel):
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None
use_beam_search: Optional[bool] = False
best_of: Optional[int] = None


class CompletionResponseChoice(BaseModel):
index: int
text: str
text: Union[str, List[str]]
logprobs: Optional[int] = None
finish_reason: Optional[Literal["stop", "length"]] = None

Expand Down
29 changes: 26 additions & 3 deletions fastchat/serve/openai_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ async def get_gen_params(
max_tokens: Optional[int],
echo: Optional[bool],
stop: Optional[Union[str, List[str]]],
best_of: Optional[int] = None,
n: Optional[int] = 1,
use_beam_search: Optional[bool] = None,
) -> Dict[str, Any]:
conv = await get_conv(model_name, worker_addr)
conv = Conversation(
Expand Down Expand Up @@ -287,6 +290,11 @@ async def get_gen_params(
"stop_token_ids": conv.stop_token_ids,
}

if best_of is not None:
gen_params.update({"n": n, "best_of": best_of})
if use_beam_search is not None:
gen_params.update({"use_beam_search": use_beam_search})

new_stop = set()
_add_to_set(stop, new_stop)
_add_to_set(conv.stop_str, new_stop)
Expand Down Expand Up @@ -494,12 +502,18 @@ async def create_completion(request: CompletionRequest):
max_tokens=request.max_tokens,
echo=request.echo,
stop=request.stop,
best_of=request.best_of,
n=request.n,
use_beam_search=request.use_beam_search,
)
for i in range(request.n):
content = asyncio.create_task(
generate_completion(gen_params, worker_addr)
)
text_completions.append(content)
# when use with best_of, only need send one request
if request.best_of:
break

try:
all_tasks = await asyncio.gather(*text_completions)
Expand All @@ -519,9 +533,18 @@ async def create_completion(request: CompletionRequest):
finish_reason=content.get("finish_reason", "stop"),
)
)
task_usage = UsageInfo.parse_obj(content["usage"])
for usage_key, usage_value in task_usage.dict().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
idx = 0
while True:
info = content["usage"]
if isinstance(info, list):
info = info[idx]

task_usage = UsageInfo.parse_obj(info)

for usage_key, usage_value in task_usage.dict().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
idx += 1
break

return CompletionResponse(
model=request.model, choices=choices, usage=UsageInfo.parse_obj(usage)
Expand Down
70 changes: 49 additions & 21 deletions fastchat/serve/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid

from fastchat.constants import ErrorCode, SERVER_ERROR_MSG
from fastchat.serve.model_worker import (
BaseModelWorker,
logger,
Expand Down Expand Up @@ -74,6 +75,9 @@ async def generate_stream(self, params):
if self.tokenizer.eos_token_id is not None:
stop_token_ids.append(self.tokenizer.eos_token_id)
echo = params.get("echo", True)
use_beam_search = params.get("use_beam_search", False)
best_of = params.get("best_of", None)
n = params.get("n", 1)

# Handle stop_str
stop = set()
Expand All @@ -90,27 +94,51 @@ async def generate_stream(self, params):
top_p = max(top_p, 1e-5)
if temperature <= 1e-5:
top_p = 1.0
sampling_params = SamplingParams(
n=1,
temperature=temperature,
top_p=top_p,
use_beam_search=False,
stop=list(stop),
max_tokens=max_new_tokens,
)
results_generator = engine.generate(context, sampling_params, request_id)

async for request_output in results_generator:
prompt = request_output.prompt
if echo:
text_outputs = [
prompt + output.text for output in request_output.outputs
]
else:
text_outputs = [output.text for output in request_output.outputs]
text_outputs = " ".join(text_outputs)
# Note: usage is not supported yet
ret = {"text": text_outputs, "error_code": 0, "usage": {}}
try:
sampling_params = SamplingParams(
n=n,
temperature=temperature,
top_p=top_p,
use_beam_search=use_beam_search,
stop=list(stop),
max_tokens=max_new_tokens,
best_of=best_of,
)

results_generator = engine.generate(context, sampling_params, request_id)

async for request_output in results_generator:
prompt = request_output.prompt
prompt_tokens = len(request_output.prompt_token_ids)
output_usage = []
for out in request_output.outputs:
completion_tokens = len(out.token_ids)
total_tokens = prompt_tokens + completion_tokens
output_usage.append(
{
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}
)

if echo:
text_outputs = [
prompt + output.text for output in request_output.outputs
]
else:
text_outputs = [output.text for output in request_output.outputs]

if sampling_params.best_of is None:
text_outputs = [" ".join(text_outputs)]
ret = {"text": text_outputs, "error_code": 0, "usage": output_usage}
yield (json.dumps(ret) + "\0").encode()
except (ValueError, RuntimeError) as e:
ret = {
"text": f"{e}",
"error_code": ErrorCode.PARAM_OUT_OF_RANGE,
"usage": {},
}
yield (json.dumps(ret) + "\0").encode()

async def generate(self, params):
Expand Down

0 comments on commit a5e6abf

Please sign in to comment.