Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 20 additions & 153 deletions examples/flash_attention/example_mha_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,110 +149,7 @@ def flash_bwd_post(
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_atomic_add(batch,
heads,
seq_len,
dim,
is_causal,
block_M,
block_N,
threads=128,
num_stages=2):
sm_scale = (1.0 / dim)**0.5
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"

@T.prim_func
def flash_bwd(
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dK: T.Tensor(shape, accum_dtype), # type: ignore
dV: T.Tensor(shape, accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim], dtype)
dv = T.alloc_fragment([block_M, dim], accum_dtype)
dk = T.alloc_fragment([block_M, dim], accum_dtype)
dq = T.alloc_fragment([block_N, dim], accum_dtype)
dk_shared = T.alloc_shared([block_M, dim], accum_dtype)
dv_shared = T.alloc_shared([block_M, dim], accum_dtype)

T.annotate_layout({
dQ: make_dq_layout(dQ),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
})
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_causal:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)

T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)

for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)

T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim):
if k * block_N + i < seq_len:
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared)
T.atomic_add(dV[bz, by * block_M:(by + 1) * block_M, bx, :], dv_shared)
T.copy(dk, dk_shared)
T.atomic_add(dK[bz, by * block_M:(by + 1) * block_M, bx, :], dk_shared)

return flash_bwd


@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd_split(batch,
heads,
seq_len,
dim,
is_causal,
block_M,
block_N,
threads=128,
num_stages=2):
def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
sm_scale = (1.0 / dim)**0.5
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
Expand All @@ -271,9 +168,13 @@ def flash_bwd(
dK: T.Tensor(shape, dtype), # type: ignore
dV: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=128) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
# should not store K to local if dim is large
# K_local = T.alloc_fragment([block_M, dim], dtype)
# K_local_T = T.alloc_fragment([block_M, dim], dtype)
# V_local = T.alloc_fragment([block_M, dim], dtype)
q = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
Expand Down Expand Up @@ -301,7 +202,7 @@ def flash_bwd(
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
for k in T.Pipelined(loop_st, loop_ed, num_stages=2):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
Expand All @@ -328,8 +229,7 @@ def flash_bwd(
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim):
if k * block_N + i < seq_len:
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared)
T.copy(dk, dk_shared)
T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :])
Expand All @@ -341,14 +241,13 @@ def flash_bwd(
class _attention(torch.autograd.Function):

@staticmethod
def forward(ctx, q, k, v, causal, use_atomic=True):
def forward(ctx, q, k, v, causal):
BATCH, N_CTX, H, D_HEAD = q.shape
block_M = 64
block_N = 64 if D_HEAD <= 128 else 32
o, lse = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N)(q, k, v)
ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal
ctx.use_atomic = use_atomic
return o

@staticmethod
Expand All @@ -367,29 +266,14 @@ def maybe_contiguous(x):
kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD)
kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD)
delta = kernel_prep(o, do)

if ctx.use_atomic:
kernel = flashattn_bwd_atomic_add(
BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, threads=128, num_stages=2)
shape = [BATCH, N_CTX, H, D_HEAD]
dq = torch.zeros(shape, dtype=torch.float32, device=q.device)
dk = torch.zeros(shape, dtype=torch.float32, device=q.device)
dv = torch.zeros(shape, dtype=torch.float32, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = kernel_post(dq)
dk = dk.to(torch.float16)
dv = dv.to(torch.float16)
else:
kernel = flashattn_bwd_split(
BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N, threads=128, num_stages=2)
shape = [BATCH, N_CTX, H, D_HEAD]
dq = torch.zeros(shape, dtype=torch.float32, device=q.device)
dk = torch.empty(shape, dtype=torch.float16, device=q.device)
dv = torch.empty(shape, dtype=torch.float16, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = kernel_post(dq)

return dq, dk, dv, None, None
kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N)
shape = [BATCH, N_CTX, H, D_HEAD]
dq = torch.zeros(shape, dtype=torch.float32, device=q.device)
dk = torch.empty(shape, dtype=torch.float16, device=q.device)
dv = torch.empty(shape, dtype=torch.float16, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = kernel_post(dq)
return dq, dk, dv, None


attention = _attention.apply
Expand All @@ -415,9 +299,7 @@ def main(
N_CTX: int = 1024,
D_HEAD: int = 64,
causal: bool = False,
use_atomic: bool = True,
):
print(f"Test with use_atomic: {use_atomic}")
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
total_flops = 5 * flops_per_matmul
if causal:
Expand All @@ -428,7 +310,7 @@ def main(
K = torch.empty_like(Q).normal_().requires_grad_()
V = torch.empty_like(Q).normal_().requires_grad_()
dO = torch.randn_like(Q)
O = attention(Q, K, V, causal, use_atomic)
O = attention(Q, K, V, causal)
O.backward(dO, retain_graph=True)
dQ, Q.grad = Q.grad.clone(), None
dK, K.grad = K.grad.clone(), None
Expand All @@ -444,7 +326,6 @@ def main(
assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2)
assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
print('All checks passed.✅')

def run():
O_ref.backward(dO, retain_graph=True)
Expand All @@ -468,20 +349,6 @@ def run1():
parser.add_argument('--h', type=int, default=32, help='Number of heads')
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
parser.add_argument('--d_head', type=int, default=64, help='Head dimension')
parser.add_argument('--causal', action='store_true', help='Causal flag')
parser.add_argument(
'--use_atomic', action='store_true', default=False, help='Use atomic add for dK/dV')
parser.add_argument(
'--use_split', action='store_true', default=False, help='Use split for dK/dV')
parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix argparse boolean handling.

Using type=bool with argparse doesn't work as expected. Any non-empty string (including "False") will be converted to True. Command-line usage like --causal False will incorrectly set causal=True.

Apply this diff to fix the boolean argument parsing:

-    parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
+    parser.add_argument('--causal', action='store_true', help='Causal flag')

Alternatively, if you need explicit True/False control:

-    parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
+    parser.add_argument('--causal', type=lambda x: x.lower() == 'true', default=False, help='Causal flag')
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
parser.add_argument('--causal', action='store_true', help='Causal flag')
🤖 Prompt for AI Agents
In examples/flash_attention/example_mha_bwd.py around line 353, the argparse
call uses type=bool which treats any non-empty string (e.g. "False") as True;
replace it with a proper boolean flag by using parser.add_argument('--causal',
action='store_true', help='Causal flag') so passing --causal sets True and
omitting it leaves False; if you need explicit True/False parsing from strings
instead, add a small str2bool helper that maps common truthy/falsey strings to
booleans and use type=str2bool (and keep default=False).

args = parser.parse_args()

# Handle backward compatibility and logic
if args.use_split:
use_atomic = False
elif args.use_atomic:
use_atomic = True
else:
# Default: use atomic
use_atomic = True

main(args.batch, args.h, args.n_ctx, args.d_head, args.causal, use_atomic)
main(args.batch, args.h, args.n_ctx, args.d_head, args.causal)
Loading