-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[torch.compile][ROCm] Fuse quantization onto attention using a torch.compile pass #16756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
simon-mo
merged 32 commits into
vllm-project:main
from
neuralmagic:luka/fusion-attention-fp8
Jun 12, 2025
Merged
Changes from all commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
db3857b
collapse multiple reshapes into 1
ProExpertProg 858c09e
before modifying the attention op, manual replacement
ProExpertProg 6de4cd9
Add output_scale to all attention backends
ProExpertProg e8c9407
AttentionFusionPass works! But the Triton kernel exits spectacularly,…
ProExpertProg 03b3b50
Add realistic test for triton_flash_attention, passes, unclear why.
ProExpertProg 4c116ae
Add flag for attention fusion pass
ProExpertProg fb460d9
Manually quantize in decode
ProExpertProg efa02b4
Fix tmp_output dtype
SageMoore c849929
More representative test, does not work though.
ProExpertProg e4c6fb8
Cleanup fusion pass, add timing logic
ProExpertProg 0a59a0e
Test in progress
ProExpertProg 8f0fe2f
Integrate fused out quant for custom paged attention, fix dtype, elim…
ProExpertProg c0a0d96
with pattern match
ProExpertProg 8f067f3
Add trace function for reshapes, only TODO integrate with custom checks
ProExpertProg 3c9cc6b
PatternMatcher approach working!
ProExpertProg 80d0052
PR comments, cleanup
ProExpertProg 0be957f
temp debugging of unittest
ProExpertProg fef905c
Comment
ProExpertProg 5009e6b
PR comments
ProExpertProg adb01c3
TEMP attntion+quant test
ProExpertProg 02e494e
Refactored ops check in backend, util for finding nodes
ProExpertProg 06c6d87
Working attention fusion test
ProExpertProg 7cec3c4
PR comments, GroupShape, fix typing default
ProExpertProg 2c68f97
Test working with dynamo reset, prompt 4 still different.
ProExpertProg f352856
Fixed recompilation issue, test working
ProExpertProg a3e68d9
Format
ProExpertProg 379b55d
Format, fix Dynamo caching
ProExpertProg a314e87
Add test to CI
ProExpertProg 6408ebf
Remove test changes
ProExpertProg f2b0d01
More cleanup for tests:
ProExpertProg 98de2f9
Add output_scale to new attn backends
ProExpertProg 1a8a794
pre-commit fix
ProExpertProg File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from typing import Optional | ||
|
||
import pytest | ||
import torch._dynamo | ||
|
||
from tests.compile.backend import TestBackend | ||
from tests.models.utils import check_outputs_equal | ||
from vllm import LLM, SamplingParams | ||
from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym | ||
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass | ||
from vllm.compilation.fx_utils import find_op_nodes | ||
from vllm.compilation.noop_elimination import NoOpEliminationPass | ||
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig | ||
from vllm.platforms import current_platform | ||
|
||
# globals needed for string-import custom Dynamo backend field | ||
backend: Optional[TestBackend] = None | ||
backend_unfused: Optional[TestBackend] = None | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"model, quant_key", | ||
[("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)]) | ||
@pytest.mark.parametrize( | ||
"use_triton_fa", [True, False] if current_platform.is_rocm() else [False]) | ||
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") | ||
@pytest.mark.skipif(not current_platform.is_cuda_alike(), | ||
reason="Only test CUDA and ROCm") | ||
def test_attention_fusion(example_prompts, monkeypatch, model: str, | ||
quant_key: QuantKey, use_triton_fa: bool): | ||
# Clean Dynamo cache to avoid reusing other test cases | ||
# (for some reason the reset at the end is not enough) | ||
torch._dynamo.reset() | ||
|
||
# Use global backends | ||
global backend, backend_unfused | ||
|
||
use_v1 = False # can be made a param once V1 support added | ||
monkeypatch.setenv("VLLM_USE_V1", str(int(use_v1))) | ||
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", str(int(use_triton_fa))) | ||
|
||
# Prompt 4 seems too open-ended, differs between fused and unfused | ||
# (both outputs look reasonable though) | ||
prompts = example_prompts[:4] + example_prompts[5:] | ||
|
||
compile_config = CompilationConfig( | ||
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation | ||
# DYNAMO_ONCE does not properly propagate shapes. | ||
level=CompilationLevel.DYNAMO_AS_IS, | ||
backend="tests.compile.test_fusion_attn.backend_unfused", | ||
) | ||
vllm_config = VllmConfig(compilation_config=compile_config) | ||
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config)) | ||
|
||
llm = LLM(model, | ||
enforce_eager=True, | ||
compilation_config=compile_config, | ||
gpu_memory_utilization=0.9, | ||
max_model_len=2048) | ||
|
||
sampling_params = SamplingParams(temperature=0.0, | ||
max_tokens=10, | ||
top_p=0.95) | ||
|
||
unfused_output = llm.generate(prompts, sampling_params) | ||
backend_unfused = None # Reset backend to make sure llm gets released | ||
del llm | ||
|
||
compile_config = CompilationConfig( | ||
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation | ||
# DYNAMO_ONCE does not properly propagate shapes. | ||
level=CompilationLevel.DYNAMO_AS_IS, | ||
backend="tests.compile.test_fusion_attn.backend", | ||
) | ||
vllm_config = VllmConfig(compilation_config=compile_config) | ||
|
||
# AttnFusionPass needs attention layers to be registered in config upon init | ||
# so we initialize it during compilation. | ||
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw) | ||
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass) | ||
llm2 = LLM(model, | ||
enforce_eager=True, | ||
compilation_config=compile_config, | ||
gpu_memory_utilization=0.9, | ||
max_model_len=2048) | ||
|
||
# check support | ||
attn_fusion_supported = [ | ||
layer.impl.fused_output_quant_supported(quant_key.dtype, | ||
quant_key.static, | ||
quant_key.group_shape) | ||
for key, layer in compile_config.static_forward_context.items() | ||
] | ||
|
||
print(f"{attn_fusion_supported=}") | ||
if any(attn_fusion_supported): | ||
# Check quant ops | ||
backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False) | ||
|
||
# attention ops present in both, just output_scale param changes | ||
attn_nodes_pre = list(find_op_nodes(ATTN_OP, backend.graph_pre_pass)) | ||
attn_nodes_post = list(find_op_nodes(ATTN_OP, backend.graph_post_pass)) | ||
assert len(attn_nodes_pre) == len(attn_nodes_post) | ||
|
||
for i in range(len(attn_nodes_pre)): | ||
assert attn_nodes_pre[i].kwargs["output_scale"] is None | ||
fused = attn_nodes_post[i].kwargs["output_scale"] is not None | ||
assert fused == attn_fusion_supported[i], \ | ||
f"Node {i} {'' if fused else 'not '} expected " \ | ||
f"to have fused output quant" | ||
|
||
# check outputs | ||
fused_output = llm2.generate(prompts, sampling_params) | ||
|
||
# transform outputs to format expected by check_outputs_equal | ||
sample_outs = lambda s: (list(s.token_ids), s.text) | ||
outs_lst = lambda ros: [sample_outs(ro.outputs[0]) for ro in ros] | ||
|
||
check_outputs_equal( | ||
outputs_0_lst=outs_lst(unfused_output), | ||
outputs_1_lst=outs_lst(fused_output), | ||
name_0="unfused", | ||
name_1="fused", | ||
) | ||
|
||
# Clean Dynamo cache to avoid polluting other case(s) | ||
torch._dynamo.reset() | ||
|
||
# Reset backend to make sure llm2 gets released | ||
backend = None |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zou3519 this was a weird thing I ran into, wasn't sure it was worth running down but lmk if you want help with a repro