Skip to content

Commit

Permalink
Revert "fixed the qkv pack issue and removed hack"
Browse files Browse the repository at this point in the history
This reverts commit 58941ed.
  • Loading branch information
jtang10 committed Feb 4, 2025
1 parent bacc596 commit ddd07df
Showing 1 changed file with 39 additions and 1 deletion.
40 changes: 39 additions & 1 deletion flash_attn/flash_attn_triton_amd/bwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import triton.language as tl
from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, get_shape_from_layout, get_strides_from_layout, write_dropout_mask, create_dropout_mask

# TODO: move this into utils.py so it's shared among kernels
# NOTE: triton fails to import tl.constexprs so create them here for the file
tl_DROPOUT_USE_PYTORCH: tl.constexpr = DROPOUT_USE_PYTORCH
tl_DROPOUT_DUMP: tl.constexpr = DROPOUT_DUMP
Expand Down Expand Up @@ -596,21 +595,50 @@ def attention_prefill_backward_triton_impl(
ACTUAL_BLOCK_DMODEL = head_size

do = do.contiguous()
# NOTE: we might need to copy the output tensor if they are not continuous or have other issues
copy_back = {"dq": False, "dk": False, "dv": False}

# deal with dq
if dq is None:
if sequence_parallel:
dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype)
else:
dq = torch.zeros(q.shape, device=q.device, dtype=q.dtype)
else:
dq_og = dq
if (not dq.is_contiguous()):
dq = dq.contiguous()
copy_back["dq"] = True

if sequence_parallel:
dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype)
copy_back["dq"] = True
else:
# NOTE: the kernel does inplace accumlation so dq has to be zeros. This avoids the case where we are passed empty dq and it is not all zeros
dq.zero_()
stride_dq_all = dq.stride()[0]

# deal with dk, dv
if (dk is None) or (dv is None):
dk = torch.zeros_like(k)
dv = torch.zeros_like(v)
else:
# store og
dk_og = dk
dv_og = dv


if (not dk.is_contiguous()):
dk = dk.contiguous()
copy_back["dk"] = True

if (not dv.is_contiguous()):
dv = dv.contiguous()
copy_back["dv"] = True

if DEBUG:
print("copy_back:", copy_back)

# zero out
dq.zero_()
dk.zero_()
Expand Down Expand Up @@ -759,4 +787,14 @@ def attention_prefill_backward_triton_impl(
print("dropout_fraction bwd:", 1.0 - (dropout_mask.sum()/ dropout_mask.numel()).item())
write_dropout_mask(dropout_mask, "dropout_mask_bwd")

if copy_back["dq"]:
dq_og.copy_(dq)
dq = dq_og
if copy_back["dk"]:
dk_og.copy_(dk)
dk = dk_og
if copy_back["dv"]:
dv_og.copy_(dv)
dv = dv_og

return dq, dk, dv, delta, None, None

0 comments on commit ddd07df

Please sign in to comment.