Skip to content

Commit b7deeba

Browse files
committed
Guards bias grad return
Ensures backward only returns a bias gradient when bias exists, keeping the signature consistent for biasless calls.
1 parent b9e3eaa commit b7deeba

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

flash_dmattn/flash_dmattn_triton.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1150,7 +1150,7 @@ def _flash_dmattn_backward(
11501150
if bias.shape[0] == 1:
11511151
dbias_expanded = dbias_expanded.sum(dim=0, keepdim=True)
11521152
dbias.copy_(dbias_expanded)
1153-
return dq, dk, dv, dbias
1153+
return dq, dk, dv, dbias if has_bias else None
11541154

11551155

11561156
def maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:

0 commit comments

Comments
 (0)