diff --git a/torchacc/ops/flash_attn.py b/torchacc/ops/flash_attn.py index 50bbdbe..6524260 100644 --- a/torchacc/ops/flash_attn.py +++ b/torchacc/ops/flash_attn.py @@ -171,10 +171,14 @@ def forward(ctx, q, k, v, attention_mask, dropout_p, softmax_scale, causal, wind softmax_scale = q.shape[-1]**(-0.5) assert isinstance(window_size, tuple) and len(window_size) == 2 + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + softmax_lse, out, rng_state, cu_seqlens_q, cu_seqlens_k = torch_xla._XLAC._flash_attention_forward( q, k, v, attention_mask, alibi_slopes, dropout_p, softmax_scale, False, causal, window_size[0], window_size[1], return_softmax, None) out = out.to(q.dtype) + ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state) ctx.dropout_p = dropout_p @@ -189,6 +193,8 @@ def forward(ctx, q, k, v, attention_mask, dropout_p, softmax_scale, causal, wind def backward(ctx, dout, *args): q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] dq, dk, dv, softmax_d = torch_xla._XLAC._flash_attention_backward( dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, ctx.alibi_slopes, ctx.dropout_p, @@ -198,6 +204,7 @@ def backward(ctx, dout, *args): dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension dk = dk[..., :dout.shape[-1]] dv = dv[..., :dout.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None @@ -217,7 +224,7 @@ def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, dropout_p, softmax_scale, False, causal, window_size[0], window_size[1], return_softmax, None) out = out.to(q.dtype) - + ctx.save_for_backward(q, k, v, out, softmax_lse, rng_state) ctx.dropout_p = dropout_p ctx.softmax_scale = softmax_scale @@ -327,6 +334,8 @@ def flash_attn_varlen_xla( ): assert q.dtype in [torch.bfloat16, torch.float16], 'flash attention only supports fp16/bf16' + if attention_mask.dtype != torch.int32: + attention_mask.to(torch.in32) return FlashAttnVarlenXla.apply( q, k, diff --git a/torchacc/utils/patch.py b/torchacc/utils/patch.py index c2d2a4e..cd8a7ce 100644 --- a/torchacc/utils/patch.py +++ b/torchacc/utils/patch.py @@ -59,16 +59,16 @@ def patch_fa(): `torchacc.ops.flash_attn_xla` and `torchacc.ops.flash_attn_varlen_xla`, and dynamically determine which one to use at runtime based on the input device. ''' - try: - import flash_attn - if hasattr(flash_attn.flash_attn_func, '__orig'): - return - flash_attn.flash_attn_func = _choose_functions( - flash_attn.flash_attn_func, ops.flash_attn_xla) - # flash_attn.flash_attn_varlen_func = _choose_functions( - # flash_attn.flash_attn_varlen_func, ops.flash_attn_varlen_xla) - except ImportError: - logger.warn(f"Patch flash_attn failed.") + from transformers.models.llama.modeling_llama import LlamaFlashAttention2 + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + if attention_mask is not None: + return ops.flash_attn_varlen_xla(query_states.contiguous(), key_states.contiguous(), value_states.contiguous(), attention_mask=attention_mask.contiguous(), dropout_p=dropout, softmax_scale=softmax_scale) + else: + return ops.flash_attn_xla(query_states, key_states, value_states, dropout_p=dropout, softmax_scale=softmax_scale) + + LlamaFlashAttention2._flash_attention_forward = _flash_attention_forward def patch_llama(use_flash_attn): @@ -89,6 +89,8 @@ def update_causal_mask( past_key_values: Cache, output_attentions: bool, ): + if attention_mask is not None: + return attention_mask return None LlamaModel._update_causal_mask = update_causal_mask