Skip to content

Commit 0eae652

Browse files
committed
add contiguous and dtype for fa_varlen
1 parent ddb7530 commit 0eae652

File tree

2 files changed

+22
-11
lines changed

2 files changed

+22
-11
lines changed

torchacc/ops/flash_attn.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,14 @@ def forward(ctx, q, k, v, attention_mask, dropout_p, softmax_scale, causal, wind
171171
softmax_scale = q.shape[-1]**(-0.5)
172172
assert isinstance(window_size, tuple) and len(window_size) == 2
173173

174+
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
175+
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
176+
174177
softmax_lse, out, rng_state, cu_seqlens_q, cu_seqlens_k = torch_xla._XLAC._flash_attention_forward(
175178
q, k, v, attention_mask, alibi_slopes, dropout_p, softmax_scale, False, causal,
176179
window_size[0], window_size[1], return_softmax, None)
177180
out = out.to(q.dtype)
181+
178182
ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q,
179183
cu_seqlens_k, rng_state)
180184
ctx.dropout_p = dropout_p
@@ -189,6 +193,8 @@ def forward(ctx, q, k, v, attention_mask, dropout_p, softmax_scale, causal, wind
189193
def backward(ctx, dout, *args):
190194
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
191195

196+
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
197+
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
192198
dq, dk, dv, softmax_d = torch_xla._XLAC._flash_attention_backward(
193199
dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k,
194200
ctx.alibi_slopes, ctx.dropout_p,
@@ -198,6 +204,7 @@ def backward(ctx, dout, *args):
198204
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
199205
dk = dk[..., :dout.shape[-1]]
200206
dv = dv[..., :dout.shape[-1]]
207+
201208
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None
202209

203210

@@ -217,7 +224,7 @@ def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, window_size,
217224
dropout_p, softmax_scale, False, causal, window_size[0],
218225
window_size[1], return_softmax, None)
219226
out = out.to(q.dtype)
220-
227+
221228
ctx.save_for_backward(q, k, v, out, softmax_lse, rng_state)
222229
ctx.dropout_p = dropout_p
223230
ctx.softmax_scale = softmax_scale
@@ -327,6 +334,8 @@ def flash_attn_varlen_xla(
327334
):
328335
assert q.dtype in [torch.bfloat16,
329336
torch.float16], 'flash attention only supports fp16/bf16'
337+
if attention_mask.dtype != torch.int32:
338+
attention_mask.to(torch.in32)
330339
return FlashAttnVarlenXla.apply(
331340
q,
332341
k,

torchacc/utils/patch.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,16 @@ def patch_fa():
5959
`torchacc.ops.flash_attn_xla` and `torchacc.ops.flash_attn_varlen_xla`,
6060
and dynamically determine which one to use at runtime based on the input device.
6161
'''
62-
try:
63-
import flash_attn
64-
if hasattr(flash_attn.flash_attn_func, '__orig'):
65-
return
66-
flash_attn.flash_attn_func = _choose_functions(
67-
flash_attn.flash_attn_func, ops.flash_attn_xla)
68-
# flash_attn.flash_attn_varlen_func = _choose_functions(
69-
# flash_attn.flash_attn_varlen_func, ops.flash_attn_varlen_xla)
70-
except ImportError:
71-
logger.warn(f"Patch flash_attn failed.")
62+
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
63+
def _flash_attention_forward(
64+
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
65+
):
66+
if attention_mask is not None:
67+
return ops.flash_attn_varlen_xla(query_states.contiguous(), key_states.contiguous(), value_states.contiguous(), attention_mask=attention_mask.contiguous(), dropout_p=dropout, softmax_scale=softmax_scale)
68+
else:
69+
return ops.flash_attn_xla(query_states, key_states, value_states, dropout_p=dropout, softmax_scale=softmax_scale)
70+
71+
LlamaFlashAttention2._flash_attention_forward = _flash_attention_forward
7272

7373

7474
def patch_llama(use_flash_attn):
@@ -89,6 +89,8 @@ def update_causal_mask(
8989
past_key_values: Cache,
9090
output_attentions: bool,
9191
):
92+
if attention_mask is not None:
93+
return attention_mask
9294
return None
9395

9496
LlamaModel._update_causal_mask = update_causal_mask

0 commit comments

Comments
 (0)