Skip to content

[V1][PP] Continue scheduling prefill chunks #13637

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from

Conversation

comaniac
Copy link
Collaborator

@comaniac comaniac commented Feb 21, 2025

This PR supports pipelining prefill chunks. The key changes are:

Schedule

  • Originally we update num_computed_tokens in update_from_output, but this prevents us from scheduling the request conservatively, so we now advance num_computed_tokens right after scheduling.
  • The scheduled_req_ids is now Dict[str, int], where the value is the number of batches this request was scheduled. The value can be larger than one only when we schedule multiple prefill chunks.

Cached Request Data

  • We may have more than one cached request data for one request. This is because when a request is scheduled twice, the cached request data cannot be reused.

Update from output

  • This function cannot refer to request.num_computed_tokens because this value may have been advanced after this batch. Instead, it should refer to scheduler_output.scheduled_xxx_reqs[req_id].num_computed_tokens. This is the right number when this batch was scheduled.
  • For speculative decoding, we now simply subtract the number of rejected tokens.

cc @WoosukKwon @ruisearch42

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Feb 21, 2025
@comaniac comaniac linked an issue Feb 21, 2025 that may be closed by this pull request
Copy link
Collaborator

@ruisearch42 ruisearch42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! The main part looks good, have not reviewed test

Copy link
Collaborator

@ruisearch42 ruisearch42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! I got the overall idea of the PR but didn't get the full details. Please check out my comments.

@comaniac
Copy link
Collaborator Author

@WoosukKwon thanks for the review. The challenging points in this PR that make the code complicate are summarized as follows:

  1. We advance num_computed_tokens at the end of .schedule(), after SchedulerOutput is formed, because we need the values in SchedulerOutput remain unchanged. This is required for model runner and prepare input to get the correct input IDs and positions.
  2. However, since the request object is shared by .schedule() and .update_from_output(), we will see the advanced num_computed_tokens in .update_from_output(). This is incorrect if a request is scheduled twice conservatively, because we will see the updated num_computed_tokens after the second scheduled batch when updating the output of the first scheduled batch.
  3. As a result, we have to instead looking scheduler_output.req.num_computed_tokens for the correct value. However, the current scheduler_output.req is cached request data and is being shared by all scheduler outputs as well. So if a request is scheduled twice and we only have one cached request data, we will mess up these values.

Copy link

mergify bot commented Feb 27, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @comaniac.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link

mergify bot commented Mar 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @comaniac.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 11, 2025
Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
@comaniac
Copy link
Collaborator Author

comaniac commented Mar 11, 2025

Benchmark with PP=2 on 2xL4 GPUs:

Example Serving Command

VLLM_USE_V1=1 vllm serve unsloth/Llama-3.1-8B-Instruct \
--no-enable-prefix-caching \
--disable-log-requests \
--distributed-executor-backend="ray" \
--enable-chunked-prefill \
--max-num-batched-tokens=2048 \
--max-model-len=16384 \
--pipeline-parallel-size=2

Benchmark Command

python benchmarks/benchmark_serving.py \
    --backend vllm \
    --model unsloth/Llama-3.1-8B-Instruct \
    --dataset-name burstgpt \
    --dataset-path BurstGPT_without_fails_2.csv \
    --request-rate 2 \
    --num-prompts 150
  • V0, MP, main
============ Serving Benchmark Result ============
Successful requests:                     150
Benchmark duration (s):                  185.01
Total input tokens:                      106025
Total generated tokens:                  42550
Request throughput (req/s):              0.81
Output token throughput (tok/s):         229.99
Total Token throughput (tok/s):          803.06
---------------Time to First Token----------------
Mean TTFT (ms):                          366.26
Median TTFT (ms):                        245.31
P99 TTFT (ms):                           1650.63
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          121.15
Median TPOT (ms):                        121.40
P99 TPOT (ms):                           166.25
---------------Inter-token Latency----------------
Mean ITL (ms):                           117.67
Median ITL (ms):                         109.92
P99 ITL (ms):                            494.85
==================================================
  • V1, Ray, main
