Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performant backward Triton implementation with separated dkdv and dq kernels #122

Merged
merged 34 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
f8c1ee5
added the split file
jtang10 Dec 17, 2024
c1b5fae
overhauled split file, need to add new kernels
jtang10 Dec 17, 2024
a3f622f
copied triton fa over for reference
jtang10 Dec 18, 2024
90d639e
added comments
jtang10 Dec 18, 2024
0560963
preprocess and dkdv done
jtang10 Dec 27, 2024
9328e5a
fixed dkdv, added dq
jtang10 Dec 27, 2024
682eb1f
fixed assumption on q, kv length different, run but incorrect
jtang10 Jan 6, 2025
d0afca7
added standalone test for split bwd kernel
jtang10 Jan 7, 2025
1c542d2
minor change on the ptr arith
jtang10 Jan 7, 2025
8b1629f
separated the dkdv and dq kernels
jtang10 Jan 8, 2025
2e4c812
GQA works now, onto seqlen q != k
jtang10 Jan 9, 2025
be65b39
dk,dq working, dv still failing
jtang10 Jan 11, 2025
5e5ad91
fixed the masking and num_step calc, now q==k works
jtang10 Jan 21, 2025
2e67d95
added debug print with interpreter, might not work entirely w/o next …
jtang10 Jan 28, 2025
149dd4b
fixed all issues with q != k
jtang10 Jan 28, 2025
c16a4f7
fixed varlen issue
jtang10 Jan 28, 2025
58cc0f2
fixup on debug print
jtang10 Jan 28, 2025
88054a7
fixed dropout, esp w/ varlen
jtang10 Jan 29, 2025
059f665
added USE_EXP2 toggle
jtang10 Jan 29, 2025
ac18466
added noncausal kernel
jtang10 Jan 29, 2025
9fa9688
updated internal test for noncausal and use_exp2
jtang10 Jan 29, 2025
10e5468
formatting
jtang10 Jan 29, 2025
91b30ac
fixed dropout from seed bug
jtang10 Jan 30, 2025
00d2c77
added envvar USE_SPLIT to toggle btw bwd kernels
jtang10 Jan 30, 2025
58941ed
fixed the qkv pack issue and removed hack
jtang10 Jan 31, 2025
7fadfe3
added the split kernel into interface_fa.py
jtang10 Jan 31, 2025
c97f298
change USE_SPLIT to USE_SINGLE_BWD_KERNEL to make split default
jtang10 Jan 31, 2025
c6f5607
removed redundant file
jtang10 Jan 31, 2025
383d8c7
fixed missing import in test
jtang10 Feb 3, 2025
8e99137
fixed import in interface_fa.py
jtang10 Feb 3, 2025
8436dc7
revert changes in flash_attn_interface.py
jtang10 Feb 3, 2025
ada4bb8
updated strides to adapt to various tensor init shape
jtang10 Feb 3, 2025
bacc596
fixed issue that dqkv not zero'd
jtang10 Feb 4, 2025
0055b35
disabled the AMD local test
jtang10 Feb 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixed the qkv pack issue and removed hack
  • Loading branch information
