Skip to content

Commit

Permalink
Fix: Change deterministic to None by default; use env var if None
Browse files Browse the repository at this point in the history
  • Loading branch information
junrae6454 committed Jul 24, 2024
1 parent b5d9075 commit 7285631
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def _flash_attention_forward(
sliding_window: Optional[int] = None,
use_top_left_mask: bool = False,
softcap: Optional[float] = None,
deterministic: bool = False,
deterministic: bool = None,
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
Expand Down Expand Up @@ -233,8 +233,9 @@ def _flash_attention_forward(
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}

if is_flash_attn_greater_or_equal("2.4.1"):
is_deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
flash_kwargs["deterministic"] = is_deterministic
if deterministic is None:
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
flash_kwargs["deterministic"] = deterministic

if softcap is not None:
flash_kwargs["softcap"] = softcap
Expand Down

0 comments on commit 7285631

Please sign in to comment.