Skip to content

Commit d0969d0

Browse files
Isotr0pyxuebwang-amd
authored andcommitted
[Kernel][Performance] Add Triton kernel for Qwen3-VL interleaved MRoPE (vllm-project#25055)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 356b612 commit d0969d0

File tree

2 files changed

+88
-46
lines changed

2 files changed

+88
-46
lines changed

tests/kernels/core/test_mrope.py

Lines changed: 66 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import NamedTuple
34

45
import pytest
56
import torch
7+
from packaging.version import Version
68
from transformers import AutoConfig
9+
from transformers import __version__ as TRANSFORMERS_VERSION
710

811
from vllm.model_executor.layers.rotary_embedding import get_rope
912
from vllm.platforms import current_platform
@@ -15,6 +18,7 @@ def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int,
1518
head_size: int, max_position_embeddings: int,
1619
dtype: torch.dtype, device: torch.device):
1720
"""Generate test data for given configuration."""
21+
current_platform.seed_everything(42)
1822
# Create 2D positions (3, num_tokens) for multimodal case
1923
positions = torch.randint(0,
2024
max_position_embeddings // 4, (3, num_tokens),
@@ -33,43 +37,67 @@ def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int,
3337
return positions, query, key
3438

3539

36-
def unroll_model_tp_dict(model_tp_dict):
37-
return [(model_name, tp_size)
38-
for model_name, tp_sizes in model_tp_dict.items()
39-
for tp_size in tp_sizes]
40-
41-
42-
model_tp_dict = {
43-
"Qwen/Qwen2-VL-7B-Instruct": [1, 2],
44-
"Qwen/Qwen2-VL-72B-Instruct": [1, 2],
45-
"Qwen/Qwen2.5-VL-72B-Instruct": [1, 2],
46-
"zai-org/GLM-4.1V-9B-Thinking": [1, 2],
47-
}
48-
49-
# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
50-
dtype_atol_rtol_list = [
51-
[torch.bfloat16, 1e-2, 1.6e-2],
40+
class MRoPETestInfo(NamedTuple):
41+
model_name: str
42+
# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
43+
atol: float = 1e-2
44+
rtol: float = 1.6e-2
45+
marks: list[pytest.MarkDecorator] = []
46+
47+
48+
TRANSFORMERS_BASE_VERSION = Version(TRANSFORMERS_VERSION).base_version
49+
50+
MODELS_TO_TEST = [
51+
MRoPETestInfo(model_name="zai-org/GLM-4.1V-9B-Thinking"),
52+
MRoPETestInfo(model_name="Qwen/Qwen2-VL-7B-Instruct"),
53+
MRoPETestInfo(model_name="Qwen/Qwen2-VL-72B-Instruct"),
54+
MRoPETestInfo(model_name="Qwen/Qwen2.5-VL-72B-Instruct"),
55+
MRoPETestInfo(
56+
model_name="Qwen/Qwen3-VL-4B-Instruct",
57+
marks=[
58+
pytest.mark.skipif(
59+
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
60+
reason="Qwen3-VL only available after Transformers v4.57",
61+
)
62+
]),
63+
MRoPETestInfo(
64+
model_name="Qwen/Qwen3-VL-30B-A3B-Instruct",
65+
marks=[
66+
pytest.mark.skipif(
67+
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
68+
reason="Qwen3-VL only available after Transformers v4.57",
69+
)
70+
]),
5271
]
5372

5473
num_tokens_list = [11, 8192]
5574

5675

5776
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
5877
reason="Skipping CUDA/ROCm only tests.")
59-
@pytest.mark.parametrize("model_name, tp_size",
60-
unroll_model_tp_dict(model_tp_dict))
61-
@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list)
78+
@pytest.mark.parametrize("model_info, model_name", [
79+
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
80+
for test_config in MODELS_TO_TEST
81+
])
82+
@pytest.mark.parametrize("tp_size", [1, 2])
83+
@pytest.mark.parametrize("dtype", [torch.bfloat16])
6284
@pytest.mark.parametrize("num_tokens", num_tokens_list)
63-
def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
85+
def test_mrope(model_name: str, model_info: MRoPETestInfo, tp_size: int,
86+
dtype: torch.dtype, num_tokens: int):
87+
88+
atol = model_info.atol
89+
rtol = model_info.rtol
6490

6591
config = AutoConfig.from_pretrained(model_name)
92+
config = config.get_text_config()
6693

6794
# get the model config
6895
total_num_kv_heads = config.num_key_value_heads
6996
total_num_heads = config.num_attention_heads
7097
num_heads = total_num_heads // tp_size
7198
num_kv_heads = max(1, total_num_kv_heads // tp_size)
72-
head_dim = config.hidden_size // total_num_heads
99+
head_dim = (config.head_dim if hasattr(config, "head_dim") else
100+
config.hidden_size // total_num_heads)
73101
is_neox_style = True
74102

75103
rope_theta = config.rope_theta
@@ -111,24 +139,30 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
111139

112140
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
113141
reason="Skipping CUDA/ROCm only tests.")
114-
@pytest.mark.parametrize(
115-
"model_name, tp_size",
116-
unroll_model_tp_dict({
117-
"Qwen/Qwen2-VL-7B-Instruct": [1, 2],
118-
"zai-org/GLM-4.1V-9B-Thinking": [1, 2]
119-
}))
120-
@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list)
121-
@pytest.mark.parametrize("num_tokens", [4])
122-
def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol,
123-
num_tokens):
142+
@pytest.mark.parametrize("model_info, model_name", [
143+
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
144+
for test_config in MODELS_TO_TEST
145+
])
146+
@pytest.mark.parametrize("tp_size", [1, 2])
147+
@pytest.mark.parametrize("dtype", [torch.bfloat16])
148+
@pytest.mark.parametrize("num_tokens", num_tokens_list)
149+
def test_mrope_torch_compile_tracing(model_name: str,
150+
model_info: MRoPETestInfo, tp_size: int,
151+
dtype: torch.dtype, num_tokens: int):
152+
153+
atol = model_info.atol
154+
rtol = model_info.rtol
155+
124156
config = AutoConfig.from_pretrained(model_name)
157+
config = config.get_text_config()
125158

126159
# get the model config
127160
total_num_kv_heads = config.num_key_value_heads
128161
total_num_heads = config.num_attention_heads
129162
num_heads = total_num_heads // tp_size
130163
num_kv_heads = max(1, total_num_kv_heads // tp_size)
131-
head_dim = config.hidden_size // total_num_heads
164+
head_dim = (config.head_dim if hasattr(config, "head_dim") else
165+
config.hidden_size // total_num_heads)
132166
is_neox_style = True
133167
rope_theta = config.rope_theta
134168
max_position = config.max_position_embeddings

vllm/model_executor/layers/rotary_embedding/mrope.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
@triton.jit
18-
def _triton_qwen2vl_mrope_forward(
18+
def _triton_mrope_forward(
1919
q_ptr,
2020
k_ptr,
2121
cos,
@@ -30,12 +30,14 @@ def _triton_qwen2vl_mrope_forward(
3030
pad_hd: tl.constexpr,
3131
mrope_section_t: tl.constexpr,
3232
mrope_section_h: tl.constexpr,
33+
mrope_section_w: tl.constexpr,
34+
is_interleaved: tl.constexpr,
3335
):
3436
# Adapted from
3537
# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
3638
# This version supports flatten input tensors from vllm
3739
# and supports cos and sin cache with shape (3, num_tokens, head_dim // 2)
38-
# instead of (3, bsz, seq_len, head_dim)
40+
# instead of (3, bsz, seq_len, head_dim), also supports interleaved rotary
3941
pid = tl.program_id(0)
4042
# locate start address
4143
q_ptr = q_ptr + pid * (n_qh * hd)
@@ -47,9 +49,6 @@ def _triton_qwen2vl_mrope_forward(
4749
# ####################################################################
4850
# Note: cos and sin now have shape (3, num_tokens, head_dim // 2)
4951

50-
t_end = mrope_section_t
51-
h_end = t_end + mrope_section_h
52-
5352
# Updated stride calculation for half head_dim
5453
half_rd = rd // 2
5554
t_cos = cos + pid * half_rd
@@ -61,9 +60,18 @@ def _triton_qwen2vl_mrope_forward(
6160

6261
# Updated offsets for half head_dim
6362
cos_offsets = tl.arange(0, pad_hd // 2)
64-
t_mask = cos_offsets < t_end
65-
h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
66-
w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd)
63+
if is_interleaved:
64+
h_mask = (((cos_offsets % 3) == 1) &
65+
(cos_offsets <= 3 * mrope_section_h))
66+
w_mask = (((cos_offsets % 3) == 2) &
67+
(cos_offsets <= 3 * mrope_section_w))
68+
t_mask = ~(h_mask | w_mask)
69+
else:
70+
t_end = mrope_section_t
71+
h_end = t_end + mrope_section_h
72+
t_mask = cos_offsets < mrope_section_t
73+
h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
74+
w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd)
6775

6876
t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
6977
h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
@@ -131,6 +139,7 @@ def triton_mrope(
131139
mrope_section: list[int],
132140
head_size: int,
133141
rotary_dim: int,
142+
mrope_interleaved: bool,
134143
) -> tuple[torch.Tensor, torch.Tensor]:
135144
"""Qwen2VL mrope kernel.
136145
@@ -158,7 +167,7 @@ def triton_mrope(
158167
cos = cos.contiguous()
159168
sin = sin.contiguous()
160169

161-
_triton_qwen2vl_mrope_forward[(n_row, )](
170+
_triton_mrope_forward[(n_row, )](
162171
q,
163172
k,
164173
cos,
@@ -173,6 +182,8 @@ def triton_mrope(
173182
pad_hd,
174183
mrope_section[0],
175184
mrope_section[1],
185+
mrope_section[2],
186+
mrope_interleaved,
176187
)
177188
return q, k
178189

@@ -201,7 +212,7 @@ def __init__(
201212
is_neox_style: bool,
202213
dtype: torch.dtype,
203214
mrope_section: Optional[list[int]] = None,
204-
mrope_interleaved: Optional[bool] = False,
215+
mrope_interleaved: bool = False,
205216
) -> None:
206217
# In Qwen2.5-VL, the maximum index value is related to the duration of
207218
# the input video. We enlarge max_position_embeddings to 4 times to get
@@ -282,10 +293,6 @@ def forward_cuda(
282293
assert positions.ndim == 1 or positions.ndim == 2
283294
assert key is not None
284295

285-
if self.mrope_interleaved:
286-
# TODO: add triton implementation to support mrope-interleaved
287-
return self.forward_native(positions, query, key)
288-
289296
num_tokens = positions.shape[-1]
290297
cos_sin = self.cos_sin_cache[positions]
291298
cos, sin = cos_sin.chunk(2, dim=-1)
@@ -302,6 +309,7 @@ def forward_cuda(
302309
self.mrope_section,
303310
self.head_size,
304311
self.rotary_dim,
312+
self.mrope_interleaved,
305313
)
306314

307315
return q.reshape(query_shape), k.reshape(key_shape)

0 commit comments

Comments
 (0)