Skip to content

Commit fded91d

Browse files
CjhHa1FrankLeeeee
authored andcommitted
[Inference] Kernel: no pad rotary embedding (#5252)
* fix bugs * comment * use more accurate atol * fix
1 parent d40eb26 commit fded91d

File tree

3 files changed

+207
-0
lines changed

3 files changed

+207
-0
lines changed

colossalai/kernel/triton/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
from .context_attn_unpad import context_attention_unpadded
1212
from .fused_layernorm import layer_norm
1313
from .gptq_triton import gptq_fused_linear_triton
14+
from .no_pad_rotary_embedding import rotary_embedding
1415
from .softmax import softmax
1516

1617
__all__ = [
1718
"context_attention_unpadded",
1819
"softmax",
1920
"layer_norm",
2021
"gptq_fused_linear_triton",
22+
"rotary_embedding",
2123
]
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
6+
@triton.jit
7+
def rotary_embedding_kernel(
8+
q,
9+
k,
10+
cos,
11+
sin,
12+
q_token_stride,
13+
q_head_stride,
14+
k_token_stride,
15+
k_head_stride,
16+
head_dim_stride,
17+
cos_token_stride,
18+
cos_stride,
19+
q_total_tokens,
20+
Q_HEAD_NUM: tl.constexpr,
21+
K_HEAD_NUM: tl.constexpr,
22+
HEAD_DIM: tl.constexpr,
23+
BLOCK_HEAD: tl.constexpr,
24+
BLOCK_TOKENS: tl.constexpr,
25+
):
26+
block_head_index = tl.program_id(0)
27+
block_token_index = tl.program_id(1)
28+
29+
rotary_data = q
30+
HEAD_NUM = Q_HEAD_NUM
31+
head_stride = q_head_stride
32+
token_stride = q_token_stride
33+
34+
if block_token_index * BLOCK_TOKENS >= q_total_tokens:
35+
block_token_index = block_token_index - tl.cdiv(q_total_tokens, BLOCK_TOKENS)
36+
rotary_data = k
37+
HEAD_NUM = K_HEAD_NUM
38+
head_stride = k_head_stride
39+
token_stride = k_token_stride
40+
41+
tokens_range = block_token_index * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS)
42+
head_range = block_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
43+
44+
dim_range0 = tl.arange(0, HEAD_DIM // 2)
45+
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
46+
47+
off_data0 = (
48+
tokens_range[:, None, None] * token_stride
49+
+ head_range[None, :, None] * head_stride
50+
+ dim_range0[None, None, :] * head_dim_stride
51+
)
52+
off_data1 = (
53+
tokens_range[:, None, None] * token_stride
54+
+ head_range[None, :, None] * head_stride
55+
+ dim_range1[None, None, :] * head_dim_stride
56+
)
57+
58+
loaded_data0 = tl.load(
59+
rotary_data + off_data0,
60+
mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
61+
other=0.0,
62+
)
63+
loaded_data1 = tl.load(
64+
rotary_data + off_data1,
65+
mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
66+
other=0.0,
67+
)
68+
69+
off_cos_sin = tokens_range[:, None] * cos_token_stride + dim_range0[None, :] * cos_stride
70+
71+
loaded_cos = tl.load(cos + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
72+
loaded_sin = tl.load(sin + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)
73+
74+
out0 = loaded_data0 * loaded_cos[:, None, :] - loaded_data1 * loaded_sin[:, None, :]
75+
out1 = loaded_data0 * loaded_sin[:, None, :] + loaded_data1 * loaded_cos[:, None, :]
76+
77+
# concat
78+
tl.store(
79+
rotary_data + off_data0,
80+
out0,
81+
mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
82+
)
83+
tl.store(
84+
rotary_data + off_data1,
85+
out1,
86+
mask=((head_range[None, :, None] < HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),
87+
)
88+
89+
90+
@torch.no_grad()
91+
def rotary_embedding(
92+
q: torch.Tensor,
93+
k: torch.Tensor,
94+
cos: torch.Tensor,
95+
sin: torch.Tensor,
96+
):
97+
"""
98+
Args:
99+
q: query tensor, [total_tokens, head_num, head_dim]
100+
k: key tensor, [total_tokens, head_num, head_dim]
101+
cos: cosine for rotary embedding, [total_tokens, head_dim]
102+
sin: sine for rotary embedding, [total_tokens, head_dim]
103+
"""
104+
q_total_tokens, q_head_num, head_dim = q.shape
105+
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
106+
BLOCK_HEAD = 4
107+
BLOCK_TOKENS = 8
108+
grid = (triton.cdiv(q_head_num, BLOCK_HEAD), 2 * triton.cdiv(q_total_tokens, BLOCK_TOKENS))
109+
110+
if head_dim >= 128:
111+
num_warps = 8
112+
else:
113+
num_warps = 4
114+
115+
q_token_stride = q.stride(0)
116+
q_head_stride = q.stride(1)
117+
head_dim_stride = q.stride(2)
118+
119+
k_token_stride = k.stride(0)
120+
k_head_stride = k.stride(1)
121+
122+
k_head_num = q.shape[1]
123+
124+
cos_token_stride = cos.stride(0)
125+
cos_stride = cos.stride(1)
126+
127+
rotary_embedding_kernel[grid](
128+
q,
129+
k,
130+
cos,
131+
sin,
132+
q_token_stride,
133+
q_head_stride,
134+
k_token_stride,
135+
k_head_stride,
136+
head_dim_stride,
137+
cos_token_stride,
138+
cos_stride,
139+
q_total_tokens,
140+
Q_HEAD_NUM=q_head_num,
141+
K_HEAD_NUM=k_head_num,
142+
HEAD_DIM=head_dim,
143+
BLOCK_HEAD=BLOCK_HEAD,
144+
BLOCK_TOKENS=BLOCK_TOKENS,
145+
num_warps=num_warps,
146+
num_stages=1,
147+
)
148+
149+
return
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import pytest
2+
import torch
3+
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
4+
5+
from colossalai.kernel.triton import rotary_embedding
6+
7+
8+
def torch_rotary_emb(x, cos, sin):
9+
seq_len, h, dim = x.shape
10+
x0 = x[:, :, 0 : dim // 2]
11+
x1 = x[:, :, dim // 2 : dim]
12+
cos = cos.view((seq_len, 1, dim // 2))
13+
sin = sin.view((seq_len, 1, dim // 2))
14+
o0 = x0 * cos - x1 * sin
15+
o1 = x0 * sin + x1 * cos
16+
return torch.cat((o0, o1), dim=-1)
17+
18+
19+
@pytest.mark.parametrize("BATCH_SIZE", [4])
20+
@pytest.mark.parametrize("SEQ_LEN", [64])
21+
@pytest.mark.parametrize("H", [32])
22+
@pytest.mark.parametrize("D", [64])
23+
@pytest.mark.parametrize("dtype", [torch.float32])
24+
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
25+
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
26+
# our crafted op equals to Transformers
27+
x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D)
28+
x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D)
29+
emb = LlamaRotaryEmbedding(D)
30+
cos, sin = emb(x0, TOTAL_TOKENS)
31+
cos_2 = cos[:, :32]
32+
sin_2 = sin[:, :32]
33+
position_ids = torch.arange(TOTAL_TOKENS)
34+
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids)
35+
embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2)
36+
assert torch.allclose(embd_x0, embd_stimulated_x)
37+
38+
# create data
39+
q_shape = (TOTAL_TOKENS, H, D)
40+
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
41+
k_shape = (TOTAL_TOKENS, H, D)
42+
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
43+
cos_shape = (TOTAL_TOKENS, D // 2)
44+
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
45+
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
46+
47+
q_ref = torch_rotary_emb(q, cos, sin)
48+
k_ref = torch_rotary_emb(k, cos, sin)
49+
rotary_embedding(q, k, cos, sin)
50+
51+
assert torch.allclose(q, q_ref, atol=1e-4, rtol=1e-4)
52+
assert torch.allclose(k, k_ref, atol=1e-4, rtol=1e-4)
53+
54+
55+
if __name__ == "__main__":
56+
test_rotary_emb(4, 64, 32, 64, torch.float32)

0 commit comments

Comments
 (0)