Skip to content

Commit e8cc372

Browse files
Alex4210987xinxyxiao
andauthored
[Enhancement] Add flash attn example for AMD MI300 series(#671)
* [Enhancement] Refactor buffer index handling for improved precision and clarity (#668) - Enhanced buffer index handling to address precision issues by removing redundant operations. - Streamlined the logic for determining buffer overlaps, ensuring more accurate conflict detection. - Updated related documentation to reflect changes in buffer management practices. * Remove obsolete test script for AMD example, streamlining the examples directory. * Remove unused dtype_size variable in AMD example script to streamline code. --------- Co-authored-by: xinxyxiao <xinyxiao@amd.com>
1 parent 98f93db commit e8cc372

File tree

1 file changed

+270
-0
lines changed

1 file changed

+270
-0
lines changed
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
# Copyright (c) Tile-AI Corporation.
2+
# Licensed under the MIT License.
3+
#
4+
# Modified to implement FlashAttention-2 forward pass principles.
5+
# Corrected loop implementation using T.while_loop.
6+
7+
import torch
8+
import torch.nn.functional as F
9+
import tilelang
10+
import tilelang.language as T
11+
import itertools
12+
import argparse
13+
from functools import partial
14+
15+
16+
# PyTorch 参考实现保持不变
17+
def ref_program(Q, K, V, is_causal, groups=1):
18+
assert Q.size(
19+
2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
20+
assert Q.size(
21+
2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}"
22+
dim = Q.size(-1)
23+
K = K.repeat_interleave(groups, dim=2)
24+
V = V.repeat_interleave(groups, dim=2)
25+
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
26+
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
27+
if is_causal:
28+
seq_len = Q.size(1)
29+
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
30+
mask = mask.unsqueeze(0).unsqueeze(0)
31+
scores = scores.masked_fill(mask == 0, float('-inf'))
32+
attention_weights = F.softmax(scores, dim=-1)
33+
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
34+
return output
35+
36+
37+
def get_v2_configs():
38+
"""Generates configurations for the autotuner, tailored for FA-2 style parallelism."""
39+
block_M = [64, 128, 256]
40+
block_N = [32, 64, 128]
41+
threads = [128, 256, 512]
42+
num_split_q = [32, 64, 128]
43+
num_stages = [1, 2, 3]
44+
enable_rasterization = [True]
45+
k_pack = [2]
46+
47+
valid_configs = []
48+
49+
for m, n, s, t, stages, r, k in itertools.product(block_M, block_N, num_split_q, threads,
50+
num_stages, enable_rasterization, k_pack):
51+
valid_configs.append({
52+
"block_M": m,
53+
"block_N": n,
54+
"num_split_q": s,
55+
"threads": t,
56+
"num_stages": stages,
57+
"enable_rasterization": r,
58+
"k_pack": k
59+
})
60+
if not valid_configs:
61+
valid_configs.append({
62+
'block_M': 64,
63+
'block_N': 64,
64+
'num_split_q': 64,
65+
'threads': 256,
66+
'num_stages': 1,
67+
'enable_rasterization': True,
68+
'k_pack': 2
69+
})
70+
return valid_configs
71+
72+
73+
@tilelang.autotune(configs=get_v2_configs(), cache_input_tensors=True)
74+
@tilelang.jit(out_idx=[3])
75+
def fast_flashattn_v2(
76+
batch,
77+
heads,
78+
seq_len,
79+
dim,
80+
is_causal,
81+
groups,
82+
block_M: int,
83+
block_N: int,
84+
num_split_q: int,
85+
threads: int,
86+
num_stages: int,
87+
enable_rasterization: bool,
88+
k_pack: int,
89+
):
90+
scale = (1.0 / dim)**0.5 * 1.44269504
91+
head_kv = heads // groups
92+
q_shape = [batch, seq_len, heads, dim]
93+
kv_shape = [batch, seq_len, head_kv, dim]
94+
dtype = "float16"
95+
accum_dtype = "float"
96+
97+
v_vec_size = 4
98+
99+
vec_size = 4 * k_pack
100+
101+
@T.macro
102+
def compute_block(
103+
bz,
104+
by,
105+
bx,
106+
Q: T.Tensor(q_shape, dtype),
107+
K: T.Tensor(kv_shape, dtype),
108+
V: T.Tensor(kv_shape, dtype),
109+
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
110+
m_i: T.FragmentBuffer([block_M], accum_dtype),
111+
l_i: T.FragmentBuffer([block_M], accum_dtype),
112+
):
113+
Q_shared = T.alloc_shared([block_M, dim], dtype)
114+
K_shared = T.alloc_shared([block_N, dim], dtype)
115+
V_shared = T.alloc_shared([block_N, dim], dtype)
116+
P_shared = T.alloc_shared([block_M, block_N], dtype)
117+
118+
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
119+
m_prev = T.alloc_fragment([block_M], accum_dtype)
120+
scale_factor = T.alloc_fragment([block_M], accum_dtype)
121+
122+
q_block_offset = bx * block_M
123+
T.copy(
124+
Q[bz, q_block_offset:q_block_offset + block_M, by, :],
125+
Q_shared,
126+
coalesced_width=vec_size)
127+
128+
loop_end_k = T.ceildiv(q_block_offset +
129+
block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
130+
for k in T.Pipelined(loop_end_k, num_stages=num_stages):
131+
kv_idx = k * block_N
132+
T.copy(
133+
K[bz, kv_idx:kv_idx + block_N, by // groups, :], K_shared, coalesced_width=vec_size)
134+
T.copy(
135+
V[bz, kv_idx:kv_idx + block_N, by // groups, :],
136+
V_shared,
137+
coalesced_width=v_vec_size)
138+
139+
T.clear(acc_s)
140+
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, k_pack=k_pack)
141+
142+
if is_causal:
143+
for i, j in T.Parallel(block_M, block_N):
144+
acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, acc_s[i, j],
145+
-T.infinity(acc_s.dtype))
146+
147+
T.copy(m_i, m_prev)
148+
T.reduce_max(acc_s, m_i, dim=1, clear=False)
149+
150+
for i in T.Parallel(block_M):
151+
sf = T.exp2(m_prev[i] * scale - m_i[i] * scale)
152+
l_i[i] *= sf
153+
scale_factor[i] = sf
154+
155+
for i, j in T.Parallel(block_M, dim):
156+
acc_o[i, j] *= scale_factor[i]
157+
158+
for i, j in T.Parallel(block_M, block_N):
159+
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - m_i[i] * scale)
160+
161+
row_sum = T.alloc_fragment([block_M], accum_dtype)
162+
T.reduce_sum(acc_s, row_sum, dim=1)
163+
for i in T.Parallel(block_M):
164+
l_i[i] += row_sum[i]
165+
166+
T.copy(acc_s, P_shared)
167+
T.sync_threads()
168+
169+
T.gemm(P_shared, V_shared, acc_o)
170+
171+
# 修复:将宏移至内核外部,以实现清晰的代码结构。
172+
@T.macro
173+
def scale_and_write_back(src_buffer, scale_vector, dest_tensor, bz, by, q_block_offset):
174+
# 此宏执行融合的缩放和写回操作,这对性能至关重要。
175+
for i, j in T.Parallel(block_M, dim):
176+
dest_tensor[bz, q_block_offset + i, by, j] = src_buffer[i, j] * scale_vector[i]
177+
178+
@T.macro
179+
def flash_attn_forward_kernel(Q: T.Tensor(q_shape, dtype), K: T.Tensor(kv_shape, dtype),
180+
V: T.Tensor(kv_shape, dtype), Output: T.Tensor(q_shape, dtype)):
181+
with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined):
182+
T.use_swizzle(10, enable=enable_rasterization)
183+
184+
bz = byz_combined // heads
185+
by = byz_combined % heads
186+
187+
num_q_blocks = T.ceildiv(seq_len, block_M)
188+
189+
bx = T.alloc_var("int32")
190+
bx[0] = b_split
191+
192+
with T.While(bx[0] < num_q_blocks):
193+
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
194+
m_i = T.alloc_fragment([block_M], accum_dtype)
195+
l_i = T.alloc_fragment([block_M], accum_dtype)
196+
T.fill(acc_o, 0)
197+
T.fill(m_i, -T.infinity(accum_dtype))
198+
T.fill(l_i, 0)
199+
200+
current_bx = bx[0]
201+
202+
compute_block(bz, by, current_bx, Q, K, V, acc_o, m_i, l_i)
203+
204+
l_inv = T.alloc_fragment([block_M], accum_dtype)
205+
for i in T.Parallel(block_M):
206+
safe_l = T.if_then_else(l_i[i] > 1e-6, l_i[i], 1.0)
207+
l_inv[i] = 1.0 / safe_l
208+
209+
# 修复:现在对宏的调用对编译器来说更清晰。
210+
q_block_offset = current_bx * block_M
211+
scale_and_write_back(acc_o, l_inv, Output, bz, by, q_block_offset)
212+
213+
bx[0] = current_bx + num_split_q
214+
215+
@T.prim_func
216+
def main(
217+
Q: T.Tensor(q_shape, dtype),
218+
K: T.Tensor(kv_shape, dtype),
219+
V: T.Tensor(kv_shape, dtype),
220+
Output: T.Tensor(q_shape, dtype),
221+
):
222+
flash_attn_forward_kernel(Q, K, V, Output)
223+
224+
return main
225+
226+
227+
# main 函数保持不变
228+
def main_v2(batch: int = 1,
229+
heads: int = 8,
230+
seq_len: int = 4096,
231+
dim: int = 128,
232+
is_causal: bool = False,
233+
groups: int = 1):
234+
235+
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
236+
total_flops = 2 * flops_per_matmul
237+
if is_causal:
238+
total_flops *= 0.5
239+
240+
print("Starting autotuning for FlashAttention-V2...")
241+
kernel = fast_flashattn_v2(batch, heads, seq_len, dim, is_causal, groups=groups)
242+
print(f"Autotuning finished. Best Configuration: {kernel.config}")
243+
244+
ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups)
245+
246+
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
247+
248+
print("Verifying correctness...")
249+
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
250+
print("All checks pass.")
251+
252+
latency = profiler.do_bench(ref_program_processed, warmup=100)
253+
print(f"Reference (PyTorch): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops")
254+
255+
latency = profiler.do_bench(warmup=100)
256+
print(
257+
f"Fast Flash Attention V2 (Tile-lang): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops"
258+
)
259+
260+
261+
if __name__ == "__main__":
262+
parser = argparse.ArgumentParser()
263+
parser.add_argument('--batch', type=int, default=1, help='batch size')
264+
parser.add_argument('--heads', type=int, default=8, help='heads')
265+
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
266+
parser.add_argument('--dim', type=int, default=128, help='dim')
267+
parser.add_argument('--is_causal', action='store_true', help='causal')
268+
parser.add_argument('--groups', type=int, default=1, help='groups')
269+
args = parser.parse_args()
270+
main_v2(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups)

0 commit comments

Comments
 (0)