File tree Expand file tree Collapse file tree 1 file changed +1
-2
lines changed
axlearn/common/flash_attention Expand file tree Collapse file tree 1 file changed +1
-2
lines changed Original file line number Diff line number Diff line change @@ -224,7 +224,6 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
224
224
or mask .has_value ()
225
225
or jnp .float32 in (query .dtype , key .dtype , value .dtype )
226
226
or query .shape [1 ] != key .shape [1 ]
227
- or dropout_rate != 0.0
228
227
):
229
228
logging .warning ("Flash attention falling back to Triton GPU kernel." )
230
229
logging .warning ("explicit_bias after extracting mask: %s" , explicit_bias .value ())
@@ -253,7 +252,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
253
252
bias = explicit_bias .value (),
254
253
softmax_scale = softmax_scale ,
255
254
causal = causal .has_value (),
256
- dropout_rate = 0.0 ,
255
+ dropout_rate = dropout_rate ,
257
256
)
258
257
259
258
elif backend == "tpu" :
You can’t perform that action at this time.
0 commit comments