@@ -171,10 +171,14 @@ def forward(ctx, q, k, v, attention_mask, dropout_p, softmax_scale, causal, wind
171
171
softmax_scale = q .shape [- 1 ]** (- 0.5 )
172
172
assert isinstance (window_size , tuple ) and len (window_size ) == 2
173
173
174
+ maybe_contiguous = lambda x : x .contiguous () if x .stride (- 1 ) != 1 else x
175
+ q , k , v = [maybe_contiguous (x ) for x in (q , k , v )]
176
+
174
177
softmax_lse , out , rng_state , cu_seqlens_q , cu_seqlens_k = torch_xla ._XLAC ._flash_attention_forward (
175
178
q , k , v , attention_mask , alibi_slopes , dropout_p , softmax_scale , False , causal ,
176
179
window_size [0 ], window_size [1 ], return_softmax , None )
177
180
out = out .to (q .dtype )
181
+
178
182
ctx .save_for_backward (q , k , v , out , softmax_lse , cu_seqlens_q ,
179
183
cu_seqlens_k , rng_state )
180
184
ctx .dropout_p = dropout_p
@@ -189,6 +193,8 @@ def forward(ctx, q, k, v, attention_mask, dropout_p, softmax_scale, causal, wind
189
193
def backward (ctx , dout , * args ):
190
194
q , k , v , out , softmax_lse , cu_seqlens_q , cu_seqlens_k , rng_state = ctx .saved_tensors
191
195
196
+ maybe_contiguous = lambda x : x .contiguous () if x .stride (- 1 ) != 1 else x
197
+ dout , q , k , v , out = [maybe_contiguous (x ) for x in (dout , q , k , v , out )]
192
198
dq , dk , dv , softmax_d = torch_xla ._XLAC ._flash_attention_backward (
193
199
dout , q , k , v , out , softmax_lse , cu_seqlens_q , cu_seqlens_k ,
194
200
ctx .alibi_slopes , ctx .dropout_p ,
@@ -198,6 +204,7 @@ def backward(ctx, dout, *args):
198
204
dq = dq [..., :dout .shape [- 1 ]] # We could have padded the head dimension
199
205
dk = dk [..., :dout .shape [- 1 ]]
200
206
dv = dv [..., :dout .shape [- 1 ]]
207
+
201
208
return dq , dk , dv , None , None , None , None , None , None , None , None , None , None , None
202
209
203
210
@@ -217,7 +224,7 @@ def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, window_size,
217
224
dropout_p , softmax_scale , False , causal , window_size [0 ],
218
225
window_size [1 ], return_softmax , None )
219
226
out = out .to (q .dtype )
220
-
227
+
221
228
ctx .save_for_backward (q , k , v , out , softmax_lse , rng_state )
222
229
ctx .dropout_p = dropout_p
223
230
ctx .softmax_scale = softmax_scale
@@ -327,6 +334,8 @@ def flash_attn_varlen_xla(
327
334
):
328
335
assert q .dtype in [torch .bfloat16 ,
329
336
torch .float16 ], 'flash attention only supports fp16/bf16'
337
+ if attention_mask .dtype != torch .int32 :
338
+ attention_mask .to (torch .in32 )
330
339
return FlashAttnVarlenXla .apply (
331
340
q ,
332
341
k ,
0 commit comments