Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FA-2] Fix fa-2 issue when passing config to from_pretrained #28043

Merged
merged 12 commits into from
Dec 15, 2023
12 changes: 12 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2955,6 +2955,18 @@ def from_pretrained(
**kwargs,
)
else:
# In case one passes a config to `from_pretrained` + "attn_implementation"
# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs
# Please see: https://github.com/huggingface/transformers/issues/28038

# Overwrite `config._attn_implementation` by the one from the kwargs --> in auto-factory
# we pop attn_implementation from the kwargs but this handles the case where users
# passes manually the config to `from_pretrained`.
config = copy.deepcopy(config)

kwarg_attn_imp = kwargs.pop("attn_implementation", None)
if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp:
config._attn_implementation = kwarg_attn_imp
model_kwargs = kwargs

quantizer = None
Expand Down
25 changes: 25 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1823,6 +1823,16 @@ def test_error_no_flash_available(self):

self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))

def test_error_no_flash_available_with_config(self):
Copy link
Contributor

@fxmarty fxmarty Dec 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test for e.g. llama + passing a config + attn_implementation="flash_attention_2 that the correct class is loaded?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean without AutoModel?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean a test for an architecture that do support FA2, passing both a config + attn_implementation="flash_attention_2"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

with self.assertRaises(ValueError) as cm:
config = AutoConfig.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel")

_ = AutoModel.from_pretrained(
"hf-tiny-model-private/tiny-random-MCTCTModel", config=config, attn_implementation="flash_attention_2"
)

self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))

def test_error_wrong_attn_implementation(self):
with self.assertRaises(ValueError) as cm:
_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="foo")
Expand All @@ -1840,6 +1850,21 @@ def test_not_available_flash(self):

self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))

def test_not_available_flash_with_config(self):
if is_flash_attn_2_available():
self.skipTest("Please uninstall flash-attn package to run test_not_available_flash")

config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTBigCodeModel")

with self.assertRaises(ImportError) as cm:
_ = AutoModel.from_pretrained(
"hf-internal-testing/tiny-random-GPTBigCodeModel",
config=config,
attn_implementation="flash_attention_2",
)

self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))

def test_not_available_sdpa(self):
if is_torch_sdpa_available():
self.skipTest("This test requires torch<=2.0")
Expand Down
Loading