Skip to content

Conversation

@BoyuanFeng
Copy link
Contributor

@BoyuanFeng BoyuanFeng commented Oct 13, 2025

Before this PR, attention output is allocated and initialized with 0 (due to torch.zeros), and the view into a shape, before the output tensor is used by any other ops. This becomes a triton kernel of ~1 us latency, which is on-par with a rope/layer norm (~1.6us) latency.

This PR changes to allocate with torch.empty which only allocates the tensor and does not initialize it. This allocation will be removed by cudagraph so it is free.

As a result, this PR removes the attn_out_view kernel at the end of this qwen3-0.6b trace.
image

See #26682 (comment) for perf win.

@mergify
Copy link

mergify bot commented Oct 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @BoyuanFeng.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 13, 2025
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes attention output tensor allocation by using torch.empty instead of torch.zeros, avoiding an unnecessary kernel launch and improving performance. A new configuration init_attn_out is introduced to maintain backward compatibility for attention backends that require zero-initialized output tensors. The changes are logical and well-motivated. I have one suggestion to improve code clarity and maintainability in vllm/attention/layer.py by refactoring the output tensor shape calculation to avoid variable reuse, which can be error-prone.

Signed-off-by: Boyuan Feng <boyuan@meta.com>
@mergify mergify bot removed the needs-rebase label Oct 13, 2025
BoyuanFeng and others added 3 commits October 12, 2025 21:46
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Boyuan Feng <fby.1994@gmail.com>
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Boyuan Feng <boyuan@meta.com>
@ProExpertProg
Copy link
Collaborator

cc @WoosukKwon I think you added the zeroinit attention?

Signed-off-by: Boyuan Feng <boyuan@meta.com>
@WoosukKwon
Copy link
Collaborator

@BoyuanFeng @ProExpertProg

I think it was my mistake that we zero-initialize the buffer for every forward pass. My intent was to do it only when the whole attention operation is skipped, like profiling run. I think we can move output.zero_() to here?

if attn_metadata is None:
# Profiling run.
return output

I think it's still important to keep the buffers away from NaN, because some kernels could potentially err with it. Previously, I met this issue with a custom MoE kernel whose top-k and routing kernel doesn't handle NaNs well.

Signed-off-by: Boyuan Feng <boyuan@meta.com>
@mergify mergify bot added the v1 label Oct 14, 2025
Signed-off-by: Boyuan Feng <boyuan@meta.com>
@BoyuanFeng
Copy link
Contributor Author

@WoosukKwon thanks for the info! I added output.fill_(0) for profiling runs and use torch.empty for other fwd. I checked that this also removes the extra triton kernel.

@ProExpertProg ProExpertProg added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 14, 2025
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just check the attention fusion pass? I think it might rely on the fact that we use zeros for attention output

@zou3519 zou3519 enabled auto-merge (squash) October 14, 2025 18:50
@zou3519 zou3519 merged commit a86b4c5 into vllm-project:main Oct 14, 2025
49 checks passed
Jonahcb pushed a commit to Jonahcb/vllm that referenced this pull request Oct 15, 2025
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Boyuan Feng <fby.1994@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Jonah Bernard <jb2528@cornell.edu>
bbartels pushed a commit to bbartels/vllm that referenced this pull request Oct 16, 2025
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Boyuan Feng <fby.1994@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: bbartels <benjamin@bartels.dev>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Boyuan Feng <fby.1994@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Boyuan Feng <fby.1994@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Boyuan Feng <fby.1994@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Boyuan Feng <fby.1994@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Boyuan Feng <fby.1994@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Boyuan Feng <fby.1994@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
@ProExpertProg ProExpertProg mentioned this pull request Oct 28, 2025
1 task
@ProExpertProg ProExpertProg linked an issue Oct 28, 2025 that may be closed by this pull request
1 task
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature]: Optimize RoPE

5 participants