Skip to content

[torch.compile][ROCm] Fuse quantization onto attention using a torch.compile pass #16756

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
db3857b
collapse multiple reshapes into 1
ProExpertProg Apr 16, 2025
858c09e
before modifying the attention op, manual replacement
ProExpertProg Apr 13, 2025
6de4cd9
Add output_scale to all attention backends
ProExpertProg Apr 16, 2025
e8c9407
AttentionFusionPass works! But the Triton kernel exits spectacularly,…
ProExpertProg Apr 17, 2025
03b3b50
Add realistic test for triton_flash_attention, passes, unclear why.
ProExpertProg Apr 17, 2025
4c116ae
Add flag for attention fusion pass
ProExpertProg Apr 17, 2025
fb460d9
Manually quantize in decode
ProExpertProg Apr 17, 2025
efa02b4
Fix tmp_output dtype
SageMoore Apr 18, 2025
c849929
More representative test, does not work though.
ProExpertProg Apr 18, 2025
e4c6fb8
Cleanup fusion pass, add timing logic
ProExpertProg Apr 18, 2025
0a59a0e
Test in progress
ProExpertProg Apr 22, 2025
8f0fe2f
Integrate fused out quant for custom paged attention, fix dtype, elim…
ProExpertProg Apr 25, 2025
c0a0d96
with pattern match
ProExpertProg Apr 22, 2025
8f067f3
Add trace function for reshapes, only TODO integrate with custom checks
ProExpertProg Apr 25, 2025
3c9cc6b
PatternMatcher approach working!
ProExpertProg Apr 30, 2025
80d0052
PR comments, cleanup
ProExpertProg May 1, 2025
0be957f
temp debugging of unittest
ProExpertProg May 5, 2025
fef905c
Comment
ProExpertProg May 7, 2025
5009e6b
PR comments
ProExpertProg May 13, 2025
adb01c3
TEMP attntion+quant test
ProExpertProg Jun 9, 2025
02e494e
Refactored ops check in backend, util for finding nodes
ProExpertProg Jun 10, 2025
06c6d87
Working attention fusion test
ProExpertProg Jun 10, 2025
7cec3c4
PR comments, GroupShape, fix typing default
ProExpertProg Jun 10, 2025
2c68f97
Test working with dynamo reset, prompt 4 still different.
ProExpertProg Jun 10, 2025
f352856
Fixed recompilation issue, test working
ProExpertProg Jun 11, 2025
a3e68d9
Format
ProExpertProg Jun 11, 2025
379b55d
Format, fix Dynamo caching
ProExpertProg Jun 11, 2025
a314e87
Add test to CI
ProExpertProg Jun 11, 2025
6408ebf
Remove test changes
ProExpertProg Jun 11, 2025
f2b0d01
More cleanup for tests:
ProExpertProg Jun 11, 2025
98de2f9
Add output_scale to new attn backends
ProExpertProg Jun 11, 2025
1a8a794
pre-commit fix
ProExpertProg Jun 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ steps:
commands:
- pytest -v -s compile/test_pass_manager.py
- pytest -v -s compile/test_fusion.py
- pytest -v -s compile/test_fusion_attn.py
- pytest -v -s compile/test_silu_mul_quant_fusion.py
- pytest -v -s compile/test_sequence_parallelism.py
- pytest -v -s compile/test_async_tp.py
Expand Down
30 changes: 16 additions & 14 deletions tests/compile/backend.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Sequence
from copy import deepcopy
from typing import Callable, Union

from torch import fx
from torch._ops import OpOverload

from vllm.compilation.fx_utils import (find_specified_fn,
find_specified_fn_maybe)
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.inductor_pass import InductorPass
from vllm.config import get_current_vllm_config

Expand Down Expand Up @@ -48,18 +49,19 @@ def post_pass(self, graph: fx.Graph):
# assign by reference, will reflect the final state of the graph
self.final_graph = graph

def check_before_ops(self, ops,
find_fn=find_specified_fn, \
find_fn_maybe=find_specified_fn_maybe, \
ops_fully_replaced=True):
def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True):
for op in ops:
find_fn(self.graph_pre_pass.nodes, op)
if ops_fully_replaced:
assert find_fn_maybe(self.graph_post_pass.nodes, op) is None
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph"
assert num_pre > num_post, f"All nodes remain for op {op.name()}"
if fully_replaced:
assert num_post == 0, \
f"Unexpected op {op.name()} in post-pass graph"

def check_after_ops(self, ops,
find_fn=find_specified_fn, \
find_fn_maybe=find_specified_fn_maybe):
def check_after_ops(self, ops: Sequence[OpOverload]):
for op in ops:
find_fn(self.graph_post_pass.nodes, op)
assert find_fn_maybe(self.graph_pre_pass.nodes, op) is None
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"
3 changes: 1 addition & 2 deletions tests/compile/test_async_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,7 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,

# In pre-nodes, all gather or reduce scatter should exist,
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
backend.check_before_ops(model.ops_in_model_before(),
ops_fully_replaced=False)
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)

# In post-nodes, fused_matmul_reduce_scatter or \
# fused_all_gather_matmul should exist
Expand Down
12 changes: 5 additions & 7 deletions tests/compile/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import vllm.envs as envs
import vllm.plugins
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
FusionPass, QuantKey)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
FusionPass, GroupShape, QuantKey)
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
VllmConfig)
Expand All @@ -30,9 +29,10 @@ def __init__(self, hidden_size: int, eps: float, static: bool,
self.cutlass_fp8_enabled = cutlass_fp8_enabled
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
self.key = QuantKey(dtype=FP8_DTYPE,
static=static,
per_tensor=static,
group_shape=group_shape,
symmetric=True)
if static:
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
Expand Down Expand Up @@ -122,9 +122,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)

# In pre-nodes, fp8 quant should be there and fused kernels should not
backend.check_before_ops(model.ops_in_model_before(), find_auto_fn,
find_auto_fn_maybe)
backend.check_before_ops(model.ops_in_model_before())

# In post-nodes, fused kernels should be there and fp8 quant should not
backend.check_after_ops(model.ops_in_model_after(), find_auto_fn,
find_auto_fn_maybe)
backend.check_after_ops(model.ops_in_model_after())
131 changes: 131 additions & 0 deletions tests/compile/test_fusion_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional

import pytest
import torch._dynamo

from tests.compile.backend import TestBackend
from tests.models.utils import check_outputs_equal
from vllm import LLM, SamplingParams
from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.platforms import current_platform

# globals needed for string-import custom Dynamo backend field
backend: Optional[TestBackend] = None
backend_unfused: Optional[TestBackend] = None


@pytest.mark.parametrize(
"model, quant_key",
[("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)])
@pytest.mark.parametrize(
"use_triton_fa", [True, False] if current_platform.is_rocm() else [False])
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test CUDA and ROCm")
def test_attention_fusion(example_prompts, monkeypatch, model: str,
quant_key: QuantKey, use_triton_fa: bool):
# Clean Dynamo cache to avoid reusing other test cases
# (for some reason the reset at the end is not enough)
torch._dynamo.reset()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 this was a weird thing I ran into, wasn't sure it was worth running down but lmk if you want help with a repro


# Use global backends
global backend, backend_unfused

use_v1 = False # can be made a param once V1 support added
monkeypatch.setenv("VLLM_USE_V1", str(int(use_v1)))
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", str(int(use_triton_fa)))

# Prompt 4 seems too open-ended, differs between fused and unfused
# (both outputs look reasonable though)
prompts = example_prompts[:4] + example_prompts[5:]

compile_config = CompilationConfig(
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
# DYNAMO_ONCE does not properly propagate shapes.
level=CompilationLevel.DYNAMO_AS_IS,
backend="tests.compile.test_fusion_attn.backend_unfused",
)
vllm_config = VllmConfig(compilation_config=compile_config)
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))

llm = LLM(model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.9,
max_model_len=2048)

sampling_params = SamplingParams(temperature=0.0,
max_tokens=10,
top_p=0.95)

unfused_output = llm.generate(prompts, sampling_params)
backend_unfused = None # Reset backend to make sure llm gets released
del llm

compile_config = CompilationConfig(
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
# DYNAMO_ONCE does not properly propagate shapes.
level=CompilationLevel.DYNAMO_AS_IS,
backend="tests.compile.test_fusion_attn.backend",
)
vllm_config = VllmConfig(compilation_config=compile_config)

# AttnFusionPass needs attention layers to be registered in config upon init
# so we initialize it during compilation.
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw)
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
llm2 = LLM(model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.9,
max_model_len=2048)

