Skip to content

[RFC]: Changes to support attention + quant fusion #16220

Open
@ProExpertProg

Description

@ProExpertProg

Motivation.

I am working on fusing FP8 quantization onto (V0) attention with a new torch.compile pass. However, this pass will require some other changes to core vLLM code so I just wanted to collect people's opinions before proceeding.

AMD saw manually fusing quantization into the attention op lead to a major improvement in performance in the ROCm fork. Basically, when fused, the attention output is quantized inside the attention Triton kernel before being written out, saving data movement by avoiding a full roundtrip to global memory at 16-bit precision.

To avoid modifying the model definitions (approach used in ROCm fork), we are planning to do the fusion using a torch.compile pass. Apart from the pass implementation, there are two obstacles in the Python integration code:

  1. Attention is wrapped into a unified_attention custom op
  2. In V1, piecewise compilation means that the ops we want to fuse appear in different graphs

Proposed Change.

1. Extend unified_attention

EDIT: @youkaichao has suggested to directly store output_scale on the attention object, which will also help more easily figure out if the attention backend supports fusing the op. So the proposal below is not necessary unless the approach does not work.

OLD:
For the first issue, I propose adding two optional parameters to the unified_attention op:

  • output_scale: Optional[torch.Tensor] = None
  • quant_config: Optional[QuantConfig] = None, which contains:
    • dynamic: bool (is the quantization dynamic or static)
    • per_token: bool (is quant per-token or per-tensor - this could later be extended to other quantization schemes (group, etc.) if necessary)
    • dtype: torch.dtype
    • This could also be named o_quant_config or out_quant_config not to be confused with quantization config for QKV and P.

This way, we can make the following fx.Graph replacement:

# Original
output = torch.ops.vllm.unified_attention(q, k, v, 'model.layers.0.self_attn.attn')
output_quant = torch.empty(...)
torch.ops._C.static_scaled_fp8_quant(output_quant, output, scale) # in-place op

# Fused 
output_quant = torch.ops.vllm.unified_attention(q, k, v, 'model.layers.0.self_attn.attn', output_scale=scale, quant_config=(True, True, torch.fp8_e4m3fn))

unified_attention_with_output (often used in V1) can use the same approach with slightly different handling for output (and a separate pattern matcher pattern). If we only used unified_attention_with_output, quant_config.dtype is redundant with the passed output dtype but we need to support both and redundancy shouldn't hurt us here.

2. Add "sticky ops" for graph splitting

Once we want to add support for attention+quant fusion to V1, we'd have to revisit our approach to splitting graphs into piecewise: basically, we need to make sure that scaled_fp8_quant ends in the same graph as attention. For that, I suggest introducing the concept of "sticky ops", whether that's by name or via callback, which are ops that use the output of attention and end up in the attention sub-graph. We can only enable this when fusion is enabled & possible to make sure we still benefit from CUDA graphs for these ops when not fusing.

Finally, this will further cement our reliance on pattern matching in-place ops. So we should figure out #14703.

Feedback Period.

This week (April 7-11)

CC List.

@youkaichao @WoosukKwon @tlrmchlsmth @SageMoore @LucasWilkinson @rasmith @gshtras

Any Other Things.

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions