Skip to content

Commit 2496f5b

Browse files
authored
triton: cascade kernels (#396)
1 parent 68c3719 commit 2496f5b

File tree

5 files changed

+421
-0
lines changed

5 files changed

+421
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from . import cascade
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
from typing import Optional
2+
3+
import torch
4+
5+
from .kernels.cascade import (
6+
merge_state_in_place_kernel,
7+
merge_state_kernel,
8+
merge_states_kernel,
9+
variable_length_merge_states_kernel,
10+
)
11+
from .utils import check_device, check_dim, check_input, check_shape
12+
13+
14+
def merge_state(
15+
v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
16+
):
17+
check_input(v_a)
18+
check_input(s_a)
19+
check_input(v_b)
20+
check_input(s_b)
21+
check_device([v_a, s_a, v_b, s_b])
22+
check_dim(3, v_a)
23+
check_dim(2, s_a)
24+
check_dim(3, v_b)
25+
check_dim(2, s_b)
26+
check_shape(v_a, v_b)
27+
check_shape(s_a, s_b)
28+
assert v_a.size(0) == s_a.size(0)
29+
assert v_a.size(1) == s_b.size(1)
30+
s_a = s_a.to(torch.float32)
31+
s_b = s_b.to(torch.float32)
32+
seq_len = v_a.size(0)
33+
num_heads = v_a.size(1)
34+
head_dim = v_a.size(2)
35+
v_merged = torch.empty_like(v_a).to(s_a.device)
36+
s_merged = torch.empty((seq_len, num_heads)).to(s_a.device)
37+
bdx = head_dim
38+
bdy = num_heads
39+
40+
merge_state_kernel[lambda meta: (seq_len,)](
41+
v_a, s_a, v_b, s_b, v_merged, s_merged, num_heads, head_dim, bdx=bdx, bdy=bdy
42+
)
43+
44+
return v_merged, s_merged
45+
46+
47+
def merge_state_in_place(
48+
v: torch.Tensor,
49+
s: torch.Tensor,
50+
v_other: torch.Tensor,
51+
s_other: torch.Tensor,
52+
mask: Optional[torch.Tensor] = None,
53+
):
54+
check_input(v)
55+
check_input(s)
56+
check_input(v_other)
57+
check_input(s_other)
58+
check_device([v, s, v_other, s_other])
59+
check_dim(3, v)
60+
check_dim(2, s)
61+
check_dim(3, v_other)
62+
check_dim(2, s_other)
63+
check_shape(v, v_other)
64+
check_shape(s, s_other)
65+
assert v.size(0) == s.size(0)
66+
assert v.size(1) == s.size(1)
67+
assert s.dtype == torch.float32
68+
assert s_other.dtype == torch.float32
69+
if mask is not None:
70+
check_dim(1, mask)
71+
assert v.size(0) == mask.size(0)
72+
assert mask.device == device
73+
seq_len = v.size(0)
74+
num_heads = v.size(1)
75+
head_dim = v.size(2)
76+
77+
bdx = head_dim
78+
bdy = num_heads
79+
merge_state_in_place_kernel[(seq_len,)](
80+
v, s, v_other, s_other, num_heads, head_dim, mask, bdx=bdx, bdy=bdy
81+
)
82+
83+
84+
def merge_states(v: torch.Tensor, s: torch.Tensor):
85+
check_input(v)
86+
check_input(s)
87+
check_device([v, s])
88+
check_dim(4, v)
89+
check_dim(3, s)
90+
assert v.size(0) == s.size(0)
91+
assert v.size(1) == s.size(1)
92+
assert v.size(2) == s.size(2)
93+
seq_len = v.size(0)
94+
num_index_sets = v.size(1)
95+
num_heads = v.size(2)
96+
head_dim = v.size(3)
97+
s = s.to(torch.float32)
98+
v_merged = torch.empty(
99+
(seq_len, num_heads, head_dim), dtype=v.dtype, device=v.device
100+
)
101+
s_merged = torch.empty((seq_len, num_heads), dtype=s.dtype, device=s.device)
102+
103+
bdx = head_dim
104+
bdy = num_heads
105+
merge_states_kernel[(seq_len,)](
106+
v,
107+
s,
108+
v_merged,
109+
s_merged,
110+
num_index_sets,
111+
num_heads,
112+
head_dim,
113+
bdx=bdx,
114+
bdy=bdy,
115+
)
116+
return v_merged, s_merged
117+
118+
119+
def variable_length_merge_states(
120+
v: torch.Tensor, s: torch.Tensor, indptr: torch.Tensor
121+
):
122+
check_input(v)
123+
check_input(s)
124+
check_device([v, s])
125+
check_dim(3, v)
126+
check_dim(2, s)
127+
assert v.size(0) == s.size(0)
128+
assert v.size(1) == s.size(1)
129+
seq_len = indptr.size(0) - 1
130+
num_heads = v.size(1)
131+
head_dim = v.size(2)
132+
s = s.to(torch.float32)
133+
indptr = indptr.to(torch.int32)
134+
v_merged = torch.empty(
135+
(seq_len, num_heads, head_dim), dtype=v.dtype, device=v.device
136+
)
137+
s_merged = torch.empty((seq_len, num_heads), dtype=s.dtype, device=s.device)
138+
139+
bdx = head_dim
140+
bdy = num_heads
141+
variable_length_merge_states_kernel[(seq_len,)](
142+
v,
143+
s,
144+
indptr,
145+
v_merged,
146+
s_merged,
147+
num_heads,
148+
head_dim,
149+
bdx=bdx,
150+
bdy=bdy,
151+
)
152+
return v_merged, s_merged
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import triton
2+
import triton.language as tl
3+
4+
5+
@triton.jit
6+
def state_merge(o, m, d, other_o, other_m, other_d):
7+
m_max = tl.maximum(m, other_m)
8+
d = d * tl.exp2(m - m_max) + other_d * tl.exp2(other_m - m_max)
9+
o = o * tl.exp2(m - m_max) + other_o * tl.exp2(other_m - m_max)
10+
return o, m_max, d
11+
12+
13+
@triton.jit
14+
def state_normalize(o, m, d):
15+
o = o / d
16+
return o, m, d
17+
18+
19+
@triton.jit
20+
def state_get_lse(o, m, d):
21+
return m + tl.log2(d)
22+
23+
24+
@triton.jit
25+
def merge_state_kernel(
26+
v_a_ptr,
27+
s_a_ptr,
28+
v_b_ptr,
29+
s_b_ptr,
30+
v_merged_ptr,
31+
s_merged_ptr,
32+
num_heads,
33+
head_dim,
34+
bdx: tl.constexpr,
35+
bdy: tl.constexpr,
36+
):
37+
pos = tl.program_id(axis=0)
38+
for tx in tl.range(bdx):
39+
for head_idx in tl.range(bdy):
40+
s_a_val = tl.load(s_a_ptr + pos * num_heads + head_idx)
41+
s_b_val = tl.load(s_b_ptr + pos * num_heads + head_idx)
42+
43+
offsets = (pos * num_heads + head_idx) * head_dim + tx
44+
v_a = tl.load(v_a_ptr + offsets)
45+
v_b = tl.load(v_b_ptr + offsets)
46+
47+
v_merged, s_max, d = state_merge(
48+
o=v_a, m=s_a_val, d=1, other_o=v_b, other_m=s_b_val, other_d=1
49+
)
50+
v_merged, s_max, d = state_normalize(v_merged, s_max, d)
51+
v_merged_offset = (pos * num_heads + head_idx) * head_dim + tx
52+
tl.store(v_merged_ptr + v_merged_offset, v_merged)
53+
54+
if s_merged_ptr:
55+
tl.store(
56+
s_merged_ptr + pos * num_heads + head_idx,
57+
tl.log2(d) + s_max,
58+
)
59+
60+
61+
@triton.jit
62+
def merge_state_in_place_kernel(
63+
v_ptr,
64+
s_ptr,
65+
v_other_ptr,
66+
s_other_ptr,
67+
num_heads,
68+
head_dim,
69+
mask_ptr,
70+
bdx: tl.constexpr,
71+
bdy: tl.constexpr,
72+
):
73+
pos = tl.program_id(axis=0)
74+
if mask_ptr:
75+
if tl.load(mask_ptr + pos) == 0:
76+
return
77+
78+
for head_idx in tl.range(bdy):
79+
s_val = tl.load(s_ptr + pos * num_heads + head_idx)
80+
s_other_val = tl.load(s_other_ptr + pos * num_heads + head_idx)
81+
s_max = tl.maximum(s_val, s_other_val)
82+
s_val = tl.exp2(s_val - s_max)
83+
s_other_val = tl.exp2(s_other_val - s_max)
84+
scale = s_val / (s_val + s_other_val)
85+
other_scale = s_other_val / (s_val + s_other_val)
86+
for tx in tl.range(bdx):
87+
offset = (pos * num_heads + head_idx) * head_dim + tx
88+
v_vec = tl.load(v_ptr + offset)
89+
v_other_vec = tl.load(v_other_ptr + offset)
90+
v_vec = scale * v_vec + other_scale * v_other_vec
91+
tl.store(v_ptr + offset, v_vec)
92+
if s_ptr:
93+
tl.store(
94+
s_ptr + pos * num_heads + head_idx,
95+
tl.log2(s_val + s_other_val) + s_max,
96+
)
97+
98+
99+
@triton.jit
100+
def merge_states_kernel(
101+
v_ptr,
102+
s_ptr,
103+
v_merged_ptr,
104+
s_merged_ptr,
105+
num_index_sets,
106+
num_heads,
107+
head_dim,
108+
bdx: tl.constexpr,
109+
bdy: tl.constexpr,
110+
):
111+
pos = tl.program_id(axis=0)
112+
113+
for tx in tl.range(bdx):
114+
for head_idx in tl.range(bdy):
115+
o, m, d = 0.0, -5e4, 1.0
116+
for iter in tl.range(num_index_sets):
117+
s = tl.load(
118+
s_ptr + (pos * num_index_sets + iter) * num_heads + head_idx
119+
)
120+
v = tl.load(
121+
v_ptr
122+
+ ((pos * num_index_sets + iter) * num_heads + head_idx) * head_dim
123+
+ tx
124+
)
125+
o, m, d = state_merge(o, m, d, v, s, 1)
126+
o, m, d = state_normalize(o, m, d)
127+
tl.store(v_merged_ptr + (pos * num_heads + head_idx) * head_dim + tx, o)
128+
if s_merged_ptr:
129+
tl.store(
130+
s_merged_ptr + pos * num_heads + head_idx, state_get_lse(o, m, d)
131+
)
132+
133+
134+
@triton.jit
135+
def variable_length_merge_states_kernel(
136+
v_ptr,
137+
s_ptr,
138+
indptr,
139+
v_merged_ptr,
140+
s_merged_ptr,
141+
num_heads,
142+
head_dim,
143+
bdx: tl.constexpr,
144+
bdy: tl.constexpr,
145+
):
146+
pos = tl.program_id(axis=0)
147+
for tx in tl.range(bdx):
148+
for head_idx in tl.range(bdy):
149+
o, m, d = 0.0, -5e4, 1.0
150+
for iter in tl.range(tl.load(indptr + pos), tl.load(indptr + pos + 1)):
151+
s = tl.load(s_ptr + iter * num_heads + head_idx)
152+
v = tl.load(v_ptr + (iter * num_heads + head_idx) * head_dim + tx)
153+
o, m, d = state_merge(o, m, d, v, s, 1)
154+
o, m, d = state_normalize(o, m, d)
155+
tl.store(v_merged_ptr + (pos * num_heads + head_idx) * head_dim + tx, o)
156+
if s_merged_ptr:
157+
tl.store(
158+
s_merged_ptr + pos * num_heads + head_idx, state_get_lse(o, m, d)
159+
)

python/flashinfer/triton/utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import List
2+
3+
import torch
4+
5+
6+
def check_input(x: torch.Tensor):
7+
assert x.is_cuda, f"{str(x)} must be a CUDA Tensor"
8+
assert x.is_contiguous(), f"{str(x)} must be contiguous"
9+
10+
11+
def check_dim(d, x: torch.Tensor):
12+
assert x.dim() == d, f"{str(x)} must be a {d}D tensor"
13+
14+
15+
def check_shape(a: torch.Tensor, b: torch.Tensor):
16+
assert a.dim() == b.dim(), f"tensors should have same dim"
17+
for i in range(a.dim()):
18+
assert a.size(i) == b.size(
19+
i
20+
), f"tensors shape mismatch, {a.size()} and {b.size()}"
21+
22+
23+
def check_device(tensors: List[torch.Tensor]):
24+
device = tensors[0].device
25+
for t in tensors:
26+
assert (
27+
t.device == device
28+
), f"All tensors should be on the same device, but got {device} and {t.device}"

0 commit comments

Comments
 (0)