Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance update on the backward split kernel #127

Open
wants to merge 8 commits into
base: main_perf
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion flash_attn/flash_attn_triton_amd/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,15 @@ def gen_fn_inputs(fn_name, BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, devic
).expand(-1, -1, -1, HQ // HK, -1)
input_metadata = MetaData(sm_scale=1.3)
input_metadata.layout = "bsghd"

# Adjust flops calculation if needed
flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD

input_data = (q, k, v, input_metadata)
else:
raise ValueError("Unsupported benchmark function")

input_metadata.use_exp2 = True
return input_data, flops_per_matmul

def run_benchmark(args, fn_name, fn, mode):
Expand Down
23 changes: 12 additions & 11 deletions flash_attn/flash_attn_triton_amd/bwd_prefill_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ def _bwd_kernel_dkdv(
# determine the starting M blocks to skip some initial blocks masked by
# causal.
delta_qk = seqlen_q - seqlen_k
if DEBUG_TRITON: print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}")
if DEBUG_TRITON: print(f"delta_qk = {delta_qk}")
if DEBUG_TRITON: print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701
if DEBUG_TRITON: print(f"delta_qk = {delta_qk}") # noqa: E701
# q > k: diretcly skip all the way until the start of causal block
start_delta_q_gt_k = delta_qk
# q < k: some blocks will have no Masked block, other needs to re-calc
Expand All @@ -257,10 +257,10 @@ def _bwd_kernel_dkdv(
start_delta_q_lt_k = delta_aligned // BLOCK_M * BLOCK_M
if delta_qk >= 0:
start_delta = delta_qk
if DEBUG_TRITON: print(f"q >= k: start_delta = delta_qk aligned to BLOCK_M = {start_delta_q_gt_k}")
if DEBUG_TRITON: print(f"q >= k: start_delta = delta_qk aligned to BLOCK_M = {start_delta_q_gt_k}") # noqa: E701
else:
start_delta = start_delta_q_lt_k
if DEBUG_TRITON: print(f"q < k: start_delta = residue btw multiple BLOCK_N and delta_qk = {delta_aligned} = aligned to BLOCK_M = {start_delta_q_lt_k}")
if DEBUG_TRITON: print(f"q < k: start_delta = residue btw multiple BLOCK_N and delta_qk = {delta_aligned} = aligned to BLOCK_M = {start_delta_q_lt_k}") # noqa: E701
# align the delta_qk
start_n = pid * BLOCK_N

Expand All @@ -274,7 +274,7 @@ def _bwd_kernel_dkdv(
mask_kv &= mask_k[None, :]
offs_kv = offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk

GROUP_SIZE = HQ // HK
GROUP_SIZE: tl.constexpr = HQ // HK
# K/V tensors not changed for the group
adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn
# load K and V: they stay in SRAM throughout the inner loop.
Expand All @@ -293,14 +293,15 @@ def _bwd_kernel_dkdv(
# steps with masked operation to get out of it
residue_m = max(start_n + delta_qk - start_m, 0)
len_m = BLOCK_N + residue_m
if DEBUG_TRITON: print(f"residue_m = {residue_m}")
if DEBUG_TRITON: print(f"residue_m = {residue_m}") # noqa: E701

# offset input and output tensor by batch and Q/K heads
adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm
Q_ptr = Q + adj_q
adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom
DO_ptr = DO + adj_do
adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam
adj_delta = bid * stride_deltab + hqid * stride_deltah + \
q_start * stride_deltam
M_ptr = M + adj_delta
Delta_ptr = Delta + adj_delta

Expand All @@ -324,7 +325,7 @@ def _bwd_kernel_dkdv(

# if start_m is negative, the current N-tile has no block on the
# diagonal of causal mask, so everything have no causal mask
if DEBUG_TRITON: print(f"Masked: start_n: {start_n}; start_m: {start_m}, num_steps: {num_steps}")
if DEBUG_TRITON: print(f"Masked: start_n: {start_n}; start_m: {start_m}, num_steps: {num_steps}") # noqa: E701
dk, dv = _bwd_dkdv_inner(
dk, dv, # output tensors
Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors
Expand Down Expand Up @@ -555,7 +556,7 @@ def _bwd_kernel_dq(
K += adj_kv
V += adj_kv
# If MQA / GQA, set the K and V head offsets appropriately.
GROUP_SIZE = HQ // HK
GROUP_SIZE: tl.constexpr = HQ // HK
for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE):
# seqlen_q < seqlen_k: delta_qk more kv tokens are added at the front
# for every M-tile
Expand Down Expand Up @@ -703,7 +704,7 @@ def _bwd_kernel_dkdv_noncausal(
mask_kv &= mask_k[None, :]
offs_kv = offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk

GROUP_SIZE = HQ // HK
GROUP_SIZE: tl.constexpr = HQ // HK
# K/V tensors not changed for the group
adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn
# load K and V: they stay in SRAM throughout the inner loop.
Expand Down Expand Up @@ -819,7 +820,7 @@ def _bwd_kernel_dq_noncausal(
K += adj_kv
V += adj_kv
# If MQA / GQA, set the K and V head offsets appropriately.
GROUP_SIZE = HQ // HK
GROUP_SIZE: tl.constexpr = HQ // HK
for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE):
# offset input and output tensor by batch and Q/K heads
adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm
Expand Down
Loading