jtang10 committed Feb 3, 2025
commit 58941ed269b9b94c3dd8b6a87cc54f6a32404d9c
65 changes: 32 additions & 33 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,12 +470,10 @@ def forward(
):
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
q, k, v = qkv[:, :, 0].detach(), qkv[:, :, 1].detach(), qkv[:, :, 2].detach()
head_size_og = q.size(3)
head_size_og = qkv.shape[-1]
if head_size_og % 8 != 0:
q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
qkv = torch.nn.functional.pad(qkv, [0, 8 - head_size_og % 8])
q, k, v = qkv[:, :, 0].detach(), qkv[:, :, 1].detach(), qkv[:, :, 2].detach()
out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
q,
k,
Expand All @@ -489,7 +487,7 @@ def forward(
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.save_for_backward(qkv, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
ctx.causal = causal
Expand All @@ -502,23 +500,23 @@ def forward(

@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
qkv, out, softmax_lse, rng_state = ctx.saved_tensors
b, s, p, h, d = qkv.shape
dqkv = torch.empty((p, b, s, h, d), dtype=qkv.dtype, device=qkv.device)
head_size_og = dout.size(3)
dout_padded = dout
if head_size_og % 8 != 0:
dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
_wrapped_flash_attn_backward(
dout_padded,
q,
k,
v,
qkv[:, :, 0],
qkv[:, :, 1],
qkv[:, :, 2],
out,
softmax_lse,
dqkv[:, :, 0],
dqkv[:, :, 1],
dqkv[:, :, 2],
dqkv[0],
dqkv[1],
dqkv[2],
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
Expand All @@ -529,7 +527,8 @@ def backward(ctx, dout, *args):
ctx.deterministic,
rng_state=rng_state,
)
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
dqkv = dqkv[..., :head_size_og] # We could have padded the head dimension
dqkv = torch.permute(dqkv, (1, 2, 0, 3, 4)).contiguous()
return dqkv, None, None, None, None, None, None, None, None


Expand All @@ -551,12 +550,10 @@ def forward(
):
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach()
head_size_og = q.size(2)
head_size_og = qkv.shape[-1]
if head_size_og % 8 != 0:
q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
qkv = torch.nn.functional.pad(qkv, [0, 8 - head_size_og % 8])
q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach()
out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
q,
k,
Expand All @@ -575,7 +572,7 @@ def forward(
return_softmax=return_softmax and dropout_p > 0,
block_table=None,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
ctx.save_for_backward(qkv, out_padded, softmax_lse, cu_seqlens, rng_state)
ctx.dropout_p = dropout_p
ctx.max_seqlen = max_seqlen
ctx.softmax_scale = softmax_scale
Expand All @@ -589,23 +586,24 @@ def forward(

@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
head_size_og = dout.size(2)
qkv, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
print('qkv.shape:', qkv.shape)
s, p, h, d = qkv.shape
dqkv = torch.empty((p, s, h, d), dtype=qkv.dtype, device=qkv.device)
head_size_og = dout.size(-1)
dout_padded = dout
if head_size_og % 8 != 0:
dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
_wrapped_flash_attn_varlen_backward(
dout_padded,
q,
k,
v,
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
out,
softmax_lse,
dqkv[:, 0],
dqkv[:, 1],
dqkv[:, 2],
dqkv[0],
dqkv[1],
dqkv[2],
cu_seqlens,
cu_seqlens,
ctx.max_seqlen,
Expand All @@ -620,7 +618,8 @@ def backward(ctx, dout, *args):
ctx.deterministic,
rng_state=rng_state,
)
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
dqkv = dqkv[..., :head_size_og] # We could have padded the head dimension
dqkv = torch.permute(dqkv, (1, 0, 2, 3)).contiguous()
return dqkv, None, None, None, None, None, None, None, None, None, None


Expand Down
39 changes: 0 additions & 39 deletions flash_attn/flash_attn_triton_amd/bwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,50 +596,21 @@ def attention_prefill_backward_triton_impl(
ACTUAL_BLOCK_DMODEL = head_size

do = do.contiguous()
# NOTE: we might need to copy the output tensor if they are not continuous or have other issues
copy_back = {"dq": False, "dk": False, "dv": False}

# deal with dq
if dq is None:
if sequence_parallel:
dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype)
else:
dq = torch.zeros(q.shape, device=q.device, dtype=q.dtype)
else:
dq_og = dq
if (not dq.is_contiguous()):
dq = dq.contiguous()
copy_back["dq"] = True

if sequence_parallel:
dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype)
copy_back["dq"] = True
else:
# NOTE: the kernel does inplace accumlation so dq has to be zeros. This avoids the case where we are passed empty dq and it is not all zeros
dq.zero_()
stride_dq_all = dq.stride()[0]

# deal with dk, dv
if (dk is None) or (dv is None):
dk = torch.zeros_like(k)
dv = torch.zeros_like(v)
else:
# store og
dk_og = dk
dv_og = dv


if (not dk.is_contiguous()):
dk = dk.contiguous()
copy_back["dk"] = True

if (not dv.is_contiguous()):
dv = dv.contiguous()
copy_back["dv"] = True

if DEBUG:
print("copy_back:", copy_back)

# zero out
dq.zero_()
dk.zero_()
Expand Down Expand Up @@ -788,14 +759,4 @@ def attention_prefill_backward_triton_impl(
print("dropout_fraction bwd:", 1.0 - (dropout_mask.sum()/ dropout_mask.numel()).item())
write_dropout_mask(dropout_mask, "dropout_mask_bwd")

if copy_back["dq"]:
dq_og.copy_(dq)
dq = dq_og
if copy_back["dk"]:
dk_og.copy_(dk)
dk = dk_og
if copy_back["dv"]:
dv_og.copy_(dv)
dv = dv_og

return dq, dk, dv, delta, None, None