Skip to content

[Bugfix][ROCm] fix the power of 2 exception from triton_unified_attention.py when running llama4 models and unit test fix #18100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 28, 2025

Conversation

hongxiayang
Copy link
Collaborator

@hongxiayang hongxiayang commented May 13, 2025

FIX #18088

As detailed in the above issue, when running V1 on llama4 issues, we saw the exception that requires the parameter is a power of 2. However, when running on llama4 128E FP8 models, the following expression in (https://github.com/vllm-project/vllm/blob/main/vllm/attention/ops/triton_unified_attention.py#L97) is not a power of 2.

 offs_m = tl.arange(0, BLOCK_Q * num_queries_per_kv)

Debugging found that those values are:
print("BLOCK_Q:", BLOCK_Q) -> 3
print("num_queries_per_kv:", num_queries_per_kv) -> 5
print("Product:", BLOCK_Q * num_queries_per_kv) -> 15

Noticed if we pass the BLOCK_M as the parameter which is hard-coded to 16 now, we can prevent this power of two issue, and also simplify the code without needing the padding.

BLOCK_M = 16
BLOCK_Q = BLOCK_M // num_queries_per_kv

So, the BLOCK_Q * num_queries_per_kv essentially = BLOCK_M

(2) It also uses a tl.constexpr BLOCK_M to replace many places to avoid multiple re-calculations of the same expression later on.

(3) This PR also fixed the the test_triton_unified_attention.py so that it can run successfully on ROCm.

Tests:

Initially, the test_triton_unified_attention.py would abort on ROCm platform. After the unit test fix, it is able to run the full test suite without issue.

(1) Passed the test_triton_unified_attention.py

root@tx:/vllm/tests/kernels/attention# pytest test_triton_unified_attention.py
/usr/local/lib/python3.12/dist-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset.
The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session"

  warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET))
=================================================================== test session starts ===================================================================
platform linux -- Python 3.12.10, pytest-8.3.5, pluggy-1.6.0
rootdir: /dockerx/vllm
configfile: pyproject.toml
plugins: anyio-4.9.0, asyncio-1.0.0
asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collected 1152 items                                                                                                                                      

test_triton_unified_attention.py .................................................................................................................. [  9%]
................................................................................................................................................... [ 22%]
................................................................................................................................................... [ 35%]
................................................................................................................................................... [ 48%]
.....................ssssssssssss............ssssssssssss............ssssssssssss............ssssssssssss............ssssssssssss............ssssss [ 60%]
ssssss............ssssssssssss............ssssssssssss............ssssssssssss............ssssssssssss............ssssssssssss............sssssssss [ 73%]
sss............ssssssssssss............ssssssssssss............ssssssssssss............ssssssssssss............ssssssssssss............ssssssssssss [ 86%]
............ssssssssssss............ssssssssssss............ssssssssssss............ssssssssssss............ssssssssssss............ssssssssssss... [ 99%]
.........                                                                                                                                           [100%]

==================================================================== warnings summary =====================================================================
../../../vllm/__init__.py:5
  /dockerx/vllm/vllm/__init__.py:5: RuntimeWarning: Failed to read commit hash:
  No module named 'vllm._version'
    from .version import __version__, __version_tuple__  # isort:skip

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================= 864 passed, 288 skipped, 1 warning in 349.31s (0:05:49) =================================================

(2) After this change, I was able to run llama4 model in V1 (by bypassing fall-back fix).

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@hongxiayang hongxiayang changed the title [Bugfix] fix the power of 2 exception when running llama4 model in tr… [Bugfix] fix the power of 2 exception when running llama4 model in triton_unified_attention.py May 13, 2025
@hongxiayang hongxiayang changed the title [Bugfix] fix the power of 2 exception when running llama4 model in triton_unified_attention.py [Bugfix] fix the power of 2 exception from triton_unified_attention.py when running llama4 models May 13, 2025
@ProExpertProg
Copy link
Collaborator

