Skip to content

Commit

Permalink
Add QKVPacked and SPMDFlashAttnVarlenXla support for attention_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
Seventeen17 committed Sep 23, 2024
1 parent 2f37c34 commit ea6f397
Showing 1 changed file with 35 additions and 40 deletions.
75 changes: 35 additions & 40 deletions torchacc/ops/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ class FlashAttnVarlenQKVPackedXla(torch.autograd.Function):
def forward(
ctx,
qkv,
cu_seqlens,
max_seqlen,
attention_mask,
dropout_p,
softmax_scale,
causal,
Expand All @@ -24,17 +23,20 @@ def forward(
softmax_scale = qkv.shape[-1]**(-0.5)

assert isinstance(window_size, tuple) and len(window_size) == 2
assert attention_mask is not None

softmax_lse, out, rng_state = torch_xla._XLAC._flash_attention_forward(
qkv[:, 0], qkv[:, 1], qkv[:, 2], cu_seqlens, cu_seqlens,
alibi_slopes, max_seqlen, max_seqlen, dropout_p, softmax_scale,
q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]

softmax_lse, out, rng_state, cu_seqlens_q, cu_seqlens_k = torch_xla._XLAC._flash_attention_forward(
q, k, v, attention_mask, alibi_slopes, dropout_p, softmax_scale,
False, causal, window_size[0], window_size[1], return_softmax, None)
out = out.to(qkv.dtype)

ctx.save_for_backward(qkv, out, softmax_lse, cu_seqlens, rng_state)
ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q,
cu_seqlens_k, rng_state)
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen
ctx.max_seqlen_k = max_seqlen
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
Expand All @@ -44,13 +46,14 @@ def forward(

@staticmethod
def backward(ctx, dout, *args):
qkv, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d = torch_xla._XLAC._flash_attention_backward(
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse, cu_seqlens,
cu_seqlens, ctx.alibi_slopes, ctx.max_seqlen_q, ctx.max_seqlen_k,
ctx.dropout_p, ctx.softmax_scale, False, ctx.causal,
ctx.window_size[0], ctx.window_size[1], ctx.deterministic, None,
rng_state)
dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k,
ctx.alibi_slopes, ctx.dropout_p, ctx.softmax_scale, False,
ctx.causal, ctx.window_size[0], ctx.window_size[1],
ctx.deterministic, None, rng_state)

dqkv = torch.stack([dq, dk, dv], dim=1)
return dqkv, None, None, None, None, None, None, None, None, None
Expand All @@ -63,10 +66,7 @@ def forward(ctx,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
attention_mask,
dropout_p,
softmax_scale,
causal,
Expand Down Expand Up @@ -98,11 +98,14 @@ def forward(ctx,
v = xs.enable_manual_sharding(
v, partition_spec, mesh=mesh).global_tensor

maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]

with torch.no_grad():
softmax_lse, out, rng_state = torch_xla._XLAC._flash_attention_forward(
q, k, v, cu_seqlens_q, cu_seqlens_k, alibi_slopes, max_seqlen_q,
max_seqlen_k, dropout_p, softmax_scale, False, causal,
window_size[0], window_size[1], return_softmax, None)
softmax_lse, out, rng_state, cu_seqlens_q, cu_seqlens_k = torch_xla._XLAC._flash_attention_forward(
q, k, v, attention_mask, alibi_slopes, dropout_p, softmax_scale,
False, causal, window_size[0], window_size[1], return_softmax,
None)

if partition_spec is not None:
out = xs.disable_manual_sharding(
Expand All @@ -113,8 +116,6 @@ def forward(ctx,
ctx.save_for_backward(full_q, full_k, full_v, out, softmax_lse,
cu_seqlens_q, cu_seqlens_k, rng_state)
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
Expand All @@ -141,11 +142,13 @@ def backward(ctx, dout, *args):
out = xs.enable_manual_sharding(
out, partition_spec, mesh=mesh).global_tensor

maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d = torch_xla._XLAC._flash_attention_backward(
dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k,
ctx.alibi_slopes, ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p,
ctx.softmax_scale, False, ctx.causal, ctx.window_size[0],
ctx.window_size[1], ctx.deterministic, None, rng_state)
ctx.alibi_slopes, ctx.dropout_p, ctx.softmax_scale, False,
ctx.causal, ctx.window_size[0], ctx.window_size[1],
ctx.deterministic, None, rng_state)

if partition_spec is not None:
dq = xs.disable_manual_sharding(
Expand Down Expand Up @@ -252,8 +255,7 @@ def backward(ctx, dout, *args):

def flash_attn_varlen_qkvpacked_xla(
qkv,
cu_seqlens,
max_seqlen,
attention_mask,
dropout_p=0.0,
softmax_scale=None,
causal=False,
Expand All @@ -266,8 +268,7 @@ def flash_attn_varlen_qkvpacked_xla(
], 'flash attention only supports fp16/bf16'
return FlashAttnVarlenQKVPackedXla.apply(
qkv,
cu_seqlens,
max_seqlen,
attention_mask,
dropout_p,
softmax_scale,
causal,
Expand All @@ -282,10 +283,7 @@ def spmd_flash_attn_varlen_xla(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
attention_mask,
dropout_p=0.0,
softmax_scale=None,
causal=False,
Expand All @@ -302,10 +300,7 @@ def spmd_flash_attn_varlen_xla(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
attention_mask,
dropout_p,
softmax_scale,
causal,
Expand Down Expand Up @@ -362,7 +357,7 @@ def flash_attn_xla(
deterministic=False,
return_attn_probs=False,
):
assert q.dtype in [torch.bfloat16,
assert q.dtype in [torch.bfloaft16,
torch.float16], 'flash attention only supports fp16/bf16'
return FlashAttnXla.apply(
q,
Expand Down

0 comments on commit ea6f397

Please sign in to comment.