-
Notifications
You must be signed in to change notification settings - Fork 26.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FA-2
] Fix fa-2 issue when passing config
to from_pretrained
#28043
[FA-2
] Fix fa-2 issue when passing config
to from_pretrained
#28043
Conversation
I am not 100% sure this approach is correct cc @fxmarty does this looks good to you (as you took care of the attention refactor) ? |
I'm a bit concerned about this - this is effectively a patch inside Is it still possible to pass in both Didn't do a final review on the recent refactor, so might be missing something. It's also not clear to me from just this PR why passing in a config would change whether or not I can pass in |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Good catch @amyeroberts Even witout a fix, transformers/src/transformers/modeling_utils.py Lines 1295 to 1299 in 050e0b4
|
src/transformers/modeling_utils.py
Outdated
# In case one passes a config to `from_pretrained` + "attn_implementation" | ||
# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs | ||
# Please see: https://github.com/huggingface/transformers/issues/28038 | ||
if config is not None: | ||
config._attn_implementation = model_kwargs.pop("attn_implementation", None) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# In case one passes a config to `from_pretrained` + "attn_implementation" | |
# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs | |
# Please see: https://github.com/huggingface/transformers/issues/28038 | |
if config is not None: | |
config._attn_implementation = model_kwargs.pop("attn_implementation", None) | |
# In case one passes a config to `from_pretrained` + "attn_implementation" | |
# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs | |
# Please see: https://github.com/huggingface/transformers/issues/28038 | |
config._attn_implementation = model_kwargs.pop("attn_implementation", None) |
config
is a PretrainedConfig
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, some tests are failing though
Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
src/transformers/modeling_utils.py
Outdated
if kwargs.get("attn_implementation", None) is not None and getattr( | ||
config, "_attn_implementation", None | ||
) != kwargs.get("attn_implementation", None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This handles the case where users pass a config
object to from_pretrained
. Note AutoModelxxx.from_pretrained
pops the attn_impelmentation
from the kwargs in case one do not pass a config, but doesn't if we pass the config.
Therefore this handles this corner case as well (passing a config --> attn_implementation
does not get popped + attn_implementation
through from_pretrained kwargs). If that's the case we should over-write the config's attn_impelmentation by the one from the kwargs assuming the user knows what they are doing.
cc @amyeroberts @fxmarty requesting another round of review! |
@@ -1823,6 +1823,16 @@ def test_error_no_flash_available(self): | |||
|
|||
self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception)) | |||
|
|||
def test_error_no_flash_available_with_config(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test for e.g. llama + passing a config + attn_implementation="flash_attention_2
that the correct class is loaded?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean without AutoModel
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean a test for an architecture that do support FA2, passing both a config + attn_implementation="flash_attention_2"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing and @fxmarty for clarifying the case for use_flash_attention
!
Just some nits
src/transformers/modeling_utils.py
Outdated
# passes manually the config to `from_pretrained`. | ||
config = copy.deepcopy(config) | ||
|
||
if kwargs.get("attn_implementation", None) is not None and config._attn_implementation != kwargs.get( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to get here or pop from the kwargs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, I think pop
would work best here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
…uggingface#28043) * fix fa-2 issue * fix test * Update src/transformers/modeling_utils.py Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> * clenaer fix * up * add more robust tests * Update src/transformers/modeling_utils.py Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> * fixup * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * pop * add test --------- Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
…28043) * fix fa-2 issue * fix test * Update src/transformers/modeling_utils.py Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> * clenaer fix * up * add more robust tests * Update src/transformers/modeling_utils.py Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> * fixup * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * pop * add test --------- Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
…uggingface#28043) * fix fa-2 issue * fix test * Update src/transformers/modeling_utils.py Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> * clenaer fix * up * add more robust tests * Update src/transformers/modeling_utils.py Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> * fixup * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * pop * add test --------- Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
What does this PR do?
Fixes: #28038
Some users pass the
config
attribute tofrom_pretrained
in order to modify model's hyperparameters to modify the undelrying architecture.Note in previous versions before the attention refactor, it was possible to perform
Now users get an issue while trying to perform the operation above because the logic of handling model's config for fa2 changed a bit.
I propose a simple fix to mitigate this issue which is overwriting the attribute
_attn_implementation
ofconfig
only in case it has been passed by the user. I can confirm with this fix the snippet:Works as expected as in the earlier versions of transformers
cc @amyeroberts @fxmarty