Skip to content
237 changes: 237 additions & 0 deletions examples/amd/example_amd_flash_attn_fwd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
import itertools
import argparse
from functools import partial


def ref_program(Q, K, V, is_causal, groups=1):
assert Q.size(
2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment is in Chinese. For consistency and to make the code more accessible to a wider audience, it's best to use English for all comments.

Suggested change
2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
# PyTorch reference implementation

assert Q.size(
2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}"
dim = Q.size(-1)
K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
return output


def get_configs():
"""Generates configurations for the autotuner, tailored for FA-2 style parallelism."""
block_M = [64, 128, 256]
block_N = [32, 64, 128]
threads = [128, 256, 512]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a comment to explain the purpose of annotating the layout for Q_shared and how it optimizes memory access.

        T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)}) # Optimize memory access by using a swizzled layout

num_split_q = [32, 64, 128]
num_stages = [0, 1, 2]
enable_rasterization = [True, False]
k_pack = [1, 2]

valid_configs = []

for m, n, s, t, stages, r, k in itertools.product(block_M, block_N, num_split_q, threads,
num_stages, enable_rasterization, k_pack):
valid_configs.append({
"block_M": m,
"block_N": n,
"num_split_q": s,
"threads": t,
"num_stages": stages,
"enable_rasterization": r,
"k_pack": k
})
valid_configs.append({
'block_M': 64,
'block_N': 64,
'num_split_q': 64,
'threads': 256,
'num_stages': 1,
'enable_rasterization': True,
'k_pack': 2
})
return valid_configs


@tilelang.autotune(configs=get_configs(), cache_input_tensors=True)
@tilelang.jit(out_idx=[3])
def fast_flashattn(
batch,
heads,
seq_len,
dim,
is_causal,
groups,
block_M: int,
block_N: int,
num_split_q: int,
threads: int,
num_stages: int,
enable_rasterization: bool,
k_pack: int,
):
scale = (1.0 / dim)**0.5 * 1.44269504
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The magic number 1.44269504 is an approximation of 1/ln(2) (or log2(e)), used to convert natural exponentiation to base-2 exponentiation (exp(x) = exp2(x * log2(e))). Adding a comment to explain this will improve readability.

    scale = (1.0 / dim)**0.5 * 1.44269504  # 1/sqrt(dim) * log2(e)

dtype = "float16"
accum_dtype = "float"

v_vec_size = 4
vec_size = 4 * k_pack

@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined):
T.use_swizzle(10, enable=enable_rasterization)

bz = byz_combined // heads
by = byz_combined % heads

num_q_blocks = T.ceildiv(seq_len, block_M)

bx = T.alloc_var("int32")
bx[0] = b_split

with T.While(bx[0] < num_q_blocks):
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
m_i = T.alloc_fragment([block_M], accum_dtype)
l_i = T.alloc_fragment([block_M], accum_dtype)
T.fill(acc_o, 0)
T.fill(m_i, -T.infinity(accum_dtype))
T.fill(l_i, 0)

current_bx = bx[0]
q_block_offset = current_bx * block_M

Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
P_shared = T.alloc_shared([block_M, block_N], dtype)

acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
m_prev = T.alloc_fragment([block_M], accum_dtype)
scale_factor = T.alloc_fragment([block_M], accum_dtype)

T.copy(
Q[bz, q_block_offset:q_block_offset + block_M, by, :],
Q_shared,
coalesced_width=vec_size)

loop_end_k = T.ceildiv(q_block_offset + block_M,
block_N) if is_causal else T.ceildiv(seq_len, block_N)

for k in T.Pipelined(loop_end_k, num_stages=num_stages):
kv_idx = k * block_N

T.copy(
K[bz, kv_idx:kv_idx + block_N, by // groups, :],
K_shared,
coalesced_width=vec_size)
T.copy(
V[bz, kv_idx:kv_idx + block_N, by // groups, :],
V_shared,
coalesced_width=v_vec_size)

T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, k_pack=k_pack)

if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j,
acc_s[i, j], -T.infinity(acc_s.dtype))

T.copy(m_i, m_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)

for i in T.Parallel(block_M):
sf = T.exp2(m_prev[i] * scale - m_i[i] * scale)
l_i[i] *= sf
scale_factor[i] = sf

for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scale_factor[i]

for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - m_i[i] * scale)

row_sum = T.alloc_fragment([block_M], accum_dtype)
T.reduce_sum(acc_s, row_sum, dim=1)
for i in T.Parallel(block_M):
l_i[i] += row_sum[i]

T.copy(acc_s, P_shared)
T.sync_threads()

T.gemm(P_shared, V_shared, acc_o)

l_inv = T.alloc_fragment([block_M], accum_dtype)
for i in T.Parallel(block_M):
safe_l = T.if_then_else(l_i[i] > 1e-6, l_i[i], 1.0)
l_inv[i] = 1.0 / safe_l

for i, j in T.Parallel(block_M, dim):
Output[bz, q_block_offset + i, by, j] = acc_o[i, j] * l_inv[i]

bx[0] = current_bx + num_split_q

return main


def main(batch: int = 1,
heads: int = 8,
seq_len: int = 4096,
dim: int = 128,
is_causal: bool = False,
groups: int = 1):

flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 2 * flops_per_matmul
if is_causal:
total_flops *= 0.5

print("Starting autotuning for FlashAttention-V2...")
kernel = fast_flashattn(batch, heads, seq_len, dim, is_causal, groups=groups)
print(f"Autotuning finished. Best Configuration: {kernel.config}")

ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups)

profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)

print("Verifying correctness...")
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.")

latency = profiler.do_bench(ref_program_processed, warmup=100)
print(f"Reference (PyTorch): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops")

latency = profiler.do_bench(warmup=100)
print(
f"Fast Flash Attention V2 (Tile-lang): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops"
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=8, help='heads')
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--groups', type=int, default=1, help='groups')
args = parser.parse_args()
main(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups)
4 changes: 2 additions & 2 deletions src/tl_templates/hip/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct MinOp {
}
};

template <class Reducer, int threads, int scale> struct AllReduce {
template <class Reducer, int threads, int scale, int thread_offset = 0> struct AllReduce {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new template parameter thread_offset is introduced but it is not used within the AllReduce struct. This is potentially dead code and makes the template signature more complex than necessary. If this parameter is intended for future use, a comment explaining its purpose would be beneficial. Otherwise, it should be removed to improve code clarity.

static_assert(threads == 1024 || threads == 512 || threads == 256 ||
threads == 128 || threads == 64 || threads == 32 ||
threads == 16 || threads == 8 || threads == 4 || threads == 2);
Expand All @@ -43,7 +43,7 @@ template <class Reducer, int threads, int scale> struct AllReduce {
if constexpr (offset == scale) {
return x;
} else {
return AllReduce<Reducer, offset, scale>::run(x, red_buf);
return AllReduce<Reducer, offset, scale, thread_offset>::run(x, red_buf);
}
}
};
Expand Down
Loading