Skip to content

Conversation

@lgeiger
Copy link
Contributor

@lgeiger lgeiger commented Oct 11, 2025

Purpose

Followup on #25337 and #25347. fast_pos_embed_interpolate launches many small cuda ops so this PR slightly simplifies and optimised the implementation.
/cc @ywang96 @Isotr0py

Test Plan & Results

I verified that the new implementation doesn't change the computation.

A quick micro benchmark on an H100 shows a 15% speedup of fast_pos_embed_interpolate and I don't think it reduces readability of the code.

import torch
import torch.nn as nn


class Qwen3_VisionTransformer(nn.Module):
    def __init__(self, hidden_size, num_position_embeddings, spatial_merge_size):
        super().__init__()
        self.spatial_merge_size = spatial_merge_size
        self.num_grid_per_side = int(num_position_embeddings**0.5)
        self.dtype = torch.bfloat16
        self.device = torch.device("cuda:0")
        self.pos_embed = nn.Embedding(
            num_position_embeddings, hidden_size, dtype=self.dtype, device=self.device
        )

    def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
        ...

    def bench_fast_pos_embed_interpolate(self, grid_thw: list[list[int]]):
        self.fast_pos_embed_interpolate(grid_thw)
        torch.cuda.synchronize()


model = Qwen3_VisionTransformer(
    hidden_size=1152, num_position_embeddings=2304, spatial_merge_size=2
)

grid_thw = [[1, 64, 48], [1, 64, 48], [1, 64, 48], [1, 64, 48], [1, 64, 48]]

model.bench_fast_pos_embed_interpolate(grid_thw)
%timeit model.bench_fast_pos_embed_interpolate(grid_thw)
main:    1.33 ms ± 30.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
This PR: 1.13 ms ± 3.85 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
@lgeiger lgeiger requested a review from sighingnow as a code owner October 11, 2025 20:29
@mergify mergify bot added the qwen Related to Qwen models label Oct 11, 2025
Comment on lines -492 to +489
weights = weights.to(
dtype=self.dtype, device=self.device, non_blocking=True
)
weights = weights.to(dtype=self.dtype)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weights will already be on self.device so no need to copy it again.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several well-reasoned optimizations to the fast_pos_embed_interpolate function in Qwen3VL. The changes, including refactoring weight calculations, vectorizing index computations, and using more efficient PyTorch operations like .sum() instead of unbind() followed by manual addition, are mathematically sound and contribute to the reported 11% performance improvement. The code is now more concise and idiomatic. I've reviewed the changes and found them to be correct and beneficial for performance without sacrificing readability. This is a solid improvement.

@lgeiger
Copy link
Contributor Author

lgeiger commented Oct 11, 2025

The grid_thw in the above benchmark came from an actual model request. I am now wondering whether it would make sense to cache the result for each t, h, w which in the above case for 5 images of the same shape would significantly speedup this code. I guess this could cause slightly higher memory usage but we could either put this cache on the CPU or at the very least handle the case where all items in the grid_thw list are the same.

Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution!

Re: caching t, h, w - I had the same idea when I first cleaned up the code here that we could build a small cache on CPU for this, but I also wonder whether the h2d cost is worth the effort

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) October 12, 2025 09:42
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 12, 2025
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
auto-merge was automatically disabled October 12, 2025 11:50

Head branch was pushed to by a user without write access

@lgeiger
Copy link
Contributor Author

lgeiger commented Oct 12, 2025

Thanks for the fast review. I also updated the tiling logic in 24b6717 which slightly improves things further. I updated the benchmarks above.

@Isotr0py Isotr0py merged commit a6049be into vllm-project:main Oct 12, 2025
54 checks passed
@lgeiger lgeiger deleted the qwen3-pos-interpolate branch October 12, 2025 17:28
1994 pushed a commit to 1994/vllm that referenced this pull request Oct 14, 2025
…26647)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Signed-off-by: 1994 <1994@users.noreply.github.com>
Dhruvilbhatt pushed a commit to Dhruvilbhatt/vllm that referenced this pull request Oct 14, 2025
…26647)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Signed-off-by: Dhruvil Bhatt <bhattdbh@amazon.com>
bbartels pushed a commit to bbartels/vllm that referenced this pull request Oct 16, 2025
…26647)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Signed-off-by: bbartels <benjamin@bartels.dev>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
…26647)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
…26647)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…26647)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…26647)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…26647)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…26647)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants