Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performant backward Triton implementation with separated dkdv and dq kernels #122

Merged
merged 34 commits into from
Feb 4, 2025
Merged
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
f8c1ee5
added the split file
jtang10 Dec 17, 2024
c1b5fae
overhauled split file, need to add new kernels
jtang10 Dec 17, 2024
a3f622f
copied triton fa over for reference
jtang10 Dec 18, 2024
90d639e
added comments
jtang10 Dec 18, 2024
0560963
preprocess and dkdv done
jtang10 Dec 27, 2024
9328e5a
fixed dkdv, added dq
jtang10 Dec 27, 2024
682eb1f
fixed assumption on q, kv length different, run but incorrect
jtang10 Jan 6, 2025
d0afca7
added standalone test for split bwd kernel
jtang10 Jan 7, 2025
1c542d2
minor change on the ptr arith
jtang10 Jan 7, 2025
8b1629f
separated the dkdv and dq kernels
jtang10 Jan 8, 2025
2e4c812
GQA works now, onto seqlen q != k
jtang10 Jan 9, 2025
be65b39
dk,dq working, dv still failing
jtang10 Jan 11, 2025
5e5ad91
fixed the masking and num_step calc, now q==k works
jtang10 Jan 21, 2025
2e67d95
added debug print with interpreter, might not work entirely w/o next …
jtang10 Jan 28, 2025
149dd4b
fixed all issues with q != k
jtang10 Jan 28, 2025
c16a4f7
fixed varlen issue
jtang10 Jan 28, 2025
58cc0f2
fixup on debug print
jtang10 Jan 28, 2025
88054a7
fixed dropout, esp w/ varlen
jtang10 Jan 29, 2025
059f665
added USE_EXP2 toggle
jtang10 Jan 29, 2025
ac18466
added noncausal kernel
jtang10 Jan 29, 2025
9fa9688
updated internal test for noncausal and use_exp2
jtang10 Jan 29, 2025
10e5468
formatting
jtang10 Jan 29, 2025
91b30ac
fixed dropout from seed bug
jtang10 Jan 30, 2025
00d2c77
added envvar USE_SPLIT to toggle btw bwd kernels
jtang10 Jan 30, 2025
58941ed
fixed the qkv pack issue and removed hack
jtang10 Jan 31, 2025
7fadfe3
added the split kernel into interface_fa.py
jtang10 Jan 31, 2025
c97f298
change USE_SPLIT to USE_SINGLE_BWD_KERNEL to make split default
jtang10 Jan 31, 2025
c6f5607
removed redundant file
jtang10 Jan 31, 2025
383d8c7
fixed missing import in test
jtang10 Feb 3, 2025
8e99137
fixed import in interface_fa.py
jtang10 Feb 3, 2025
8436dc7
revert changes in flash_attn_interface.py
jtang10 Feb 3, 2025
ada4bb8
updated strides to adapt to various tensor init shape
jtang10 Feb 3, 2025
bacc596
fixed issue that dqkv not zero'd
jtang10 Feb 4, 2025
0055b35
disabled the AMD local test
jtang10 Feb 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
minor change on the ptr arith
  • Loading branch information
jtang10 committed Feb 3, 2025
commit 1c542d2a141a6c7a7e5164f89fd8e7f3b60ce91e
62 changes: 32 additions & 30 deletions flash_attn/flash_attn_triton_amd/bwd_prefill_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _attn_bwd_dkdv(

# Load m before computing qk to reduce pipeline stall.
offs_m = curr_m + tl.arange(0, BLOCK_M1)
m = tl.load(M + offs_m, mask=offs_m < seqlen_q, other=0.0)
m = tl.load(M + offs_m, mask=offs_m < seqlen_q)
qkT = tl.dot(k, qT)
pT = tl.math.exp2(qkT - m[None, :])
# Autoregressive masking.
Expand All @@ -141,7 +141,7 @@ def _attn_bwd_dkdv(
ppT = ppT.to(tl.float16)
dv += tl.dot(ppT, do)
# D (= delta) is pre-divided by ds_scale.
Di = tl.load(D + offs_m, mask=offs_m < seqlen_q, other=0.0)
Di = tl.load(D + offs_m, mask=offs_m < seqlen_q)
# Compute dP and dS.
dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
if DROPOUT:
Expand All @@ -153,8 +153,9 @@ def _attn_bwd_dkdv(
curr_m += step_m
qT_ptrs += step_m * stride_qm
do_ptrs += step_m * stride_qm
curr_dropout_offset += step_m * stride_qm
curr_philox_offset += step_m * stride_qm
if DROPOUT:
curr_dropout_offset += step_m * stride_qm
curr_philox_offset += step_m * stride_qm
return dk, dv


Expand Down Expand Up @@ -241,7 +242,7 @@ def _attn_bwd_dq(dq, # output
# num_pid = max(
# tl.cdiv(max_seqlen_k // BLOCK_N1),
# tl.cdiv(max_seqlen_q // BLOCK_M2))
# grid = (num_pid, 1, batch * nheads_q)
# grid = (num_pid, batch * nheads_q)
@triton.jit
def _bwd_kernel(
Q, K, V, sm_scale, Out, DO, DQ, DK, DV,
Expand Down Expand Up @@ -270,7 +271,7 @@ def _bwd_kernel(
LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
# program ids
pid = tl.program_id(0)
bhqid = tl.program_id(2)
bhqid = tl.program_id(1)
bid = bhqid // HQ
hqid = bhqid % HQ

Expand Down Expand Up @@ -302,21 +303,24 @@ def _bwd_kernel(
# If MQA / GQA, set the K and V head offsets appropriately.
GROUP_SIZE = HQ // HK
if GROUP_SIZE != 1:
off_hk = hqid // GROUP_SIZE
hkid = hqid // GROUP_SIZE
else:
off_hk = hqid
hkid = hqid

# input tensor offsets
Q += bid * stride_qb + hqid * stride_qh + q_start * stride_qm
K += bid * stride_kb + off_hk * stride_kh + k_start * stride_kn
V += bid * stride_vb + off_hk * stride_vh + k_start * stride_vn
DO += bid * stride_qb + hqid * stride_qh + q_start * stride_qm
M += bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam
Delta += bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam
adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm
adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn
adj_delta = bhqid * stride_deltab + q_start * stride_deltam
Q += adj_q
K += adj_kv
V += adj_kv
DO += adj_q
M += adj_delta
Delta += adj_delta
# output tensor offsets
DQ += bid * stride_qb + hqid * stride_qh + q_start * stride_qm
DK += bid * stride_kb + off_hk * stride_kh + k_start * stride_kn
DV += bid * stride_vb + off_hk * stride_vh + k_start * stride_vn
DQ += adj_q
DK += adj_kv
DV += adj_kv

# dropout is a boolean mask that will clear out the multiplicant tensor
# wherever the dropout's entry is 0. It is generated by the tl.rand(seed,
Expand Down Expand Up @@ -359,10 +363,9 @@ def _bwd_kernel(
mask_k = offs_k < ACTUAL_HEAD_DIM
mask_kv &= mask_k[None, :]
# load K and V: they stay in SRAM throughout the inner loop.
k_ptrs = K + offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk
v_ptrs = V + offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk
k = tl.load(k_ptrs, mask=mask_kv, other=0.0)
v = tl.load(v_ptrs, mask=mask_kv, other=0.0)
offs_kv = offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk
k = tl.load(K + offs_kv, mask=mask_kv, other=0.0)
v = tl.load(V + offs_kv, mask=mask_kv, other=0.0)

MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
num_steps = BLOCK_N1 // MASK_BLOCK_M1
Expand Down Expand Up @@ -390,7 +393,7 @@ def _bwd_kernel(
else:
start_m = 0

num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M1)
num_steps = (seqlen_q - start_m) // BLOCK_M1
# only the blocks on the causal mask diagonal needs to mask
dk, dv = _attn_bwd_dkdv(
dk, dv, # output tensors
Expand All @@ -407,11 +410,9 @@ def _bwd_kernel(
)

# Write back dV and dK.
dv_ptrs = DV + offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk
tl.store(dv_ptrs, dv, mask=mask_kv)
tl.store(DV + offs_kv, dv, mask=mask_kv)
dk *= sm_scale
dk_ptrs = DK + offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk
tl.store(dk_ptrs, dk, mask=mask_kv)
tl.store(DK + offs_kv, dk, mask=mask_kv)

# THIS BLOCK DOES DQ:
if pid < num_qblocks:
Expand All @@ -420,8 +421,9 @@ def _bwd_kernel(
# TODO: now pid is only a function of max_seqlen_k, so it's incorrect for the
start_m = pid * BLOCK_M2
# seqlen_q > seqlen_k, no need to process these tile for dq
if start_m + BLOCK_M2 < seqlen_delta:
return
# TODO: fix this
# if start_m + BLOCK_M2 < seqlen_delta:
# return
end_n = start_m + BLOCK_M2
# when seqlen_q < seqlen_k, the end_n is padded
end_n += seqlen_delta
Expand All @@ -439,7 +441,7 @@ def _bwd_kernel(
dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
do = tl.load(do_ptrs, mask=mask_q, other=0.0)

m = tl.load(M + offs_m, mask=offs_m < seqlen_q, other=0.0)
m = tl.load(M + offs_m, mask=offs_m < seqlen_q)
m = m[:, None]

# Compute dQ for masked (diagonal) blocks.
Expand Down Expand Up @@ -584,7 +586,7 @@ def attention_prefill_backward_triton_split_impl(
num_pid = max(
(max_seqlen_k + BLOCK_N1 - 1) // BLOCK_N1,
(max_seqlen_q + BLOCK_M2 - 1) // BLOCK_M2)
grid = (num_pid, 1, batch * nheads_q)
grid = (num_pid, batch * nheads_q)
_bwd_kernel[grid](
q, k, v, sm_scale, o, do, dq, dk, dv,
softmax_lse, delta,
Expand Down