diff --git a/docs/source/index.rst b/docs/source/index.rst index 5cc28a2d70139..4022c590843e6 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -87,6 +87,7 @@ Documentation models/adding_model models/engine_args models/lora + models/performance .. toctree:: :maxdepth: 1 diff --git a/docs/source/models/performance.rst b/docs/source/models/performance.rst new file mode 100644 index 0000000000000..067757699f32a --- /dev/null +++ b/docs/source/models/performance.rst @@ -0,0 +1,38 @@ +.. _performance: + +Performance and Tuning +====================== + +Chunked Prefill +--------------- +vLLM supports an experimental feature chunked prefill. Chunked prefill allows to chunk large prefills into smaller chunks and batch them together with decode requests. + +You can enable the feature by specifying + +.. code-block:: python + + llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_chunked_prefill=True) + # Set max_num_batched_tokens to tune performance. + # NOTE: 512 is the default max_num_batched_tokens for chunked prefill. + # llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_chunked_prefill=True, max_num_batched_tokens=512) + +By default, vLLM scheduler prioritizes prefills and doesn't batch prefill and decode to the same batch. This policy optimizes the TTFT (time to thefirst token), but incurs slower ITL (inter token latency) and inefficient GPU utilization. + +Once chunked prefill is enabled, the policy is changed to + +- prioritize decode requests. It batches all pending decode requests to the batch before scheduling any prefill. +- When there are available token_budget (`max_num_batched_tokens`), it schedules pending prefills. If a last pending prefill request cannot fit into `max_num_batched_tokens`, it chunks it. + +This policy has two benefits. + +- It improves ITL (inter token latency) and generation decode because decode requests are prioritized. +- It helps achieve better GPU utilization by locating compute-bound (prefill) and memory-bound (decode) requests to the same batch. + +You can tune the performance by changing `max_num_batched_tokens`. +By default, it is set to 512, which has the best ITL on A100 in the initial benchmark. +Smaller batch size achieves better ITL because there are fewer prefills interrupting decodes. +Higher batch size achieves better TTFT as you can put more prefill to the batch. +If `max_num_batched_tokens` is the same as `max_model_len`, that's almost the equivalent to the default scheduling policy (except that it still prioritizes decodes). +Note that the default batch size (512) is optimized for ITL, and it may have lower throughput than the default scheduler. We recommend you set `max_num_batched_tokens > 2048` for throughput. + +See related papers for more details (https://arxiv.org/pdf/2401.08671 or https://arxiv.org/pdf/2308.16369). diff --git a/vllm/config.py b/vllm/config.py index fe54c54bed48e..6c65bbe247f84 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -607,8 +607,9 @@ def __init__( self.max_num_batched_tokens = max_num_batched_tokens else: if enable_chunked_prefill: - # For chunked prefill, choose the well-tuned batch size. - self.max_num_batched_tokens = 768 + # It is the values that have the best balance between ITL + # and TTFT on A100. Note it is not optimized for throughput. + self.max_num_batched_tokens = 512 else: # If max_model_len is too short, use 2048 as the default value # for higher throughput.