============ Serving Benchmark Result ============
Successful requests:                     150
Benchmark duration (s):                  175.30
Total input tokens:                      106025
Total generated tokens:                  41936
Request throughput (req/s):              0.86
Output token throughput (tok/s):         239.23
Total Token throughput (tok/s):          844.05
---------------Time to First Token----------------
Mean TTFT (ms):                          359.48
Median TTFT (ms):                        217.95
P99 TTFT (ms):                           1705.93
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          110.96
Median TPOT (ms):                        110.21
P99 TPOT (ms):                           155.08
---------------Inter-token Latency----------------
Mean ITL (ms):                           107.03
Median ITL (ms):                         95.87
P99 ITL (ms):                            507.94
==================================================
  • V1, Ray, this PR
============ Serving Benchmark Result ============
Successful requests:                     150
Benchmark duration (s):                  173.27
Total input tokens:                      106025
Total generated tokens:                  41920
Request throughput (req/s):              0.87
Output token throughput (tok/s):         241.94
Total Token throughput (tok/s):          853.85
---------------Time to First Token----------------
Mean TTFT (ms):                          308.59
Median TTFT (ms):                        197.27
P99 TTFT (ms):                           1234.44
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          107.31
Median TPOT (ms):                        107.30
P99 TPOT (ms):                           150.19
---------------Inter-token Latency----------------
Mean ITL (ms):                           103.79
Median ITL (ms):                         94.51
P99 ITL (ms):                            492.35
==================================================

Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
@comaniac
Copy link
Collaborator Author

@WoosukKwon I refactored the PR a bit to reduce possible confusion, along with the benchmark results. PTAL

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Left some comments.

Copy link

mergify bot commented Mar 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @comaniac.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot removed the needs-rebase label Mar 14, 2025
@comaniac comaniac added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 14, 2025
Copy link
Collaborator

@ruisearch42 ruisearch42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the optimization. A couple minor issues/questions

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
@comaniac comaniac enabled auto-merge (squash) March 14, 2025 21:15
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@comaniac I see some performance regression in the shareGPT benchmark

@WoosukKwon WoosukKwon disabled auto-merge March 14, 2025 21:34
@WoosukKwon
Copy link
Collaborator

@comaniac It actually hangs at the end of the benchmark... 😭 Can you please check again?

Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
@comaniac
Copy link
Collaborator Author

comaniac commented Mar 14, 2025

@comaniac It actually hangs at the end of the benchmark... 😭 Can you please check again?

Fixed a bug that doesn't consider the num_computed_tokens when prefix caching is enabled. Specifically, orig_num_computed_tokens should add prefix cached tokens.

Please try again.

@WoosukKwon
Copy link
Collaborator

As we discussed offline, let's hold off this PR until we make v0.8.0 release.

@comaniac
Copy link
Collaborator Author

@WoosukKwon should be good to go given the 0.8.0 branch has been cut?

Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
Copy link

mergify bot commented Mar 30, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @comaniac.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 30, 2025
Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
@mergify mergify bot removed the needs-rebase label Apr 7, 2025
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@comaniac thanks for the PR! Looks good to me in general, but I still have a few questions on scheduled_req_ids. Please check out my comments.

# With PP, when the input prompt is divided into chunks, we can
# schedule a new chunk even before the previous chunk has completed
# the full pipeline stages. This helps reduce TTFT.
self.scheduled_req_ids: dict[str, int] = defaultdict(int)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dumb question: What is scheduled_req_ids used for? IIUC, this PR eliminates its usage.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be a good catch. Since we now simply check whether num_new_tokens=0 to determine whether we could continue scheduling a request, we may not need this anymore...

Copy link
Collaborator

@WoosukKwon WoosukKwon Apr 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@comaniac If so, can you please remove it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
Signed-off-by: Cody Yu <cody@openai.com>

Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
Copy link

mergify bot commented May 22, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @comaniac.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 22, 2025
@comaniac comaniac closed this May 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-rebase ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[V1][PP] Pipeline chunked prefill
3 participants