Skip to content

Commit

Permalink
add contiguous and dtype for fa_varlen
Browse files Browse the repository at this point in the history
  • Loading branch information
Seventeen17 committed Sep 10, 2024
1 parent ddb7530 commit 0eae652
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 11 deletions.
11 changes: 10 additions & 1 deletion torchacc/ops/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,14 @@ def forward(ctx, q, k, v, attention_mask, dropout_p, softmax_scale, causal, wind
softmax_scale = q.shape[-1]**(-0.5)
assert isinstance(window_size, tuple) and len(window_size) == 2

maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]

softmax_lse, out, rng_state, cu_seqlens_q, cu_seqlens_k = torch_xla._XLAC._flash_attention_forward(
q, k, v, attention_mask, alibi_slopes, dropout_p, softmax_scale, False, causal,
window_size[0], window_size[1], return_softmax, None)
out = out.to(q.dtype)

ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q,
cu_seqlens_k, rng_state)
ctx.dropout_p = dropout_p
Expand All @@ -189,6 +193,8 @@ def forward(ctx, q, k, v, attention_mask, dropout_p, softmax_scale, causal, wind
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors

maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d = torch_xla._XLAC._flash_attention_backward(
dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k,
ctx.alibi_slopes, ctx.dropout_p,
Expand All @@ -198,6 +204,7 @@ def backward(ctx, dout, *args):
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., :dout.shape[-1]]
dv = dv[..., :dout.shape[-1]]

return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None


Expand All @@ -217,7 +224,7 @@ def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, window_size,
dropout_p, softmax_scale, False, causal, window_size[0],
window_size[1], return_softmax, None)
out = out.to(q.dtype)

ctx.save_for_backward(q, k, v, out, softmax_lse, rng_state)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
Expand Down Expand Up @@ -327,6 +334,8 @@ def flash_attn_varlen_xla(
):
assert q.dtype in [torch.bfloat16,
torch.float16], 'flash attention only supports fp16/bf16'
if attention_mask.dtype != torch.int32:
attention_mask.to(torch.in32)
return FlashAttnVarlenXla.apply(
q,
k,
Expand Down
22 changes: 12 additions & 10 deletions torchacc/utils/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,16 @@ def patch_fa():
`torchacc.ops.flash_attn_xla` and `torchacc.ops.flash_attn_varlen_xla`,
and dynamically determine which one to use at runtime based on the input device.
'''
try:
import flash_attn
if hasattr(flash_attn.flash_attn_func, '__orig'):
return
flash_attn.flash_attn_func = _choose_functions(
flash_attn.flash_attn_func, ops.flash_attn_xla)
# flash_attn.flash_attn_varlen_func = _choose_functions(
# flash_attn.flash_attn_varlen_func, ops.flash_attn_varlen_xla)
except ImportError:
logger.warn(f"Patch flash_attn failed.")
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
if attention_mask is not None:
return ops.flash_attn_varlen_xla(query_states.contiguous(), key_states.contiguous(), value_states.contiguous(), attention_mask=attention_mask.contiguous(), dropout_p=dropout, softmax_scale=softmax_scale)
else:
return ops.flash_attn_xla(query_states, key_states, value_states, dropout_p=dropout, softmax_scale=softmax_scale)

LlamaFlashAttention2._flash_attention_forward = _flash_attention_forward


def patch_llama(use_flash_attn):
Expand All @@ -89,6 +89,8 @@ def update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
if attention_mask is not None:
return attention_mask
return None

LlamaModel._update_causal_mask = update_causal_mask

0 comments on commit 0eae652

Please sign in to comment.