Skip to content

Commit d66b83c

Browse files
authored
[Example] Update GQA varlen fwd and MHA varlen fwd (#1071)
1 parent e57ef58 commit d66b83c

File tree

3 files changed

+405
-158
lines changed

3 files changed

+405
-158
lines changed
Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
# ruff: noqa
2+
import argparse
3+
import torch
4+
import tilelang
5+
import tilelang.language as T
6+
import tilelang.testing
7+
from einops import rearrange, repeat
8+
from tilelang.profiler import do_bench
9+
from varlen_utils import generate_random_padding_mask, generate_qkv
10+
11+
tilelang.disable_cache()
12+
13+
14+
def attention_ref(
15+
q,
16+
k,
17+
v,
18+
query_padding_mask=None,
19+
key_padding_mask=None,
20+
causal=False,
21+
window_size=(-1, -1),
22+
upcast=True,
23+
):
24+
if causal:
25+
window_size = (window_size[0], 0)
26+
dtype_og = q.dtype
27+
if upcast:
28+
q, k, v = q.float(), k.float(), v.float()
29+
dim = q.shape[-1]
30+
scale = (1.0 / dim)**0.5
31+
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
32+
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
33+
scores = torch.einsum("bthd,bshd->bhts", q, k)
34+
if key_padding_mask is not None:
35+
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
36+
scores = scores * scale
37+
attention = torch.softmax(scores, dim=-1).to(v.dtype)
38+
39+
if query_padding_mask is not None:
40+
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
41+
output = torch.einsum("bhts,bshd->bthd", attention, v)
42+
if query_padding_mask is not None:
43+
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
44+
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
45+
46+
47+
@tilelang.jit(
48+
out_idx=[6], pass_configs={
49+
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
50+
})
51+
def flashattn(batch_size,
52+
groups,
53+
UQ,
54+
UKV,
55+
heads,
56+
dim,
57+
is_causal,
58+
block_M=64,
59+
block_N=64,
60+
num_stages=1,
61+
threads=128):
62+
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
63+
head_kv = heads // groups
64+
q_shape = [UQ, heads, dim]
65+
kv_shape = [UKV, head_kv, dim]
66+
o_shape = [UQ, heads, dim]
67+
dtype = "float16"
68+
accum_dtype = "float"
69+
70+
@T.prim_func
71+
def main(
72+
Q_unpad: T.Tensor(q_shape, dtype),
73+
K_unpad: T.Tensor(kv_shape, dtype),
74+
V_unpad: T.Tensor(kv_shape, dtype),
75+
cu_seqlens_q: T.Tensor([batch_size + 1], "int32"),
76+
cu_seqlens_k: T.Tensor([batch_size + 1], "int32"),
77+
max_seqlen_q: T.int32,
78+
Output_unpad: T.Tensor(o_shape, dtype),
79+
):
80+
with T.Kernel(
81+
T.ceildiv(max_seqlen_q, block_M), heads, batch_size,
82+
threads=threads) as (bx, by, bz):
83+
Q_shared = T.alloc_shared([block_M, dim], dtype)
84+
K_shared = T.alloc_shared([block_N, dim], dtype)
85+
V_shared = T.alloc_shared([block_N, dim], dtype)
86+
O_shared = T.alloc_shared([block_M, dim], dtype)
87+
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
88+
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
89+
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
90+
scores_max = T.alloc_fragment([block_M], accum_dtype)
91+
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
92+
scores_scale = T.alloc_fragment([block_M], accum_dtype)
93+
scores_sum = T.alloc_fragment([block_M], accum_dtype)
94+
logsum = T.alloc_fragment([block_M], accum_dtype)
95+
96+
batch_idx = bz
97+
head_idx = by
98+
kv_head_idx = head_idx // groups
99+
100+
q_start_idx = cu_seqlens_q[batch_idx]
101+
k_start_idx = cu_seqlens_k[batch_idx]
102+
v_start_idx = cu_seqlens_k[batch_idx]
103+
q_end_idx = cu_seqlens_q[batch_idx + 1]
104+
k_end_idx = cu_seqlens_k[batch_idx + 1]
105+
v_end_idx = cu_seqlens_k[batch_idx + 1]
106+
107+
q_current_seqlen = q_end_idx - q_start_idx
108+
k_current_seqlen = k_end_idx - k_start_idx
109+
v_current_seqlen = v_end_idx - v_start_idx
110+
111+
T.copy(
112+
Q_unpad[q_start_idx + bx * block_M:q_start_idx + (bx + 1) * block_M, head_idx, :],
113+
Q_shared)
114+
for i, d in T.Parallel(block_M, dim):
115+
if bx * block_M + i >= q_current_seqlen:
116+
Q_shared[i, d] = 0
117+
118+
T.fill(acc_o, 0)
119+
T.fill(logsum, 0)
120+
T.fill(scores_max, -T.infinity(accum_dtype))
121+
122+
loop_range = T.ceildiv(k_current_seqlen, block_N)
123+
124+
for k in T.Pipelined(loop_range, num_stages=num_stages):
125+
T.copy(
126+
K_unpad[k_start_idx + k * block_N:k_start_idx + (k + 1) * block_N,
127+
kv_head_idx, :], K_shared)
128+
for i, d in T.Parallel(block_N, dim):
129+
if k * block_N + i >= k_current_seqlen:
130+
K_shared[i, d] = 0
131+
132+
if is_causal:
133+
for i, j in T.Parallel(block_M, block_N):
134+
acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and
135+
(bx * block_M + i >= q_current_seqlen or
136+
k * block_N + j >= k_current_seqlen),
137+
-T.infinity(acc_s.dtype), 0)
138+
else:
139+
for i, j in T.Parallel(block_M, block_N):
140+
acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or
141+
k * block_N + j >= k_current_seqlen),
142+
-T.infinity(acc_s.dtype), 0)
143+
144+
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
145+
146+
T.copy(scores_max, scores_max_prev)
147+
T.fill(scores_max, -T.infinity(accum_dtype))
148+
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
149+
150+
for i in T.Parallel(block_M):
151+
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
152+
for i, j in T.Parallel(block_M, block_N):
153+
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
154+
T.reduce_sum(acc_s, scores_sum, dim=1)
155+
for i in T.Parallel(block_M):
156+
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
157+
T.copy(acc_s, acc_s_cast)
158+
159+
for i, j in T.Parallel(block_M, dim):
160+
acc_o[i, j] *= scores_scale[i]
161+
162+
T.copy(
163+
V_unpad[v_start_idx + k * block_N:v_start_idx + (k + 1) * block_N,
164+
kv_head_idx, :], V_shared)
165+
for i, d in T.Parallel(block_N, dim):
166+
if k * block_N + i >= v_current_seqlen:
167+
V_shared[i, d] = 0
168+
169+
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
170+
171+
for i, j in T.Parallel(block_M, dim):
172+
acc_o[i, j] /= logsum[i]
173+
T.copy(acc_o, O_shared)
174+
175+
for i, d in T.Parallel(block_M, dim):
176+
if bx * block_M + i < q_current_seqlen:
177+
Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d]
178+
179+
return main
180+
181+
182+
def main(batch: int = 1,
183+
heads: int = 64,
184+
q_seqlen: int = 2048,
185+
k_seqlen: int = 2048,
186+
dim: int = 128,
187+
groups: int = 16,
188+
is_causal: bool = False):
189+
assert heads % groups == 0, "heads must be divisible by groups"
190+
191+
flops_per_matmul = 2.0 * batch * heads * q_seqlen * k_seqlen * dim
192+
total_flops = 2 * flops_per_matmul
193+
194+
tilelang.testing.set_random_seed(0)
195+
196+
causal = False
197+
if causal:
198+
total_flops *= 0.5
199+
200+
tilelang.testing.set_random_seed(0)
201+
202+
dtype = torch.float16
203+
device = torch.device("cuda")
204+
205+
head_kv = heads // groups
206+
q = torch.randn(batch, q_seqlen, heads, dim, dtype=dtype, device=device, requires_grad=True)
207+
k = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device, requires_grad=True)
208+
v = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device, requires_grad=True)
209+
210+
query_padding_mask = generate_random_padding_mask(q_seqlen, batch, device, mode="random")
211+
key_padding_mask = generate_random_padding_mask(k_seqlen, batch, device, mode="random")
212+
213+
(
214+
q_unpad,
215+
k_unpad,
216+
v_unpad,
217+
cu_seqlens_q,
218+
cu_seqlens_k,
219+
max_seqlen_q,
220+
max_seqlen_k,
221+
q,
222+
k,
223+
v,
224+
output_pad_fn,
225+
_,
226+
_,
227+
) = generate_qkv(
228+
q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
229+
230+
UQ = q_unpad.shape[0]
231+
UKV = k_unpad.shape[0]
232+
233+
kernel = flashattn(
234+
batch,
235+
groups,
236+
UQ,
237+
UKV,
238+
heads,
239+
dim,
240+
is_causal,
241+
block_M=64,
242+
block_N=64,
243+
num_stages=1,
244+
threads=128)
245+
246+
out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q)
247+
out = output_pad_fn(out_unpad)
248+
249+
out_ref, _ = attention_ref(
250+
q,
251+
k,
252+
v,
253+
query_padding_mask=query_padding_mask,
254+
key_padding_mask=key_padding_mask,
255+
causal=is_causal,
256+
)
257+
torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2)
258+
print("All checks passed.✅")
259+
latency = do_bench(
260+
lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q))
261+
print("Tile-lang: {:.2f} ms".format(latency))
262+
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
263+
264+
265+
if __name__ == "__main__":
266+
parser = argparse.ArgumentParser()
267+
parser.add_argument('--batch', type=int, default=8, help='batch size')
268+
parser.add_argument('--heads', type=int, default=64, help='query heads')
269+
parser.add_argument('--groups', type=int, default=16, help='groups')
270+
parser.add_argument('--q_seqlen', type=int, default=2048, help='query sequence length')
271+
parser.add_argument('--k_seqlen', type=int, default=2048, help='key/value sequence length')
272+
parser.add_argument('--dim', type=int, default=128, help='head dim')
273+
parser.add_argument('--is_causal', action='store_true', help='causal attention')
274+
args = parser.parse_args()
275+
main(args.batch, args.heads, args.q_seqlen, args.k_seqlen, args.dim, args.groups,
276+
args.is_causal)

0 commit comments

Comments
 (0)