Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting log probabilities of prompt tokens in both engine and OpenAI API server (aka echo) #959

Closed
wants to merge 40 commits into from

Conversation

wanmok
Copy link
Contributor

@wanmok wanmok commented Sep 6, 2023

According to issue #201 , the echo function is not currently supported in vLLM. This implementation does not require making modifications to individual model files but lay the heavy work on the first generation step.

The basic logic of the implementation:

  • Engine
    1. Added echo to the SamplingParams so that the engine can be informed of the echo request
    2. Modified Sampler to process the echo request.
      a. Flags from SamplingParams of each SequenceGroup would be used to determine whether to perform echo. Log probabilities of prompt tokens are computed at the first generation step.
      b. Modified SequenceOutputs to carry prompt log probs from the first generate step so that engine can post process SequenceData to store the generated log probabilities of prompt tokens. Since log probabilities of prompt tokens are not necessarily stored in the top-k returns, they are stored in two separate fields as prompt_logprobs and prompt_top_logprobs.
      c. Changed numerical unstable computation of log(softmax()) to log_softmax().
    3. Modified RequestOutput.from_seq_group so that log probabilities of prompt tokens can be inserted into the final outputs.
  • OpenAI API server
    1. Removed error message for requesting echo.
    2. Modified create_logprobs function to reflect engine output changes.
    3. Edge case: added the support for the case of echo=True and max_tokens=0, which is a valid case in OpenAI APIs. In this case, the SamplingParams.max_tokens would be set to 1 and enable a echo_self flag. The flag is then used in the following post processing to ensure the additionally generated token being removed from the final output.

Please let me know if you have any questions or concerns with respect to merging this PR. Thanks!

@wanmok wanmok changed the title A runnable implementation of echo. Supporting log probabilities of prompt tokens in both engine and OpenAI API server (aka echo) Sep 6, 2023
@eugenepentland
Copy link

Hello, thanks for this! Have you tested it using the openai compatible endpoint? when I try using curl it gets the request but I never get a response back. The following is with echo off, but it hangs regardless if it is on/off.

INFO 09-07 06:06:49 async_llm_engine.py:226] Received request cmpl-a06d218f31d549d7b128ee067691a83a: prompt: 'San Francisco is a', sampling params: SamplingParams(n=1, best_of=1, presence_penalty=0.0, frequency_penalty=0.0, temperature=0.0, top_p=1.0, top_k=-1, use_beam_search=False, length_penalty=1.0, early_stopping=False, stop=[], ignore_eos=False, max_tokens=7, logprobs=None), use_echo=False, prompt token ids: [1, 3087, 8970, 338, 263].

I may just be impatient, but I am excited to try this out with llm eval harness!

@wanmok
Copy link
Contributor Author

wanmok commented Sep 7, 2023

Hello, thanks for this! Have you tested it using the openai compatible endpoint? when I try using curl it gets the request but I never get a response back. The following is with echo off, but it hangs regardless if it is on/off.

INFO 09-07 06:06:49 async_llm_engine.py:226] Received request cmpl-a06d218f31d549d7b128ee067691a83a: prompt: 'San Francisco is a', sampling params: SamplingParams(n=1, best_of=1, presence_penalty=0.0, frequency_penalty=0.0, temperature=0.0, top_p=1.0, top_k=-1, use_beam_search=False, length_penalty=1.0, early_stopping=False, stop=[], ignore_eos=False, max_tokens=7, logprobs=None), use_echo=False, prompt token ids: [1, 3087, 8970, 338, 263].

I may just be impatient, but I am excited to try this out with llm eval harness!

What is the LM you were using for the request? Also, how many GPUs?

@eugenepentland
Copy link

Turns out I still had tensor_parallel=2 on, not sure the issue is related to your fork. Setting it to 1 fixed that issue. Running Llama-7B-HF on a RTX 3090.

This is the new error internal server error I'm getting when I have the following args to http://localhost:8000/v1/completions with the openai endpoint:

JSON Args:

{
    "model": "llama",
    "prompt": "San Fransisco is a",
    "max_tokens": 0,
    "echo": true,
    "logprobs": 5
}

