Skip to content

Commit b9e3eaa

Browse files
committed
Rename _flash_attn_forward/_flash_attn_backward to _flash_dmattn_forward/_flash_dmattn_backward and update call sites
1 parent 53c34fa commit b9e3eaa

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

flash_dmattn/flash_dmattn_triton.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -888,7 +888,7 @@ def _bwd_kernel(
888888
)
889889

890890

891-
def _flash_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False):
891+
def _flash_dmattn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False):
892892
# shape constraints
893893
batch, seqlen_q, nheads, d = q.shape
894894
_, seqlen_k, nheads_k, _ = k.shape
@@ -980,7 +980,7 @@ def _flash_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False
980980
return o, lse, softmax_scale # softmax_scale could have been updated
981981

982982

983-
def _flash_attn_backward(
983+
def _flash_dmattn_backward(
984984
do, q, k, v, mask, bias, o, lse, softmax_scale=None, is_causal=False
985985
):
986986
# Make sure that the last dimension is contiguous
@@ -1195,7 +1195,7 @@ def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, is_causal=Fa
11951195
else:
11961196
attn_bias = torch.nn.functional.pad(attn_bias, [0, seqlen_k_rounded - attn_bias.shape[-1]])
11971197

1198-
o, lse, ctx.softmax_scale = _flash_attn_forward(
1198+
o, lse, ctx.softmax_scale = _flash_dmattn_forward(
11991199
query,
12001200
key,
12011201
value,
@@ -1218,7 +1218,7 @@ def backward(ctx, do):
12181218
if head_size_og % 8 != 0:
12191219
do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_og % 8])
12201220

1221-
dq, dk, dv, dbias = _flash_attn_backward(
1221+
dq, dk, dv, dbias = _flash_dmattn_backward(
12221222
do_padded,
12231223
query,
12241224
key,

0 commit comments

Comments
 (0)