Skip to content

Commit 071ab90

Browse files
committed
Disables tensor sanitization in attention ops
Removes NaN/Inf zeroing in forward and backward paths to avoid masking numerical issues and to reduce overhead. Preserves raw outputs/gradients for easier debugging and correctness checks; callers can sanitize if required.
1 parent 059776d commit 071ab90

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

flash_dmattn/flash_dmattn_interface.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def _flash_dmattn_forward(
9595
softcap,
9696
return_softmax,
9797
)
98-
_sanitize_tensors(out, nan=0.0, posinf=0.0, neginf=0.0)
98+
# _sanitize_tensors(out, nan=0.0, posinf=0.0, neginf=0.0)
9999
return out, softmax_lse, S_dmask
100100

101101

@@ -163,7 +163,7 @@ def _flash_dmattn_varlen_forward(
163163
softcap,
164164
return_softmax,
165165
)
166-
_sanitize_tensors(out, nan=0.0, posinf=0.0, neginf=0.0)
166+
# _sanitize_tensors(out, nan=0.0, posinf=0.0, neginf=0.0)
167167
return out, softmax_lse, S_dmask
168168

169169

@@ -247,7 +247,7 @@ def _flash_dmattn_backward(
247247
softcap,
248248
deterministic,
249249
)
250-
_sanitize_tensors(dq, dk, dv, dbias, nan=0.0, posinf=0.0, neginf=0.0)
250+
# _sanitize_tensors(dq, dk, dv, dbias, nan=0.0, posinf=0.0, neginf=0.0)
251251
return softmax_d
252252

253253

@@ -335,7 +335,7 @@ def _flash_dmattn_varlen_backward(
335335
softcap,
336336
deterministic,
337337
)
338-
_sanitize_tensors(dq, dk, dv, nan=0.0, posinf=0.0, neginf=0.0)
338+
# _sanitize_tensors(dq, dk, dv, nan=0.0, posinf=0.0, neginf=0.0)
339339
return softmax_d
340340

341341

0 commit comments

Comments
 (0)