diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index a9c9fbed0cbaa..4143e1af8ae04 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -160,16 +160,26 @@ async def show_available_models(): return ModelList(data=model_cards) -def create_logprobs(token_ids: List[int], - id_logprobs: List[Dict[int, float]], - initial_text_offset: int = 0) -> LogProbs: +def create_logprobs( + token_ids: List[int], + top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None, + num_output_top_logprobs: Optional[int] = None, + initial_text_offset: int = 0, +) -> LogProbs: """Create OpenAI-style logprobs.""" logprobs = LogProbs() last_token_len = 0 - for token_id, id_logprob in zip(token_ids, id_logprobs): + if num_output_top_logprobs: + logprobs.top_logprobs = [] + for i, token_id in enumerate(token_ids): + step_top_logprobs = top_logprobs[i] + if step_top_logprobs is not None: + token_logprob = step_top_logprobs[token_id] + else: + token_logprob = None token = tokenizer.convert_ids_to_tokens(token_id) logprobs.tokens.append(token) - logprobs.token_logprobs.append(id_logprob[token_id]) + logprobs.token_logprobs.append(token_logprob) if len(logprobs.text_offset) == 0: logprobs.text_offset.append(initial_text_offset) else: @@ -177,10 +187,11 @@ def create_logprobs(token_ids: List[int], last_token_len) last_token_len = len(token) - logprobs.top_logprobs.append({ - tokenizer.convert_ids_to_tokens(i): p - for i, p in id_logprob.items() - }) + if num_output_top_logprobs: + logprobs.top_logprobs.append({ + tokenizer.convert_ids_to_tokens(i): p + for i, p in step_top_logprobs.items() + } if step_top_logprobs else None) return logprobs @@ -371,8 +382,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request): for the API specification. This API mimics the OpenAI Completion API. NOTE: Currently we do not support the following features: - - echo (since the vLLM engine does not currently support - getting the logprobs of prompt tokens) - suffix (the language models we currently support do not support suffix) - logit_bias (to be supported by vLLM engine) @@ -383,11 +392,8 @@ async def create_completion(request: CompletionRequest, raw_request: Request): if error_check_ret is not None: return error_check_ret - if request.echo: - # We do not support echo since the vLLM engine does not - # currently support getting the logprobs of prompt tokens. - return create_error_response(HTTPStatus.BAD_REQUEST, - "echo is not currently supported") + # OpenAI API supports echoing the prompt when max_tokens is 0. + echo_without_generation = request.echo and request.max_tokens == 0 if request.suffix is not None: # The language models we currently support do not support suffix. @@ -443,9 +449,11 @@ async def create_completion(request: CompletionRequest, raw_request: Request): stop=request.stop, stop_token_ids=request.stop_token_ids, ignore_eos=request.ignore_eos, - max_tokens=request.max_tokens, + max_tokens=request.max_tokens + if not echo_without_generation else 1, logprobs=request.logprobs, use_beam_search=request.use_beam_search, + prompt_logprobs=request.logprobs if request.echo else None, skip_special_tokens=request.skip_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens, ) @@ -495,24 +503,42 @@ def create_stream_response_json( async def completion_stream_generator() -> AsyncGenerator[str, None]: previous_texts = [""] * request.n previous_num_tokens = [0] * request.n + has_echoed = [False] * request.n async for res in result_generator: res: RequestOutput for output in res.outputs: i = output.index delta_text = output.text[len(previous_texts[i]):] + token_ids = output.token_ids[previous_num_tokens[i]:] + top_logprobs = output.logprobs[previous_num_tokens[i]:] + offsets = len(previous_texts[i]) + if request.echo and not has_echoed[i]: + if not echo_without_generation: + delta_text = res.prompt + delta_text + token_ids = res.prompt_token_ids + token_ids + top_logprobs = res.prompt_logprobs + top_logprobs + else: + delta_text = res.prompt + token_ids = res.prompt_token_ids + top_logprobs = res.prompt_logprobs + has_echoed[i] = True if request.logprobs is not None: logprobs = create_logprobs( - output.token_ids[previous_num_tokens[i]:], - output.logprobs[previous_num_tokens[i]:], - len(previous_texts[i])) + token_ids=token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + initial_text_offset=offsets, + ) else: logprobs = None previous_texts[i] = output.text previous_num_tokens[i] = len(output.token_ids) + finish_reason = output.finish_reason response_json = create_stream_response_json( index=i, text=delta_text, logprobs=logprobs, + finish_reason=finish_reason, ) yield f"data: {response_json}\n\n" if output.finish_reason is not None: @@ -551,14 +577,36 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: final_res = res assert final_res is not None choices = [] + prompt_token_ids = final_res.prompt_token_ids + prompt_logprobs = final_res.prompt_logprobs + prompt_text = final_res.prompt for output in final_res.outputs: if request.logprobs is not None: - logprobs = create_logprobs(output.token_ids, output.logprobs) + if not echo_without_generation: + token_ids = output.token_ids + top_logprobs = output.logprobs + if request.echo: + token_ids = prompt_token_ids + token_ids + top_logprobs = prompt_logprobs + top_logprobs + else: + token_ids = prompt_token_ids + top_logprobs = prompt_logprobs + logprobs = create_logprobs( + token_ids=token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + ) else: logprobs = None + if not echo_without_generation: + output_text = output.text + if request.echo: + output_text = prompt_text + output_text + else: + output_text = prompt_text choice_data = CompletionResponseChoice( index=output.index, - text=output.text, + text=output_text, logprobs=logprobs, finish_reason=output.finish_reason, ) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 39db35620307f..797f0a7115e6e 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -106,8 +106,7 @@ class LogProbs(BaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list) - top_logprobs: List[Optional[Dict[str, - float]]] = Field(default_factory=list) + top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None class CompletionResponseChoice(BaseModel):