# check support
attn_fusion_supported = [
layer.impl.fused_output_quant_supported(quant_key.dtype,
quant_key.static,
quant_key.group_shape)
for key, layer in compile_config.static_forward_context.items()
]

print(f"{attn_fusion_supported=}")
if any(attn_fusion_supported):
# Check quant ops
backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)

# attention ops present in both, just output_scale param changes
attn_nodes_pre = list(find_op_nodes(ATTN_OP, backend.graph_pre_pass))
attn_nodes_post = list(find_op_nodes(ATTN_OP, backend.graph_post_pass))
assert len(attn_nodes_pre) == len(attn_nodes_post)

for i in range(len(attn_nodes_pre)):
assert attn_nodes_pre[i].kwargs["output_scale"] is None
fused = attn_nodes_post[i].kwargs["output_scale"] is not None
assert fused == attn_fusion_supported[i], \
f"Node {i} {'' if fused else 'not '} expected " \
f"to have fused output quant"

# check outputs
fused_output = llm2.generate(prompts, sampling_params)

# transform outputs to format expected by check_outputs_equal
sample_outs = lambda s: (list(s.token_ids), s.text)
outs_lst = lambda ros: [sample_outs(ro.outputs[0]) for ro in ros]

check_outputs_equal(
outputs_0_lst=outs_lst(unfused_output),
outputs_1_lst=outs_lst(fused_output),
name_0="unfused",
name_1="fused",
)

# Clean Dynamo cache to avoid polluting other case(s)
torch._dynamo.reset()

# Reset backend to make sure llm2 gets released
backend = None
8 changes: 7 additions & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,7 @@ def scaled_fp8_quant(
num_token_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
output: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
Expand Down Expand Up @@ -1259,7 +1260,12 @@ def scaled_fp8_quant(
out_dtype: torch.dtype = current_platform.fp8_dtype()
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype)
if output is None:
output = torch.empty(shape, device=input.device, dtype=out_dtype)
else:
assert num_token_padding is None, \
"padding not supported if output passed in"
assert output.dtype == out_dtype

if scale is None:
if use_per_token_if_dynamic:
Expand Down
17 changes: 17 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,25 @@ def forward(
kv_cache: torch.Tensor,
attn_metadata: T,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError

def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: tuple[int, int]):
"""
Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization
onto implementations that support it.

TODO(luka) merge parameters into QuantDescriptor
:param dtype: quantized dtype
:param static: static or dynamic quantization
:param group_shape: quant group shape. (-1, -1) for per-tensor.
:return: is fusion supported for this type of quantization
"""
return False


class MLAAttentionImpl(AttentionImpl[T], Generic[T]):

Expand All @@ -300,6 +316,7 @@ def forward(
kv_cache: torch.Tensor,
attn_metadata: T,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError

Expand Down
6 changes: 6 additions & 0 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ def forward(
kv_cache: torch.Tensor,
attn_metadata: BlocksparseFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.

Expand All @@ -388,6 +389,11 @@ def forward(
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for BlocksparseFlashAttentionImpl")

num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
Expand Down
9 changes: 9 additions & 0 deletions vllm/attention/backends/dual_chunk_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,8 @@ def forward( # type: ignore
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: DualChunkFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with DualChunkFlashAttention.
Args:
Expand All @@ -383,6 +385,13 @@ def forward( # type: ignore
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is None, "Output tensor not supported for DualChunk"

if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")

(
query,
query_succ,
Expand Down
6 changes: 6 additions & 0 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,7 @@ def forward(
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.

Expand All @@ -692,6 +693,11 @@ def forward(
"""
assert output is not None, "Output tensor must be provided."

if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")

# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16:
assert (
Expand Down
6 changes: 6 additions & 0 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,8 +975,14 @@ def forward(
kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:

if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashInferImpl")

# TODO: directly write to output tensor
num_heads: int = self.num_heads
head_size: int = self.head_size
Expand Down
6 changes: 6 additions & 0 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def forward(
kv_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.

Expand All @@ -193,6 +194,11 @@ def forward(
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for HPUAttentionImpl")

batch_size, seq_len, hidden_size = query.shape
_, seq_len_kv, _ = key.shape

Expand Down
Loading