forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Doc] Chunked Prefill Documentation (vllm-project#4580)
- Loading branch information
Showing
3 changed files
with
42 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
.. _performance: | ||
|
||
Performance and Tuning | ||
====================== | ||
|
||
Chunked Prefill | ||
--------------- | ||
vLLM supports an experimental feature chunked prefill. Chunked prefill allows to chunk large prefills into smaller chunks and batch them together with decode requests. | ||
|
||
You can enable the feature by specifying | ||
|
||
.. code-block:: python | ||
llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_chunked_prefill=True) | ||
# Set max_num_batched_tokens to tune performance. | ||
# NOTE: 512 is the default max_num_batched_tokens for chunked prefill. | ||
# llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_chunked_prefill=True, max_num_batched_tokens=512) | ||
By default, vLLM scheduler prioritizes prefills and doesn't batch prefill and decode to the same batch. This policy optimizes the TTFT (time to thefirst token), but incurs slower ITL (inter token latency) and inefficient GPU utilization. | ||
|
||
Once chunked prefill is enabled, the policy is changed to | ||
|
||
- prioritize decode requests. It batches all pending decode requests to the batch before scheduling any prefill. | ||
- When there are available token_budget (`max_num_batched_tokens`), it schedules pending prefills. If a last pending prefill request cannot fit into `max_num_batched_tokens`, it chunks it. | ||
|
||
This policy has two benefits. | ||
|
||
- It improves ITL (inter token latency) and generation decode because decode requests are prioritized. | ||
- It helps achieve better GPU utilization by locating compute-bound (prefill) and memory-bound (decode) requests to the same batch. | ||
|
||
You can tune the performance by changing `max_num_batched_tokens`. | ||
By default, it is set to 512, which has the best ITL on A100 in the initial benchmark. | ||
Smaller batch size achieves better ITL because there are fewer prefills interrupting decodes. | ||
Higher batch size achieves better TTFT as you can put more prefill to the batch. | ||
If `max_num_batched_tokens` is the same as `max_model_len`, that's almost the equivalent to the default scheduling policy (except that it still prioritizes decodes). | ||
Note that the default batch size (512) is optimized for ITL, and it may have lower throughput than the default scheduler. We recommend you set `max_num_batched_tokens > 2048` for throughput. | ||
|
||
See related papers for more details (https://arxiv.org/pdf/2401.08671 or https://arxiv.org/pdf/2308.16369). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters