Skip to content

Commit 85ac82d

Browse files
authored
[Kernel] Make rotary_embedding ops more flexible with input shape (vllm-project#12777)
1 parent 1e57b1e commit 85ac82d

File tree

4 files changed

+115
-57
lines changed

4 files changed

+115
-57
lines changed

csrc/pos_encoding_kernels.cu

+89-14
Original file line numberDiff line numberDiff line change
@@ -124,18 +124,54 @@ __global__ void batched_rotary_embedding_kernel(
124124
void rotary_embedding(
125125
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
126126
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
127-
// [num_tokens, num_heads * head_size]
127+
// [num_tokens, num_heads * head_size] or
128+
// [batch_size, seq_len, num_heads, head_size] or
129+
// [num_tokens, num_heads, head_size]
128130
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
129-
// [num_tokens, num_kv_heads * head_size]
131+
// [num_tokens, num_kv_heads * head_size] or
132+
// [batch_size, seq_len, num_heads, head_size] or
133+
// [num_tokens, num_heads, head_size]
130134
int64_t head_size,
131135
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
132136
bool is_neox) {
133-
int64_t num_tokens = query.numel() / query.size(-1);
137+
// num_tokens = batch_size * seq_len
138+
int64_t num_tokens = positions.numel();
139+
int positions_ndim = positions.dim();
140+
141+
// Make sure num_tokens dim is consistent across positions, query, and key.
142+
TORCH_CHECK(
143+
positions_ndim == 1 || positions_ndim == 2,
144+
"positions must have shape [num_tokens] or [batch_size, seq_len]");
145+
if (positions_ndim == 1) {
146+
TORCH_CHECK(
147+
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
148+
"query, key and positions must have the same number of tokens");
149+
}
150+
if (positions_ndim == 2) {
151+
TORCH_CHECK(
152+
query.size(0) == positions.size(0) &&
153+
key.size(0) == positions.size(0) &&
154+
query.size(1) == positions.size(1) &&
155+
key.size(1) == positions.size(1),
156+
"query, key and positions must have the same batch_size and seq_len");
157+
}
158+
159+
// Make sure head_size is valid for query and key
160+
// hidden_size = num_heads * head_size
161+
int query_hidden_size = query.numel() / num_tokens;
162+
int key_hidden_size = key.numel() / num_tokens;
163+
TORCH_CHECK(query_hidden_size % head_size == 0);
164+
TORCH_CHECK(key_hidden_size % head_size == 0);
165+
166+
// Make sure query and key have consistent number of heads
167+
int num_heads = query_hidden_size / head_size;
168+
int num_kv_heads = key_hidden_size / head_size;
169+
TORCH_CHECK(num_heads % num_kv_heads == 0);
170+
134171
int rot_dim = cos_sin_cache.size(1);
135-
int num_heads = query.size(-1) / head_size;
136-
int num_kv_heads = key.size(-1) / head_size;
137-
int64_t query_stride = query.stride(-2);
138-
int64_t key_stride = key.stride(-2);
172+
int seq_dim_idx = positions_ndim - 1;
173+
int64_t query_stride = query.stride(seq_dim_idx);
174+
int64_t key_stride = key.stride(seq_dim_idx);
139175

140176
dim3 grid(num_tokens);
141177
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
@@ -165,19 +201,58 @@ and process in batched manner.
165201
void batched_rotary_embedding(
166202
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
167203
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
168-
// [num_tokens, num_heads * head_size]
204+
// [num_tokens, num_heads * head_size] or
205+
// [batch_size, seq_len, num_heads, head_size] or
206+
// [num_tokens, num_heads, head_size]
169207
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
170-
// [num_tokens, num_kv_heads * head_size]
208+
// [num_tokens, num_kv_heads * head_size] or
209+
// [batch_size, seq_len, num_heads, head_size] or
210+
// [num_tokens, num_heads, head_size]
171211
int64_t head_size,
172212
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
173213
bool is_neox, int64_t rot_dim,
174-
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
214+
torch::Tensor& cos_sin_cache_offsets // [num_tokens] or [batch_size]
175215
) {
216+
// num_tokens = batch_size * seq_len
176217
int64_t num_tokens = cos_sin_cache_offsets.size(0);
177-
int num_heads = query.size(-1) / head_size;
178-
int num_kv_heads = key.size(-1) / head_size;
179-
int64_t query_stride = query.stride(-2);
180-
int64_t key_stride = key.stride(-2);
218+
TORCH_CHECK(
219+
positions.size(0) == num_tokens || positions.numel() == num_tokens,
220+
"positions must have the same num_tokens or batch_size as "
221+
"cos_sin_cache_offsets");
222+
223+
int positions_ndim = positions.dim();
224+
// Make sure num_tokens dim is consistent across positions, query, and key.
225+
TORCH_CHECK(
226+
positions_ndim == 1 || positions_ndim == 2,
227+
"positions must have shape [num_tokens] or [batch_size, seq_len]");
228+
if (positions_ndim == 1) {
229+
TORCH_CHECK(
230+
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
231+
"query, key and positions must have the same number of tokens");
232+
}
233+
if (positions_ndim == 2) {
234+
TORCH_CHECK(
235+
query.size(0) == positions.size(0) &&
236+
key.size(0) == positions.size(0) &&
237+
query.size(1) == positions.size(1) &&
238+
key.size(1) == positions.size(1),
239+
"query, key and positions must have the same batch_size and seq_len");
240+
}
241+
242+
// Make sure head_size is valid for query and key
243+
int query_hidden_size = query.numel() / num_tokens;
244+
int key_hidden_size = key.numel() / num_tokens;
245+
TORCH_CHECK(query_hidden_size % head_size == 0);
246+
TORCH_CHECK(key_hidden_size % head_size == 0);
247+
248+
// Make sure query and key have concistent number of heads
249+
int num_heads = query_hidden_size / head_size;
250+
int num_kv_heads = key_hidden_size / head_size;
251+
TORCH_CHECK(num_heads % num_kv_heads == 0);
252+
253+
int seq_dim_idx = positions_ndim - 1;
254+
int64_t query_stride = query.stride(seq_dim_idx);
255+
int64_t key_stride = key.stride(seq_dim_idx);
181256

182257
dim3 grid(num_tokens);
183258
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));

tests/kernels/test_pos_encoding.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
from itertools import accumulate, product
4-
from typing import Dict, List, Optional
4+
from typing import Callable, Dict, List, Optional
55

66
import pytest
77
import torch
@@ -24,7 +24,21 @@
2424
]
2525

2626

27+
def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
28+
head_size: int) -> tuple[int, ...]:
29+
return (batch_size, seq_len, num_heads * head_size)
30+
31+
32+
def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
33+
head_size: int) -> tuple[int, ...]:
34+
return (batch_size, seq_len, num_heads, head_size)
35+
36+
37+
TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape]
38+
39+
2740
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
41+
@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN)
2842
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
2943
@pytest.mark.parametrize("seq_len", SEQ_LENS)
3044
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@@ -36,6 +50,7 @@
3650
@torch.inference_mode()
3751
def test_rotary_embedding(
3852
is_neox_style: bool,
53+
tensor_shape_fn: Callable[[int, int, int, int], tuple[int]],
3954
batch_size: int,
4055
seq_len: int,
4156
num_heads: int,
@@ -58,10 +73,8 @@ def test_rotary_embedding(
5873
rope = rope.to(dtype=dtype)
5974

6075
positions = torch.randint(0, max_position, (batch_size, seq_len))
61-
query = torch.randn(batch_size,
62-
seq_len,
63-
num_heads * head_size,
64-
dtype=dtype)
76+
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
77+
query = torch.randn(query_shape, dtype=dtype)
6578
key = torch.randn_like(query)
6679

6780
# NOTE(woosuk): The reference implementation should be executed first
@@ -80,6 +93,7 @@ def test_rotary_embedding(
8093

8194

8295
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
96+
@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN)
8397
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
8498
@pytest.mark.parametrize("seq_len", SEQ_LENS)
8599
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@@ -91,6 +105,7 @@ def test_rotary_embedding(
91105
@torch.inference_mode()
92106
def test_batched_rotary_embedding(
93107
is_neox_style: bool,
108+
tensor_shape_fn: Callable[[int, int, int, int], tuple[int]],
94109
batch_size: int,
95110
seq_len: int,
96111
num_heads: int,
@@ -113,10 +128,8 @@ def test_batched_rotary_embedding(
113128
rope = rope.to(dtype=dtype)
114129

115130
positions = torch.randint(0, max_position, (batch_size, seq_len))
116-
query = torch.randn(batch_size,
117-
seq_len,
118-
num_heads * head_size,
119-
dtype=dtype)
131+
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
132+
query = torch.randn(query_shape, dtype=dtype)
120133
key = torch.randn_like(query)
121134

122135
# NOTE(woosuk): The reference implementation should be executed first

vllm/attention/backends/mla/utils.py

+3-22
Original file line numberDiff line numberDiff line change
@@ -424,24 +424,6 @@ def _forward_decode(
424424
) -> torch.Tensor:
425425
raise NotImplementedError
426426

427-
def apply_pure_rope(
428-
self,
429-
input_positions: torch.Tensor,
430-
q_pe: torch.Tensor,
431-
k_pe: torch.Tensor,
432-
) -> tuple[torch.Tensor, torch.Tensor]:
433-
seq_len = input_positions.size(0)
434-
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
435-
436-
q_pe, k_pe = self.rotary_emb(
437-
input_positions,
438-
q_pe.reshape(seq_len, -1),
439-
k_pe.reshape(seq_len, -1),
440-
)
441-
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)
442-
443-
return q_pe, k_pe
444-
445427
def forward(
446428
self,
447429
layer: AttentionLayer,
@@ -466,22 +448,21 @@ def forward(
466448
# Restore head dim (for rotary embedding)
467449
k_pe = k_pe.unsqueeze(1)
468450
assert hasattr(attn_metadata, "input_positions")
469-
rope_fn = (self.rotary_emb
470-
if self.use_yarn_rope else self.apply_pure_rope)
471451

472452
if is_decode:
473453
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
474454
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
475455
.view(-1, self.num_heads, self.qk_rope_head_dim)
476-
q_pe, k_pe = rope_fn(attn_metadata.input_positions, q_pe, k_pe)
456+
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe,
457+
k_pe)
477458
else:
478459
assert is_prefill
479460
q = self.q_proj(hidden_states_or_q_c)[0]\
480461
.view(-1, self.num_heads, self.qk_head_dim)
481462

482463
# TODO(lucas): there must be a nicer way to write this line
483464
q[..., self.qk_nope_head_dim:], k_pe = \
484-
rope_fn(
465+
self.rotary_emb(
485466
attn_metadata.input_positions,
486467
q[..., self.qk_nope_head_dim:], k_pe)
487468

vllm/model_executor/models/deepseek_v2.py

+1-12
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,7 @@ def __init__(
257257
prefix=f"{prefix}.o_proj")
258258
if rope_scaling:
259259
rope_scaling["rope_type"] = 'deepseek_yarn'
260-
self.use_normal_rope = False
261-
else:
262-
self.use_normal_rope = True
260+
263261
self.rotary_emb = get_rope(qk_rope_head_dim,
264262
rotary_dim=qk_rope_head_dim,
265263
max_position=max_position_embeddings,
@@ -309,17 +307,8 @@ def forward(
309307
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
310308
k_pe = latent_cache[:, :, self.kv_lora_rank:]
311309

312-
if self.use_normal_rope:
313-
seq_len = positions.size(0)
314-
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
315-
q_pe = q_pe.reshape(seq_len, -1)
316-
k_pe = k_pe.reshape(seq_len, -1)
317-
318310
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
319311

320-
if self.use_normal_rope:
321-
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)
322-
323312
q[..., self.qk_nope_head_dim:] = q_pe
324313
k = torch.empty_like(q)
325314
k[..., :self.qk_nope_head_dim] = k_nope

0 commit comments

Comments
 (0)