Description
Motivation.
I am working on fusing FP8 quantization onto (V0) attention with a new torch.compile
pass. However, this pass will require some other changes to core vLLM code so I just wanted to collect people's opinions before proceeding.
AMD saw manually fusing quantization into the attention op lead to a major improvement in performance in the ROCm fork. Basically, when fused, the attention output is quantized inside the attention Triton kernel before being written out, saving data movement by avoiding a full roundtrip to global memory at 16-bit precision.
To avoid modifying the model definitions (approach used in ROCm fork), we are planning to do the fusion using a torch.compile
pass. Apart from the pass implementation, there are two obstacles in the Python integration code:
- Attention is wrapped into a
unified_attention
custom op - In V1, piecewise compilation means that the ops we want to fuse appear in different graphs
Proposed Change.
1. Extend unified_attention
EDIT: @youkaichao has suggested to directly store output_scale
on the attention object, which will also help more easily figure out if the attention backend supports fusing the op. So the proposal below is not necessary unless the approach does not work.
OLD:
For the first issue, I propose adding two optional parameters to the unified_attention
op:
output_scale: Optional[torch.Tensor] = None
quant_config: Optional[QuantConfig] = None
, which contains:dynamic: bool
(is the quantization dynamic or static)per_token: bool
(is quant per-token or per-tensor - this could later be extended to other quantization schemes (group, etc.) if necessary)dtype: torch.dtype
- This could also be named
o_quant_config
orout_quant_config
not to be confused with quantization config forQKV
andP
.
This way, we can make the following fx.Graph
replacement:
# Original
output = torch.ops.vllm.unified_attention(q, k, v, 'model.layers.0.self_attn.attn')
output_quant = torch.empty(...)
torch.ops._C.static_scaled_fp8_quant(output_quant, output, scale) # in-place op
# Fused
output_quant = torch.ops.vllm.unified_attention(q, k, v, 'model.layers.0.self_attn.attn', output_scale=scale, quant_config=(True, True, torch.fp8_e4m3fn))
unified_attention_with_output
(often used in V1) can use the same approach with slightly different handling for output
(and a separate pattern matcher pattern). If we only used unified_attention_with_output
, quant_config.dtype
is redundant with the passed output dtype but we need to support both and redundancy shouldn't hurt us here.
2. Add "sticky ops" for graph splitting
Once we want to add support for attention+quant fusion to V1, we'd have to revisit our approach to splitting graphs into piecewise: basically, we need to make sure that scaled_fp8_quant
ends in the same graph as attention. For that, I suggest introducing the concept of "sticky ops", whether that's by name or via callback, which are ops that use the output of attention and end up in the attention sub-graph. We can only enable this when fusion is enabled & possible to make sure we still benefit from CUDA graphs for these ops when not fusing.
Finally, this will further cement our reliance on pattern matching in-place ops. So we should figure out #14703.
Feedback Period.
This week (April 7-11)
CC List.
@youkaichao @WoosukKwon @tlrmchlsmth @SageMoore @LucasWilkinson @rasmith @gshtras
Any Other Things.
No response
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.