Skip to content

Commit 58dcf33

Browse files
authored
Enable cudnn dropout (apple#913)
1 parent ae855ed commit 58dcf33

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

axlearn/common/flash_attention/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,6 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
224224
or mask.has_value()
225225
or jnp.float32 in (query.dtype, key.dtype, value.dtype)
226226
or query.shape[1] != key.shape[1]
227-
or dropout_rate != 0.0
228227
):
229228
logging.warning("Flash attention falling back to Triton GPU kernel.")
230229
logging.warning("explicit_bias after extracting mask: %s", explicit_bias.value())
@@ -253,7 +252,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
253252
bias=explicit_bias.value(),
254253
softmax_scale=softmax_scale,
255254
causal=causal.has_value(),
256-
dropout_rate=0.0,
255+
dropout_rate=dropout_rate,
257256
)
258257

259258
elif backend == "tpu":

0 commit comments

Comments
 (0)