Skip to content

Commit 9f18cc6

Browse files
authored
Fix SDPA dispatch & make SDPA CI compatible with torch<2.1.1 (#27940)
fix sdpa dispatch
1 parent 7ea21f1 commit 9f18cc6

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

src/transformers/modeling_utils.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -1244,6 +1244,7 @@ def _autoset_attn_implementation(
12441244
# Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user.
12451245
# The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager").
12461246
# The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
1247+
requested_attn_implementation = None
12471248
if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None:
12481249
if config._attn_implementation != "flash_attention_2" and use_flash_attention_2:
12491250
raise ValueError(
@@ -1260,9 +1261,7 @@ def _autoset_attn_implementation(
12601261
raise ValueError(message + ".")
12611262

12621263
# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
1263-
hard_check_only = True
1264-
else:
1265-
hard_check_only = False
1264+
requested_attn_implementation = config._attn_implementation_internal
12661265

12671266
if use_flash_attention_2:
12681267
logger.warning_once(
@@ -1275,13 +1274,15 @@ def _autoset_attn_implementation(
12751274
config,
12761275
torch_dtype=torch_dtype,
12771276
device_map=device_map,
1278-
hard_check_only=hard_check_only,
1277+
hard_check_only=False,
12791278
check_device_map=check_device_map,
12801279
)
1281-
elif cls._supports_sdpa or config._attn_implementation == "sdpa":
1280+
elif requested_attn_implementation in [None, "sdpa"]:
12821281
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
1283-
config = cls._check_and_enable_sdpa(config, hard_check_only=hard_check_only)
1284-
elif not hard_check_only:
1282+
config = cls._check_and_enable_sdpa(
1283+
config, hard_check_only=False if requested_attn_implementation is None else True
1284+
)
1285+
else:
12851286
config._attn_implementation = "eager"
12861287

12871288
return config

tests/test_modeling_common.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
is_flax_available,
8484
is_tf_available,
8585
is_torch_fx_available,
86+
is_torch_sdpa_available,
8687
)
8788
from transformers.utils.generic import ModelOutput
8889

@@ -778,7 +779,7 @@ def _create_and_check_torchscript(self, config, inputs_dict):
778779
configs_no_init.torchscript = True
779780
for model_class in self.all_model_classes:
780781
for attn_implementation in ["eager", "sdpa"]:
781-
if attn_implementation == "sdpa" and not model_class._supports_sdpa:
782+
if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()):
782783
continue
783784

784785
configs_no_init._attn_implementation = attn_implementation

0 commit comments

Comments
 (0)