Comment on lines 97 to 99
# avoid power of 2 issue and pad it
BLOCK_Q_NUM_QUERY_PER_KV_PADDED: tl.constexpr = triton.next_power_of_2(
BLOCK_Q * num_queries_per_kv)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't we also need to introduce masking somewhere to account for this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will check on this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @tdoublep @tjtanaa has added the masking logic. Is this good from your perspective?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tdoublep I also noticed if we pass the BLOCK_M as the parameter which is hard-coded to 16 now, we can prevent this power of two issue, and also simplify the code without needing the padding.

BLOCK_M = 16
BLOCK_Q = BLOCK_M // num_queries_per_kv

So, the BLOCK_Q * num_queries_per_kv essentially = BLOCK_M

please let me know.

(Though not sure why BLOCK_M chooses to be harded-code to 16, and not other value.)

Copy link

mergify bot commented May 16, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @hongxiayang.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

hongxiayang and others added 6 commits May 22, 2025 20:55
…iton_unified_attention.py

Signed-off-by: Hongxia Yang <hongxia.yang@amd.com>
Signed-off-by: Hongxia Yang <hongxia.yang@amd.com>
Signed-off-by: Hongxia Yang <hongxia.yang@amd.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: Hongxia Yang <hongxia.yang@amd.com>
Signed-off-by: Hongxia Yang <hongxia.yang@amd.com>
Signed-off-by: Hongxia Yang <hongxia.yang@amd.com>
@hongxiayang hongxiayang added the ready ONLY add when PR is ready to merge/full CI is needed label May 22, 2025
Signed-off-by: Hongxia Yang <hongxia.yang@amd.com>
@hongxiayang hongxiayang changed the title [Bugfix] fix the power of 2 exception from triton_unified_attention.py when running llama4 models [Bugfix][ROCm] fix the power of 2 exception from triton_unified_attention.py when running llama4 models and unit test fix May 28, 2025
@tdoublep
Copy link
Member

@hongxiayang This does not seem to explicitly disable the fallback - is that intentional?

@hongxiayang
Copy link
Collaborator Author

@hongxiayang This does not seem to explicitly disable the fallback - is that intentional?

@tdoublep yes, it is intentional not changing the fall-back logic. I remember Greg's has a PR to add an environment variable to conditionally enable/disable the fall-back.

Copy link
Member

@tdoublep tdoublep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes look fine to me: it simplifies the kernel nicely. Thanks for catching some of the mistakes in the comments too. I'm still a little bit unsure how the actual performance will be for this model where it seems like BLOCK_M=15. Did you do any benchmarking at all? This doesn't need to block merging this, since the changes can't make anything worse.

I wasn't aware that the test was failing on MI300x actually. Are we running that as part of CI?

@hongxiayang
Copy link
Collaborator Author

The changes look fine to me: it simplifies the kernel nicely. Thanks for catching some of the mistakes in the comments too. I'm still a little bit unsure how the actual performance will be for this model where it seems like BLOCK_M=15. Did you do any benchmarking at all? This doesn't need to block merging this, since the changes can't make anything worse.

I was thinking about benchmarking how this refactoring will improve performance end to end. Will do that later.

I wasn't aware that the test was failing on MI300x actually. Are we running that as part of CI?

Previously, this test in AMD CI "AMD MI300: Kernels Attention Test %N" had the soft-failures as below:
image

Now it passed.

You can also run pytest locally.

@hongxiayang
Copy link
Collaborator Author

cc @DarkLight1337 are you comfortable to approval/merge this one? thanks.

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.

@houseroad houseroad merged commit 269d901 into vllm-project:main May 28, 2025
66 checks passed
Comment on lines +16 to +18
QDTYPES = [None, torch.float8_e4m3fn] if not current_platform.is_rocm() else [
None, torch.float8_e4m3fnuz
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is incorrect, it should use current_platform.fp8_dtype():

QDTYPES = [None, current_platform.fp8_dtype()]

amitm02 pushed a commit to amitm02/vllm that referenced this pull request Jun 1, 2025
…tion.py when running llama4 models and unit test fix (vllm-project#18100)

Signed-off-by: Hongxia Yang <hongxia.yang@amd.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: amit <amit.man@gmail.com>
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
…tion.py when running llama4 models and unit test fix (vllm-project#18100)

Signed-off-by: Hongxia Yang <hongxia.yang@amd.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: minpeter <kali2005611@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants