Skip to content

[Core] Add dynamic chunk size calculation #10061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

prashantgupta24
Copy link
Contributor

@prashantgupta24 prashantgupta24 commented Nov 6, 2024

When doing chunked prefill, calculate the token budget for a single chunk. This dynamically scales the chunk size down as the number of sequences that require prefilling increases. This ensures that a single sequence with a very large prompt to prefill doesn't take the entire remaining token budget, allowing other sequences to prefill and decode concurrently.

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Adding or changing kernels

Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.

  • Make sure custom ops are registered following PyTorch guidelines: Custom C++ and CUDA Operators and The Custom Operators Manual
  • Custom operations that return Tensors require meta-functions. Meta-functions should be implemented and registered in python so that dynamic dims can be handled automatically. See above documents for a description of meta-functions.
  • Use torch.libary.opcheck() to test the function registration and meta-function for any registered ops. See tests/kernels for examples.
  • When changing the C++ signature of an existing op, the schema must be updated to reflect the changes.
  • If a new custom type is needed, see the following document: Custom Class Support in PT2.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

joerunde and others added 8 commits November 4, 2024 17:12
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Copy link

github-actions bot commented Nov 6, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@prashantgupta24 prashantgupta24 changed the title [Core] Add Min chunk size [Core] Add dynamic chunk size calculation Nov 6, 2024
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
Copy link

mergify bot commented Nov 6, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @prashantgupta24 please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 6, 2024
Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. While I understand that this approach attempts to balance the TTFT of all requests, I have the following concerns:

  1. This results in many requests in the partial prefill stage. For example, you may have 10 requests with 1k prompt, and you will schedule all of them to process a chunk in each step. However, we now allocate kv blocks for an entire prompt when scheduling a request, meaning that we will allocate 10k/16 blocks in the first step. Even assuming we have more than 10k/16 blocks available, this may result in lots of preemptions once all 10 requests get to decoding stage.
  2. The implementation seems to have some overheads. Please be aware that scheduler is in the critical path of TTFT and ITL, so even 1 ms is a huge overhead.
  3. I don't think this approach benefits throughputs, so for the high throughput scenario, people would not use this feature (min_chunk_size=None). Then there's only overheads.

In short, I'm worry about the overall performance for the scenarios that cannot benefit from this approach. A solid benchmark could be a good reference to help proceed.

Comment on lines +1610 to +1618
# calculate a chunk size that shares it evenly across sequences that
# need to prefill
chunk_size = int(remaining_token_budget / prefilling_seqs)
# Ensure the chunk size is at least the minimum configured by the
# user, to limit the number of requests doing prefill
chunk_size = max(chunk_size, self.scheduler_config.min_chunk_size)
# And cap that at our actual budget so we don't spend tokens we
# don't have.
chunk_size = min(remaining_token_budget, chunk_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again I feel in high QPS the overhead of this logic would be large. Specifically, when there are many prefill requests, you mostly would just allocate min_chunk_size, making this calculation not effective.

prashantgupta24 and others added 3 commits November 6, 2024 10:44
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
@njhill
Copy link
Member

njhill commented Nov 7, 2024

I haven't looked at the impl/logic closely but instead of treating all prefills equally perhaps we can admit new ones based on the input size being smaller than some fraction of the remaining prefill tokens of existing prefill reqs in the batch. This would address the core problem while I think avoiding most of the issues @comaniac raised. And in theory could even have some benefit to throughput since more reqs would get to decode stage faster?

@joerunde
Copy link
Collaborator

joerunde commented Nov 7, 2024

Thanks for the PR. While I understand that this approach attempts to balance the TTFT of all requests, I have the following concerns:

  1. This results in many requests in the partial prefill stage. For example, you may have 10 requests with 1k prompt, and you will schedule all of them to process a chunk in each step. However, we now allocate kv blocks for an entire prompt when scheduling a request, meaning that we will allocate 10k/16 blocks in the first step. Even assuming we have more than 10k/16 blocks available, this may result in lots of preemptions once all 10 requests get to decoding stage.
  2. The implementation seems to have some overheads. Please be aware that scheduler is in the critical path of TTFT and ITL, so even 1 ms is a huge overhead.
  3. I don't think this approach benefits throughputs, so for the high throughput scenario, people would not use this feature (min_chunk_size=None). Then there's only overheads.

In short, I'm worry about the overall performance for the scenarios that cannot benefit from this approach. A solid benchmark could be a good reference to help proceed.

Yeah, all good points!

1: I wouldn't expect min_chunk_size to be configured so small that all 10 requests would start to prefill, e.g.that would mean setting --max-num-batched-tokens=512 --min-chunk-size=50 and with chunks that small, I think prefill would just take a lot longer in general. Setting those to 512/256 to allow two requests to prefill at a time would be more reasonable. I agree we don't want to be reckless with blowing up the KV cache, but we're trying out these changes because a customer is frequently hitting an issue where a single 130k context prompt is blocking the queue during prefill for a solid minute, while only using 10% of the KV cache.

2/3: Yeah we definitely need to not do any extra chunk size calculations if min_chunk_size is unset, and we can clean up / cache some calculations to decrease overhead too.

@joerunde
Copy link
Collaborator

joerunde commented Nov 7, 2024

Re: benchmarking, we created a dataset with super high variance in the prompt lengths- 75% prompts with tens of tokens and 25% prompts with thousands of tokens. This change is currently faster on TTFT but slower on ITL, so we'll work on reducing that overhead and see if we can beat both.

joerunde and others added 2 commits November 7, 2024 10:28
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
@comaniac
Copy link
Collaborator

comaniac commented Nov 8, 2024

FYI: @rickyyx found an issue with multiple partial prefills on the fly. Particularly there's an assumption in the sampler that we will only have one partial prefill request in a batch. I'm a bit worry this assumption may be everywhere so we have to be careful.

@rickyyx
Copy link
Contributor

rickyyx commented Nov 8, 2024

FYI: @rickyyx found an issue with multiple partial prefills on the fly. Particularly there's an assumption in the sampler that we will only have one partial prefill request in a batch. I'm a bit worry this assumption may be everywhere so we have to be careful.

Yeah, specifically for the sampler one I think it's assumed we would only have one partial prefill(that doesn't require prompt logprobes):

for (seq_group, sample_result, group_prompt_logprobs,
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
maybe_deferred_sample_results,
prompt_logprobs, sample_logprobs):

All inputs to the zip were assumed to have equal len, but when there are multiple partial prefills, this assumption might not hold true.

Copy link

mergify bot commented Dec 17, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @prashantgupta24.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 17, 2024
@njhill
Copy link
Member

njhill commented Feb 5, 2025

Closing because this was superseded by #10235.

@njhill njhill closed this Feb 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants