@@ -888,7 +888,7 @@ def _bwd_kernel(
888888 )
889889
890890
891- def _flash_attn_forward (q , k , v , mask , bias , softmax_scale = None , is_causal = False ):
891+ def _flash_dmattn_forward (q , k , v , mask , bias , softmax_scale = None , is_causal = False ):
892892 # shape constraints
893893 batch , seqlen_q , nheads , d = q .shape
894894 _ , seqlen_k , nheads_k , _ = k .shape
@@ -980,7 +980,7 @@ def _flash_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False
980980 return o , lse , softmax_scale # softmax_scale could have been updated
981981
982982
983- def _flash_attn_backward (
983+ def _flash_dmattn_backward (
984984 do , q , k , v , mask , bias , o , lse , softmax_scale = None , is_causal = False
985985):
986986 # Make sure that the last dimension is contiguous
@@ -1195,7 +1195,7 @@ def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, is_causal=Fa
11951195 else :
11961196 attn_bias = torch .nn .functional .pad (attn_bias , [0 , seqlen_k_rounded - attn_bias .shape [- 1 ]])
11971197
1198- o , lse , ctx .softmax_scale = _flash_attn_forward (
1198+ o , lse , ctx .softmax_scale = _flash_dmattn_forward (
11991199 query ,
12001200 key ,
12011201 value ,
@@ -1218,7 +1218,7 @@ def backward(ctx, do):
12181218 if head_size_og % 8 != 0 :
12191219 do_padded = torch .nn .functional .pad (do , [0 , 8 - head_size_og % 8 ])
12201220
1221- dq , dk , dv , dbias = _flash_attn_backward (
1221+ dq , dk , dv , dbias = _flash_dmattn_backward (
12221222 do_padded ,
12231223 query ,
12241224 key ,
0 commit comments