Description
Motivation #
There is significant interest in vLLM supporting encoder/decoder models. Issues #187 and #180 , for example, request encoder/decoder model support. As a result encoder/decoder support was recently introduced to vLLM via the following three PRs:
- [Core] Cross-attention KV caching and memory-management (towards eventual encoder/decoder model support) #4837
- [Kernel] Correctly invoke prefill & decode kernels for cross-attention (towards eventual encoder/decoder model support) #4888
- [Core] Subclass ModelRunner to support cross-attention & encoder sequences (towards eventual encoder/decoder model support) #4942
These three PRs make encoder/decoder model inference possible; however, they leave more to be desired in terms of (1) parity between vLLM's decoder-only & encoder/decoder request processing pipelines with respect to feature support, and (2) the number of encoder/decoder models which are supported.
The ask for the vLLM community is to contribute PRs which help bring vLLM encoder/decoder functionality to a similar level of maturity as that of vLLM's decoder-only functionality.
Proposed changes #
The support matrix below summarizes which encoder/decoder models have already been added & which features are currently compatible with the vLLM encoder/decoder pipeline, versus which features & models will require additional PRs to implement in the long-term:
Model/feature | Model is already available/feature is already compatible with encoder-decoder? | Having this model/making this feature compatible is a long-term goal? |
---|---|---|
Encoder/decoder infrastructure | Yes | Yes |
BART | Yes | Yes |
Whisper | No | Yes |
T5 | No | Yes |
Other enc/dec models | No | Yes |
Quantization | Untested | Yes |
Multimodality | No | Yes |
Attention backends other than Xformers (esp. flash-attn, flashinfer) | No | Yes |
Custom attention bias support | No | Yes |
CUDAGraph | No (Issue #7447) |
Yes |
Pipeline parallelism | No | Yes |
Speculative decoding | No | Low-priority but nice-to-have; difficult. |
Automatic prefix caching | No | Low-priority; difficult. |
Sliding window | No | No |
Chunked prefill | No | No |
LoRA | No | No |
This RFC gives an overview of those features & models which are not compatible with encoder/decoder currently, but which should be made compatible eventually (i.e. No in the second column, Yes in the third column in the support matrix.)
Note that there are features (automatic prefix caching/sliding window/chunked prefill/LoRA) which are not long-term compatibility goals.
Background #
Before continuing, it will be helpful to review the details of the new vLLM encoder/decoder infrastructure.
It will also be helpful to review this how-to guide for adding new encoder/decoder models & improving encoder/decoder feature compatibility.
Initial goal #
Members of the vLLM contributor community identify models/features in the support matrix above, for which they will work on writing a PR.
Detailed long-term goals #
Add new models to vLLM #
Please review the how-to guide for adding new models to vLLM
See tests/models/test_bart.py
for an example of an encoder/decoder model unit test. See tests/distributed/test_basic_distributed_correctness_enc_dec.py
for an example of an encoder/decoder model test with TP > 1.
Add Whisper model #
Steps to add support for Whisper, a multimodal encoder/decoder speech recognition model:
- Extend existing vLLM multimodality support to encoder/decoder models
- Extend existing vLLM prompt processing pipeline to support audio
- Port HuggingFace Whisper model to vLLM; an existing open PR for this workstream is Whisper support #5964
- Modify each Whisper layer, where appropriate, to support TP > 1
- Add a Whisper test under
tests/models/
Proposal: consider whether or not it makes sense to implement encoder/decoder multimodality, audio support, and Whisper in the same PR; that way, the Whisper model may be used to facilitate an end-to-end test with of audio multimodality.
Add T5 model #
Note: T5 depends on custom attention bias being supported by at least one of the attention backends which also supports encoder attention & cross-attention; at time of writing no vLLM attention backend fulfills this requirement. The vLLM XFormers attention backend is the only backend which supports encoder/decoder models but neither it nor any other vLLM attention backend supports custom attention bias. (Custom attention bias is required in order to support T5 relative positional encoding.)
Steps to add support for the T5 model:
- Port HuggingFace T5 model to vLLM
- This includes porting over the method which computes the custom attention bias matrix for T5 relative position encoding
- Modify each T5 layer, where appropriate, to support TP > 1
- The custom attention bias computation must also support TP > 1
- Add a T5 test to
tests/models/
Note: T5 was added to an older version of vLLM in #3117 , which could be a helpful starting-point
Add other encoder/decoder models
- Review open vLLM issues on GitHub and identify other encoder/decoder models which are requested by users
Quantization #
The goal of this workstream is to make sure that quantization + encoder/decoder models is fully-tested, and to fill in any gaps (should they exist) in vLLM's support for quantized encoder/decoder models.
Steps to ensure that vLLM supports encoder/decoder models in combination with all existing vLLM quantization methods:
- Identify the list of quantization methods which vLLM currently supports with decoder-only models.
- Add unit tests for encoder/decoder models with all of these quantization methods.
- Determine which quantization methods are currently incompatible with vLLM encoder/decoder infrastructure.
- Scope out the effort involved in making these quantization methods compatible & submit a PR making the change.
vLLM encoder/decoder infrastructure should be compatible with most of the existing vLLM quantization methods, because the specialized quantization kernels are only employed for GEMM operations involving the learned weight matrices (Attention(q, k, v, kv_cache)
layer behaves & does not impact the learned weight matrices at all.
It is less clear whether vLLM encoder/decoder infrastructure would be incompatible with FP8. It does appear that a specialized quantized KV cache kernel is employed by the Attention(q, k, v, kv_cache)
layer when FP8 quantization is employed.
Support encoder/decoder multimodality #
Technically, vLLM already supports multimodality for models which have an "encoder" and a "decoder", i.e. Llava. However, Llava's decoder does not utilize cross-attention & the model is basically compatible with vLLM's pre-existing decoder-only infrastructure.
But critically, for encoder/decoder models with cross-attention such as Whisper vLLM does not currently support multimodality of any sort. The processing pipeline does not extract or utilize multimodal data from the input prompt, and the EncoderDecoderModelRunner
has an assert which fails if the multimodal config is not None
. Addressing this is what is meant by "supporting encoder/decoder multimodality".
Steps to extend existing vLLM multimodality support to encoder/decoder models:
- Review existing vLLM multimodality support in the decoder-only pipeline
- Scope out a plan for adding encoder/decoder multimodality support.
- Propose & implement one or more multimodal prompt formats for encoder/decoder models
- Integrate multimodality support into encoder/decoder processing pipeline
- Remove the assertion which fails when multimodality is enabled for an encoder/decoder model (see
assert_enc_dec_mr_supported_scenario()
invllm/worker/utils.py
) - Add one or more unit tests with multimodal data
There are a number of multimodal encoder/decoder models which will benefit from this feature. One possibility is to add multimodality support & a multimodal model such as Whisper in the same PR, so that Whisper may be used to facilitate an end-to-end test with multimodality.
Another possibility is to implement multimodality support in its own PR.
Considerations for designing multimodal encoder/decoder prompt formats #
One approach to designing the vLLM multimodal encoder/decoder prompt formats, is to consider what we want the user experience to be for high-priority multimodal encoder/decoder models such as
Initial proposal for multimodal encoder/decoder prompt formats
It may be helpful to review
- The non-multimodal encoder/decoder prompt formats which are currently supported by vLLM: singleton prompts (raw text prompt,
TextPrompt
,TokensPrompt
) as well asExplicitEncoderDecoder
prompts - The multimodal decoder-only prompt formats which are currently supported by vLLM; search for
multi_modal_data
here and also review the vLLM documentation on multimodality
Generally speaking, in encoder/decoder models based on cross-attention, the non-text input modality is passed to the encoder as input. Conversely, any text prompt is typically passed to the decoder as a input prompt.
The following two encoder/decoder multimodal prompt formats are tentatively proposed:
-
Singleton
TextPrompt
withmulti_modal_data
field- vLLM will extract the
multi_modal_data
and pass it to the encoder module - vLLM will extract the prompt text, tokenize it and pass the token-list to the decoder (note that this is the opposite of vLLM behavior for non-multimodal prompts, where the prompt text would be passed to the encoder.)
For example passing the
TextPrompt
below to vLLM BARTTextPrompt( 'prompt': "The rain in spain falls mainly on the", 'multi_modal_data': <multi modal data structure> )
results in
Encoder input: <multi modal data structure> Decoder prompt: "The rain in spain falls mainly on the"
- vLLM will extract the
-
Singleton
TokensPrompt
withmulti_modal_data
field- vLLM will extract the
multi_modal_data
and pass it to the encoder module - vLLM will extract the token list and pass it unmodified to the decoder (note that this is the opposite of vLLM behavior for non-multimodal prompts, where the prompt tokens would be passed to the encoder.)
For example passing the
TokensPrompt
below to vLLM BARTTokensPrompt( 'prompt_tokens': [2,0,171,5,2], 'multi_modal_data': <multi modal data structure> )
results in
Encoder prompt: <multi modal data structure> Decoder prompt: [2,0,171,5,2]
- vLLM will extract the
It may also be worth considering whether or how to support
ExplicitEncoderDecoderPrompt
s with multimodality- An input prompt format which encapsulates only multimodal encoder inputs, with no associated decoder text/tokens prompt (this would result in the decoder being passed a "default" or empty prompt.)
Add support for encoder attention and cross-attention to additional backends #
At time of writing, XFormers is the only vLLM attention backend which supports encoder attention & cross-attention.
The goal of this workstream would be to extend encoder attention & cross-attention support to additional backends, the highest-priority being flash-attention and flashinfer.
Reviewing encoder attention and cross-attention support in the XFormers backend would be a good starting-point for extending support to other models.
For context on the requirements for a backend to support encoder and cross-attention, it may help to review the encoder/decoder architecture, the way that attention masks are currently constructed in the XFormers backend, and the recommended architecture for vLLM encoder/decoder models.
A summary of the key changes required for an attention backend to support encoder attention and cross-attention:
- The backend's
AttentionMetadata
subclass must support fields for encoder sequence lengths, encoder sequence token count, cross-attention blocktables, and cross-attention slot mapping. XFormers examples: - The
forward()
method of the backend implementation must accept anattn_type
argument of typeAttentionType
, which allows choosing between encoder attention, decoder attention, or encoder/decoder cross-attention. XFormers example - The backend implementation must recognize which option has been chosen for
attn_type
, and adjust accordingly in terms of (1) how it utilizesattn_metadata
when invoking the attention kernels (review XFormersforward()
for context), and (2) the choice of causal or non-causal attention, as well the choice of attention mask shape (XFormers example).
Initial goals
- Identify the changes required to add encoder attention & cross-attention support to flash-attention and flashinfer
- PR the required changes
- Remove/modify any asserts which fail if the vLLM attention backend is not XFormers
- Currently, the
__init__()
method ofEncoderDecoderModelRunner
invokes a methodEncoderDecoderModelRunner._maybe_force_supported_attention_backend()
defined here which (1) attempts to force encoder/decoder models to use XFormers attention backend, and (2) raises an exception if the user has overridden the attention backend to be anything other than XFormers.
- Currently, the
- Remove/modify any asserts which fail if the vLLM attention backend is not XFormers
Long-term goals
- All vLLM attention backends support encoder attention and cross-attention
Support custom attention bias #
Note: T5 takes a dependency on custom attention bias. Custom attention bias is likely complex enough to merit its own PR.
Note: custom bias support was added to PagedAttention
in an older version of vLLM as part of #3117 ; given changes in vLLM since then, additional work would be required to integrate this implementation.
Custom attention bias and relative positional encoding
Attention bias refers to adding a matrix
Here, custom attention bias is understood to mean that the vLLM attention backend allows
There are broadly two possible approaches to custom attention bias, which do not necessarily have to be mutually-exclusive:
-
$A$ is a fully-materialized attention bias matrix passed to the attention backend -
$A$ is computed on-the-fly by the attention kernel, using an element-wise formula for the attention bias which is fused with the$Q K^T$ and$softmax$ computations
T5 employs custom attention bias in order to implement relative positional encoding, wherein pairwise positional relationships between tokens are represented by the bias matrix. The HuggingFace Transformers T5 implementation provides an example of how the relative positional encoding matrix is computed.
Existing attention bias support
Currently, no vLLM attention backend fully supports passing in a custom attention bias. This is primarily due to underlying kernel limitations. For example, the xFormers memory_efficient_attention_forward
kernel is the only NVIDIA-GPU-oriented kernel which permits passing in an arbitrary PyTorch tensor as a materialized attention bias (via the attn_bias
argument) (at time of writing I have not investigated if custom attention bias is supported by any of the kernels for AMD GPU, CPU, etc.) Regardless, vLLM only employs xFormers memory_efficient_attention_forward
for prefill; to my knowledge, none of the decode-phase kernels employed by vLLM can accept an arbitrary tensor as a custom attention bias, making custom attention bias impossible to apply end-to-end for both prefill and decode under the current vLLM implementation.
In addition to lack of kernel-level support for custom attention bias, most vLLM backends also prevent passing a custom attention bias matrix to the underlying kernel. The exception is the XFormers backend, which accepts an attention bias via XFormersMetadata.attn_bias
attribute (however the XFormers backend only utilizes attn_bias
in the prefill phase.)
Proposed methods for supporting custom attention bias
Here the following two approaches for supporting custom attention bias in vLLM are proposed:
-
Fully-materialized bias matrix: Modify vLLM attention backends to accept an arbitrary PyTorch tensor, passed into the backend via the
AttentionMetadata.attn_bias
field. -
On-the-fly/fused bias matrix computation: Enable an efficient workflow whereby vLLM developers can tweak an attention kernel to compute the custom attention bias on the fly
- For example: rather than computing the T5 relative position encoder bias matrix once, instead the attention kernel can fuse the element-wise bias matrix formula with the
$Q K^T$ and$softmax()$ . The attention bias matrix is never fully materialized. - FlexAttention enables fused custom attention bias computations in a FlashAttention-style kernel, using torch.compile.
- For example: rather than computing the T5 relative position encoder bias matrix once, instead the attention kernel can fuse the element-wise bias matrix formula with the
It may make sense to support one or both of these methods.
Note that custom attention bias support must be added on a backend-by-backend basis, because of the kernel modifications & backend logic changes required.
Initial goals for introducing custom attention bias support
- Focus on a particular vLLM attention backend
- Suggestion: focus on an attention backend which also supports encoder/decoder models, in order to facilitate running T5. At time of writing, XFormers is the only backend which supports encoder/decoder models, however there will likely be work on supporting encoder/decoder in additional attention backends.
- Scope out the effort involved in introducing custom attention bias support to this backend
- Some steps which will likely be involved in introducing custom attention bias support:
- Augment attention backend's kernels to accept custom attention bias; for example, the PagedAttention kernel (for XFormers backend), the Flash-attention kernel (for the flash-attn backend), or the Flashinfer kernels (for the Flashinfer backend)
- (Except for XFormers) add an
attn_bias
attribute to attention backend'sAttentionMetadata
subclass - Ensure that the attention backend passes the
attn_bias
attribute to both the prefill and decode kernels
- Add at least two custom attention bias unit tests (for prefill & decode respectively)
Final goals for introducing custom attention bias support
- All vLLM attention backends should support custom attention bias, with unit tests
Some links which may be helpful for understanding how causal & non-causal attention masks are currently configured in vLLM:
-
Invocation of flash-attention for prefill in vLLM backend, using
causal
flag -
Invocation of FlashInfer attention kernel for prefill in backend, using
causal
flag -
Invocation of PagedAttention kernel for decode in vLLM backend
Support CUDAGraph with encoder/decoder models #
Note: this topic is being tracked by Issue #7447
Steps to support CUDAGraph with encoder/decoder models:
- Scope out the effort require to support CUDAGraph with encoder/decoder models
- Write a PR for CUDAGraph + encoder/decoder
- Remove the assertion which fails when CUDAGraph is enabled for an encoder/decoder model (see
assert_enc_dec_mr_supported_scenario()
invllm/worker/utils.py
)
- Remove the assertion which fails when CUDAGraph is enabled for an encoder/decoder model (see
Support pipeline-parallelism with encoder/decoder models #
Steps to support pipeline-parallelism with encoder/decoder models:
- Scope out the effort required to support pipeline-parallelism with encoder/decoder models
- Write a PR for pipeline-parallelism + encoder/decoder
- Remove the assertion which fails when pipeline-parallelism is enabled for an encoder/decoder model (see
assert_enc_dec_mr_supported_scenario()
invllm/worker/utils.py
)
- Remove the assertion which fails when pipeline-parallelism is enabled for an encoder/decoder model (see
Support multi-step scheduling with encoder/decoder models #
Note: depends on #7000 landing in order to add multi-step scheduling support; it may be helpful to review the multi-step scheduling RFC ( #6854 )
Steps to support multi-step scheduling with encoder/decoder models:
- Scope out the effort required to support multi-step scheduling
EncoderDecoderModelRunner
multi-step support
- Write a PR for multi-step scheduling + encoder/decoder
- Write at least one test of an encoder/decoder model with multi-step scheduling
Low-priority high-effort tasks #
- Speculative decoding
- Automatic prefix caching
Here it is proposed that these features are low-priority. Adding support for speculative decoder and automatic prefix caching would require a significant of effort to scope out and design the implementations.
Note that adding support for either of these features would require removing the assertions which fail when speculative decoding or automatic prefix caching are enabled for an encoder/decoder model (see assert_enc_dec_mr_supported_scenario()
in vllm/worker/utils.py
)
Feedback Period.
Closed.
CC List.
@WoosukKwon
@robertgshaw2-neuralmagic
@mgoin
@tms
@njhill
@sroy745
@ywang96
@DarkLight1337
@js8544
Any Other Things.
No response