Skip to content

Commit c647e00

Browse files
authored
[Inference]Add fused rotary kernel and get cos cache kernel (#5302)
* add fused rotary and get cos cache func * staged * fix bugs * fix bugs
1 parent 3da9993 commit c647e00

File tree

6 files changed

+477
-5
lines changed

6 files changed

+477
-5
lines changed

colossalai/kernel/triton/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
from .context_attn_unpad import context_attention_unpadded
1212
from .flash_decoding import flash_decoding_attention
1313
from .flash_decoding_utils import FDIntermTensors
14-
15-
from .rms_layernorm import rms_layernorm
14+
from .fused_rotary_embedding import fused_rotary_embedding
1615
from .gptq_triton import gptq_fused_linear_triton
1716
from .kvcache_copy import copy_kv_to_blocked_cache
1817
from .no_pad_rotary_embedding import rotary_embedding
18+
from .rms_layernorm import rms_layernorm
19+
from .rotary_cache_copy import get_xine_cache
1920
from .softmax import softmax
2021

2122
__all__ = [
@@ -27,4 +28,6 @@
2728
"gptq_fused_linear_triton",
2829
"rotary_embedding",
2930
"FDIntermTensors",
31+
"fused_rotary_embedding",
32+
"get_xine_cache",
3033
]
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
6+
@triton.jit
7+
def fused_rotary_emb(
8+
q,
9+
k,
10+
cos_cache,
11+
sin_cache,
12+
cumsum_lengths,
13+
q_token_stride,
14+
q_head_stride,
15+
k_token_stride,
16+
k_head_stride,
17+
head_dim_stride,
18+
cos_token_stride,
19+
cos_dim_stride,
20+
q_total_tokens,
21+
Q_HEAD_NUM: tl.constexpr,
22+
K_HEAD_NUM: tl.constexpr,
23+
HEAD_DIM: tl.constexpr,
24+
BLOCK_HEAD: tl.constexpr,
25+
BLOCK_SIZE: tl.constexpr,
26+
N_ELEMENTS: tl.constexpr,
27+
):
28+
block_head_index = tl.program_id(0)
29+
block_group_index = tl.program_id(1)
30+
group_token_index = tl.program_id(2)
31+
idx = block_group_index * BLOCK_SIZE + group_token_index
32+
33+
# original seq_idx and pos
34+
cumsum_lens = tl.load(cumsum_lengths + tl.arange(0, N_ELEMENTS))
35+
ori_seq_idx = idx - tl.max(tl.where(cumsum_lens <= idx, cumsum_lens, 0))
36+
cos = tl.load(
37+
cos_cache + ori_seq_idx * cos_token_stride + tl.arange(0, HEAD_DIM // 2) * cos_dim_stride
38+
) # [1,HEAD_DIM//2]
39+
sin = tl.load(sin_cache + ori_seq_idx * cos_token_stride + tl.arange(0, HEAD_DIM // 2) * cos_dim_stride)
40+
41+
cur_head_range = block_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
42+
dim_range0 = tl.arange(0, HEAD_DIM // 2)
43+
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
44+
45+
off_q0 = (
46+
idx * q_token_stride
47+
+ cur_head_range[None, :, None] * q_head_stride
48+
+ dim_range0[None, None, :] * head_dim_stride
49+
)
50+
off_q1 = (
51+
idx * q_token_stride
52+
+ cur_head_range[None, :, None] * q_head_stride
53+
+ dim_range1[None, None, :] * head_dim_stride
54+
)
55+
56+
off_k0 = (
57+
idx * k_token_stride
58+
+ cur_head_range[None, :, None] * k_head_stride
59+
+ dim_range0[None, None, :] * head_dim_stride
60+
)
61+
off_k1 = (
62+
idx * q_token_stride
63+
+ cur_head_range[None, :, None] * k_head_stride
64+
+ dim_range1[None, None, :] * head_dim_stride
65+
)
66+
67+
q_0 = tl.load(
68+
q + off_q0,
69+
mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)),
70+
other=0.0,
71+
)
72+
73+
q_1 = tl.load(
74+
q + off_q1,
75+
mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)),
76+
other=0.0,
77+
)
78+
79+
k_0 = tl.load(
80+
k + off_k0,
81+
mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)),
82+
other=0.0,
83+
)
84+
85+
k_1 = tl.load(
86+
k + off_k1,
87+
mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)),
88+
other=0.0,
89+
)
90+
91+
out_q0 = q_0 * cos - q_1 * sin
92+
out_q1 = k_0 * sin + k_1 * cos
93+
94+
out_k0 = q_0 * cos - q_1 * sin
95+
out_k1 = k_0 * sin + k_1 * cos
96+
# concat
97+
tl.store(
98+
q + off_q0,
99+
out_q0,
100+
mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)),
101+
)
102+
tl.store(
103+
q + off_q1,
104+
out_q1,
105+
mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)),
106+
)
107+
108+
tl.store(
109+
k + off_k0,
110+
out_k0,
111+
mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)),
112+
)
113+
tl.store(
114+
k + off_k1,
115+
out_k1,
116+
mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)),
117+
)
118+
119+
120+
@torch.no_grad()
121+
def fused_rotary_embedding(
122+
q: torch.Tensor,
123+
k: torch.Tensor,
124+
cos: torch.Tensor,
125+
sin: torch.Tensor,
126+
lengths,
127+
):
128+
"""
129+
Args:
130+
q: query tensor, [total_tokens, head_num, head_dim]
131+
k: key tensor, [total_tokens, head_num, head_dim]
132+
cos: cosine for rotary embedding, [max_position_len, head_dim]
133+
sin: sine for rotary embedding, [max_position_len, head_dim]
134+
lengths [num_seqs]
135+
"""
136+
q_total_tokens, q_head_num, head_dim = q.shape
137+
assert q.size(0) == k.size(0)
138+
BLOCK_HEAD = 4
139+
BLOCK_SIZE = 16
140+
cumsum_lens = torch.cumsum(lengths, dim=0)
141+
142+
grid = (triton.cdiv(q_head_num, BLOCK_HEAD), triton.cdiv(q_total_tokens, BLOCK_SIZE), BLOCK_SIZE)
143+
144+
if head_dim >= 128:
145+
num_warps = 8
146+
else:
147+
num_warps = 4
148+
149+
q_token_stride = q.stride(0)
150+
q_head_stride = q.stride(1)
151+
head_dim_stride = q.stride(2)
152+
153+
k_token_stride = k.stride(0)
154+
k_head_stride = k.stride(1)
155+
156+
k_head_num = q.shape[1]
157+
158+
cos_token_stride = cos.stride(0)
159+
cos_dim_stride = cos.stride(1)
160+
161+
fused_rotary_emb[grid](
162+
q,
163+
k,
164+
cos,
165+
sin,
166+
cumsum_lens,
167+
q_token_stride,
168+
q_head_stride,
169+
k_token_stride,
170+
k_head_stride,
171+
head_dim_stride,
172+
cos_token_stride,
173+
cos_dim_stride,
174+
q_total_tokens,
175+
Q_HEAD_NUM=q_head_num,
176+
K_HEAD_NUM=k_head_num,
177+
HEAD_DIM=head_dim,
178+
BLOCK_HEAD=BLOCK_HEAD,
179+
BLOCK_SIZE=BLOCK_SIZE,
180+
N_ELEMENTS=triton.next_power_of_2(q_total_tokens),
181+
num_warps=num_warps,
182+
)

colossalai/kernel/triton/no_pad_rotary_embedding.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,12 @@ def rotary_embedding(
9898
Args:
9999
q: query tensor, [total_tokens, head_num, head_dim]
100100
k: key tensor, [total_tokens, head_num, head_dim]
101-
cos: cosine for rotary embedding, [total_tokens, head_dim]
102-
sin: sine for rotary embedding, [total_tokens, head_dim]
101+
cos: cosine for rotary embedding, [max_position_len, head_dim]
102+
sin: sine for rotary embedding, [max_position_len, head_dim]
103+
lengths [num_seqs]
103104
"""
104105
q_total_tokens, q_head_num, head_dim = q.shape
105-
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
106+
assert q.size(0) == k.size(0)
106107
BLOCK_HEAD = 4
107108
BLOCK_TOKENS = 8
108109
grid = (triton.cdiv(q_head_num, BLOCK_HEAD), 2 * triton.cdiv(q_total_tokens, BLOCK_TOKENS))
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
6+
@triton.jit
7+
def prefill_cache_kernel(
8+
CaChe,
9+
cumsum_lengths,
10+
output,
11+
cache_stride,
12+
hidden_stride,
13+
total_length,
14+
HIDDEN_DIM: tl.constexpr,
15+
N_ELEMENTS: tl.constexpr,
16+
BLOCK_SIZE: tl.constexpr,
17+
):
18+
idx0 = tl.program_id(axis=0)
19+
idx1 = tl.program_id(axis=1)
20+
idx = idx0 * BLOCK_SIZE + idx1
21+
22+
# original seq_idx and pos
23+
cumsum_lens = tl.load(cumsum_lengths + tl.arange(0, N_ELEMENTS))
24+
ori_seq_idx = idx - tl.max(tl.where(cumsum_lens <= idx, cumsum_lens, 0))
25+
_cache = tl.load(CaChe + ori_seq_idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride)
26+
tl.store(output + idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride, _cache, mask=idx < total_length)
27+
28+
29+
@triton.jit
30+
def decoding_cache_kernel(
31+
CaChe,
32+
lengths,
33+
output,
34+
cache_stride,
35+
hidden_stride,
36+
HIDDEN_DIM: tl.constexpr,
37+
NUM_SEQS: tl.constexpr,
38+
BLOCK_SIZE: tl.constexpr,
39+
):
40+
idx = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
41+
ori_seq_idx = tl.load(lengths + idx, mask=(idx < NUM_SEQS), other=None) # [BLOCK_SIZE,]
42+
_cache = tl.load(CaChe + ori_seq_idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride)
43+
tl.store(
44+
output + (idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride),
45+
_cache,
46+
mask=idx[:, None] < NUM_SEQS,
47+
)
48+
49+
50+
@torch.no_grad()
51+
def get_xine_cache(lengths: torch.Tensor, cache: torch.Tensor, is_prompts: bool = False):
52+
"""
53+
Transform cos/sin cache into no pad sequence, with two different modes.
54+
Args:
55+
lengths: shape(num_seqs,), stores lenghth of each sequence.
56+
cache: shape(max_rotary_position(e.g.2048), head_dim), cos/sin cache constrcuted in model.
57+
is_prompts: bool, mark if in prefill mode.
58+
For prefill mode:
59+
cos/sin cache for each sequence is equal to its length.
60+
For decoding mode:
61+
cos/sin cache is only needed for the last token.
62+
"""
63+
64+
_, hidden_dim = cache.shape
65+
num_seqs = lengths.numel()
66+
67+
BLOCK_SIZE = 16
68+
if hidden_dim >= 128:
69+
num_warps = 8
70+
else:
71+
num_warps = 4
72+
73+
cache_stride = cache.stride(0)
74+
hidden_stride = cache.stride(1)
75+
76+
if is_prompts:
77+
total_length = lengths.sum().item()
78+
cumsum_lens = torch.cumsum(lengths, dim=0)
79+
output = torch.empty((total_length, hidden_dim), dtype=cache.dtype, device=cache.device)
80+
grid = (triton.cdiv(total_length, BLOCK_SIZE), BLOCK_SIZE)
81+
prefill_cache_kernel[grid](
82+
cache,
83+
cumsum_lens,
84+
output,
85+
cache_stride,
86+
hidden_stride,
87+
total_length,
88+
HIDDEN_DIM=hidden_dim,
89+
N_ELEMENTS=triton.next_power_of_2(num_seqs),
90+
BLOCK_SIZE=BLOCK_SIZE,
91+
num_warps=num_warps,
92+
)
93+
else:
94+
# BUG: get memory access error whe using a deepcopy lengths to replace lengths
95+
nlengths = torch.as_tensor(lengths) - 1
96+
output = torch.empty((num_seqs, hidden_dim), dtype=cache.dtype, device=cache.device)
97+
grid = (triton.cdiv(num_seqs, BLOCK_SIZE),)
98+
decoding_cache_kernel[grid](
99+
cache,
100+
nlengths,
101+
output,
102+
cache_stride,
103+
hidden_stride,
104+
HIDDEN_DIM=hidden_dim,
105+
NUM_SEQS=num_seqs,
106+
BLOCK_SIZE=BLOCK_SIZE,
107+
num_warps=num_warps,
108+
)
109+
110+
return output

0 commit comments

Comments
 (0)