-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[Bugfix][ROCm] fix the power of 2 exception from triton_unified_attention.py when running llama4 models and unit test fix #18100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
# avoid power of 2 issue and pad it | ||
BLOCK_Q_NUM_QUERY_PER_KV_PADDED: tl.constexpr = triton.next_power_of_2( | ||
BLOCK_Q * num_queries_per_kv) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't we also need to introduce masking somewhere to account for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will check on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tdoublep I also noticed if we pass the BLOCK_M
as the parameter which is hard-coded to 16 now, we can prevent this power of two issue, and also simplify the code without needing the padding.
BLOCK_M = 16
BLOCK_Q = BLOCK_M // num_queries_per_kv
So, the BLOCK_Q * num_queries_per_kv
essentially = BLOCK_M
please let me know.
(Though not sure why BLOCK_M chooses to be harded-code to 16, and not other value.)
This pull request has merge conflicts that must be resolved before it can be |
c86c6ea
to
6983e7f
Compare
…iton_unified_attention.py Signed-off-by: Hongxia Yang <hongxia.yang@amd.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: Hongxia Yang <hongxia.yang@amd.com>
Signed-off-by: Hongxia Yang <hongxia.yang@amd.com>
fe9bc6e
to
7631fb0
Compare
Signed-off-by: Hongxia Yang <hongxia.yang@amd.com>
Signed-off-by: Hongxia Yang <hongxia.yang@amd.com>
@hongxiayang This does not seem to explicitly disable the fallback - is that intentional? |
@tdoublep yes, it is intentional not changing the fall-back logic. I remember Greg's has a PR to add an environment variable to conditionally enable/disable the fall-back. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes look fine to me: it simplifies the kernel nicely. Thanks for catching some of the mistakes in the comments too. I'm still a little bit unsure how the actual performance will be for this model where it seems like BLOCK_M=15
. Did you do any benchmarking at all? This doesn't need to block merging this, since the changes can't make anything worse.
I wasn't aware that the test was failing on MI300x actually. Are we running that as part of CI?
cc @DarkLight1337 are you comfortable to approval/merge this one? thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good.
QDTYPES = [None, torch.float8_e4m3fn] if not current_platform.is_rocm() else [ | ||
None, torch.float8_e4m3fnuz | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is incorrect, it should use current_platform.fp8_dtype()
:
QDTYPES = [None, current_platform.fp8_dtype()]
…tion.py when running llama4 models and unit test fix (vllm-project#18100) Signed-off-by: Hongxia Yang <hongxia.yang@amd.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: amit <amit.man@gmail.com>
…tion.py when running llama4 models and unit test fix (vllm-project#18100) Signed-off-by: Hongxia Yang <hongxia.yang@amd.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: minpeter <kali2005611@gmail.com>
FIX #18088
As detailed in the above issue, when running V1 on llama4 issues, we saw the exception that requires the parameter is a power of 2. However, when running on llama4 128E FP8 models, the following expression in (https://github.com/vllm-project/vllm/blob/main/vllm/attention/ops/triton_unified_attention.py#L97) is not a power of 2.
Debugging found that those values are:
print("BLOCK_Q:", BLOCK_Q) -> 3
print("num_queries_per_kv:", num_queries_per_kv) -> 5
print("Product:", BLOCK_Q * num_queries_per_kv) -> 15
Noticed if we pass the BLOCK_M as the parameter which is hard-coded to 16 now, we can prevent this power of two issue, and also simplify the code without needing the padding.
So, the
BLOCK_Q * num_queries_per_kv
essentially =BLOCK_M
(2) It also uses a tl.constexpr BLOCK_M to replace many places to avoid multiple re-calculations of the same expression later on.
(3) This PR also fixed the the test_triton_unified_attention.py so that it can run successfully on ROCm.
Tests:
Initially, the
test_triton_unified_attention.py
would abort on ROCm platform. After the unit test fix, it is able to run the full test suite without issue.(1) Passed the test_triton_unified_attention.py
(2) After this change, I was able to run llama4 model in V1 (by bypassing fall-back fix).