Error:

  File "/home/epentland/miniconda3/envs/axol/lib/python3.9/site-packages/fastapi/routing.py", line 190, in run_endpoint_function
    return await dependant.call(**values)
  File "/home/epentland/ai/inference/echo/vllm/vllm/entrypoints/openai/api_server.py", line 565, in create_completion
    logprobs = create_logprobs(
  File "/home/epentland/ai/inference/echo/vllm/vllm/entrypoints/openai/api_server.py", line 182, in create_logprobs
    for i, p in top_logprobs[i].items()
    ```

@wanmok
Copy link
Contributor Author

wanmok commented Sep 7, 2023

Turns out I still had tensor_parallel=2 on, not sure the issue is related to your fork. Setting it to 1 fixed that issue. Running Llama-7B-HF on a RTX 3090.

This is the new error internal server error I'm getting when I have the following args to http://localhost:8000/v1/completions with the openai endpoint:

JSON Args:

{
    "model": "llama",
    "prompt": "San Fransisco is a",
    "max_tokens": 0,
    "echo": true,
    "logprobs": 5
}

Error:

  File "/home/epentland/miniconda3/envs/axol/lib/python3.9/site-packages/fastapi/routing.py", line 190, in run_endpoint_function
    return await dependant.call(**values)
  File "/home/epentland/ai/inference/echo/vllm/vllm/entrypoints/openai/api_server.py", line 565, in create_completion
    logprobs = create_logprobs(
  File "/home/epentland/ai/inference/echo/vllm/vllm/entrypoints/openai/api_server.py", line 182, in create_logprobs
    for i, p in top_logprobs[i].items()
    ```

Could you pull the latest commit? I have tested the latest commit with about 4k requests (similar to yours).

@eugenepentland
Copy link

I tested on the latest commit on my own hardware, and on a A6000 on runpod, getting the same error on both. Can you send the exact command for the openai ai sever and the curl request/python code you used to make the request?

@wanmok
Copy link
Contributor Author

wanmok commented Sep 8, 2023

I tested on the latest commit on my own hardware, and on a A6000 on runpod, getting the same error on both. Can you send the exact command for the openai ai sever and the curl request/python code you used to make the request?

Ah, I have it fixed in the latest commit. It was later introduced by the modifications to the top_logprobs, where the first one should be None. Could you try to run it again with your dataset? This PR introduced multiple changes, so I suppose that it would be better to test more rigorously... Thanks for catching that.

@eugenepentland
Copy link

It's working on my end now! Thank you for the great work!

I am making it work with lm-eval-harness so we can have a much faster model evaluation.

@wanmok
Copy link
Contributor Author

wanmok commented Sep 11, 2023

@zhuohan123 @WoosukKwon - hi guys, I wonder if we could have someone review the PR. If any change requested, I would be happy to work on it. Thanks!

@zhuohan123 zhuohan123 self-requested a review September 11, 2023 23:29
Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your contribution! This is a feature that a lot of our community members are asking. The PR is relatively complex and I think it may require several rounds of review. I have left some of my initial comments. Please take a look. Additionally, Is it possible for you to add a test on this functionality here? After this round, I will run this PR, check the correctness of the PR, and then perform another round of review. Thanks again!

vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
vllm/entrypoints/openai/api_server.py Show resolved Hide resolved
vllm/sequence.py Outdated Show resolved Hide resolved
vllm/sequence.py Outdated Show resolved Hide resolved
vllm/sampling_params.py Outdated Show resolved Hide resolved
vllm/worker/worker.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/sampler.py Show resolved Hide resolved
vllm/model_executor/layers/sampler.py Show resolved Hide resolved
vllm/model_executor/layers/sampler.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/sampler.py Outdated Show resolved Hide resolved
@zhuohan123 zhuohan123 added the enhancement New feature or request label Sep 12, 2023
@wanmok
Copy link
Contributor Author

wanmok commented Sep 24, 2023

@zhuohan123 I wonder if there is any update on the review. Please let me know if you found any blocker.

Copy link
Contributor Author

@wanmok wanmok left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have addressed comments from the previous review. Please let me know if there is any other concerns to merge the PR.

@lifengjin
Copy link

Any update on this? This is very useful.

@WoosukKwon WoosukKwon mentioned this pull request Oct 13, 2023
3 tasks
@zhuohan123
Copy link
Member

@wanmok Thanks for all of your contributions so far. I have been looking into this last week and tried to have a simpler modification to the vLLM core to implement this function. Please find my changes in this PR: #1328. Is it possible for you to add the OpenAI endpoint support of echo on top of that PR? Let me know if there is any issue.

@wanmok
Copy link
Contributor Author

wanmok commented Oct 14, 2023

@wanmok Thanks for all of your contributions so far. I have been looking into this last week and tried to have a simpler modification to the vLLM core to implement this function. Please find my changes in this PR: #1328. Is it possible for you to add the OpenAI endpoint support of echo on top of that PR? Let me know if there is any issue.

Well.... it might take some time to read through the PR and re-implement that. Are you thinking about a new PR after that one being merged or?

@zhuohan123
Copy link
Member

Well.... it might take some time to read through the PR and re-implement that. Are you thinking about a new PR after that one being merged or?

PR #1328 modified the vLLM core. So that you can just add prompt_logprobs=N in SamplingParams, and the returning RequestOutput will include a prompt_logprobs field in it. After #1328 being merged into the main branch, the only remaining thing is to let the OpenAI API server uses this new feature to support echo. Can you submit a new PR for this after #1328 is merged? I believe the code change to openai/api_server.py should be very similar with what you have done in this PR.

@wanmok
Copy link
Contributor Author

wanmok commented Oct 15, 2023 via email

@zhuohan123
Copy link
Member

Sure thing! Actually in that PR I used some of your testing code. Will add it when I merge the PR.

@RanchiZhao
Copy link

Cool!is this now available, I mean the logps of prompt tokens.

@wanmok
Copy link
Contributor Author

wanmok commented Oct 16, 2023

Cool!is this now available, I mean the logps of prompt tokens.

I suppose that the PR is not ready yet. If you would like to get logprobs of prompt tokens, you could try with this PR as a temporary solution.

@zhuohan123
Copy link
Member

@wanmok #1328 has been merged to the main branch!

@wanmok
Copy link
Contributor Author

wanmok commented Oct 17, 2023

@wanmok #1328 has been merged to the main branch!

Cool! I'll draft a new PR when I got time, probably in a week or so (traveling to CA this week).

@lifengjin
Copy link

Any updates on this, please?

@zhuohan123
Copy link
Member

Close this PR in favor of #1328 and #1509. Thank you for your contribution!

@zhuohan123 zhuohan123 closed this Oct 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants