Description
Motivation.
TLDR; There is high CPU overhead associated with each decode batch due to the processing and generation of input/output. Multi-step decoding will be able to amortize all these overheads over n-steps at a time.
- Transfer of the sampled token from GPU to CPU for de-tokenization and response to client
- Generation of output for user - Pythonization of tensors into python objects
- CPU preparation and generation of next step’s input metadata
- vLLM scheduler
Result is that GPU is often idle, waiting for CPU operations (5-13ms of GPU bubble)
Multi-step is when multiple decode passes are performed before performing a GPU-CPU sync in order to invoke vLLM scheduler and process sampled tokens. Currently the GPU->CPU memory transfer for sampled tokens is also synchronous with each decode step causing bubbles on the GPU. With multi-step, this memory transfer can happen in a separate CUDA stream and is essentially free as the CPU runs ahead of GPU.
See below for the source of performance improvement.
- Both screenshots are about 200ms in duration. Top row is the CUDA kernels and bottom contains the python trace.
- Each highlighted redbox is about 4ms of GPU bubble for both images.
- For baseline, this overhead is incurred on every decode.
- With multi-step-8, the 4ms only needs to be incurred once for every 8 decode iterations.
Torch Profiles
Baseline 8B on 1xH100
Benchmarks
- ShareGPT using benchmark_serving.py
- Infinite request rate
- 1k requests
- With Cudagraph
- No chunked prefill
- Input/Output length is from shareGPT dataset
MS = multi-step
MS-8 = 8-multi-steps before calling vLLM scheduler and process_output
Single GPU | Baseline (Req/s) | MS-8 (Req/s) | MS-16 (Req/s) |
---|---|---|---|
A10G 8B Llama | 5.20 | 5.89 | - |
H100 8B Llama | 20.66 | 40.06 | 43.31 |
H100 30B Llama | 9.23 | 13.09 | 13.23 |
Proposed Change.
Extend ExecuteModelRequest
(input to Workers
) and RequestOutput/SamplerOutput
to include metadata for the multi-step state and modify existing ModelRunner
to properly handle multi-step state. AsyncLLMEngine/LLMEngine
will need to be modified to be aware of multi-step in order to call into the VLLM scheduler after n-steps instead of on every decode. The existing PP scheduling will not be changed.
High level Algorithm:
- Scheduler
- We have fixed n steps and allocated additional blocks.
- Only for decoding, not prefill. Prefill runs in the same way.
- At each worker
- We prepare initial inputs the same way.
- Run a model.
- Sampler doesn't synchronize cpu <> gpu, but generates a next token only in gpu.
- At each iteration, we broadcast tokens to all workers.
- Update inputs for the next step. We use Cuda kernels for faster updates because Torch is too slow.
- Asynchronously transfer sampled tokens to CPU.
Details:
Multi-step states that need to be track for each (micro)batch:
- Current step that the batch is on - remaining lookahead slots available
sampled_token_ids
- to keep track of sampled tokens still on GPUsampler_output_ready_event
- CUDA event to make sure we only pythonize if the GPU sampling is finished- CUDA event for any forward passes that have not completed yet
- Any buffers that might be needed for async in-place update of attention metadata (depends on the backend)
Core changes to Engine:
- Add attribute to scheduler config, engine argument, and CLI to enable vLLM scheduler to return lookahead slots (previous only for spec-decode)
- Skip vLLM scheduler invocation if we have not run out of lookahead slots for a batch of decodes
- Capture pythonized outputs as they become ready to return to the client.
Core changes to ModelRunner
:
- For TP/PP: Broadcast the sampled token to all other ranks in order for each of them to call
advance_step
- Synchronize using CudaEvents with the previous forward passes to make sure the CPU does not clobber any GPU tensors currently in-use when preparing inputs for the next step.
- Synchronize with previous forward pass’s sampler and start GPU->CPU transfer in separate Cuda stream
- Pythonize any ready GPU tensors if the CPU is running ahead.
- Invoke the correct advance_step for in-place updating of next step’s input metadata
- Make sure to block for any remaining forward passes or GPU-> CPU transfers if out of lookahead slots so that Engine can call into vLLM scheduler
Prototype:
The current prototype is based on speculative decode’s T1DraftModelRunner’
s logic. There are numerous additions for PP/TP support. For the prototype we created a non-spec decode MultiStepModelRunner under workers/
. The goal is that we will generalize this to the existing ModelRunner (removing the need for a new file) before merging.
Reasoning: PP+multi-step
TLDR: Since the current multi-step look is inside ModelRunner/Worker, PP scheduling in Executor will cause bubbles between each step and not interleave the steps of Batch 1 (VE1) with Batch 2 (VE2)
Feedback Period.
No response
CC List.
@zhisbug @Yard1 @WoosukKwon @rkooo567 @zhuohan123 @simon-mo @comaniac @megha95 @richardliaw
Any Other Things.
Much thanks to @Yard1 for extensive help with design and implementation!
Sync with @megha for ongoing work to make the output_processor async. She proposed to move sampler out of model runner.