Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FA-2] Fix fa-2 issue when passing config to from_pretrained #28043

Merged
merged 12 commits into from
Dec 15, 2023

Conversation

younesbelkada
Copy link
Contributor

What does this PR do?

Fixes: #28038

Some users pass the config attribute to from_pretrained in order to modify model's hyperparameters to modify the undelrying architecture.

Note in previous versions before the attention refactor, it was possible to perform

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, AutoConfig

model_id = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
config = AutoConfig.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
    model_id, 
    config=config,
    torch_dtype=torch.bfloat16, 
    use_flash_attention_2="flash_attention_2",
    low_cpu_mem_usage=True,
)

Now users get an issue while trying to perform the operation above because the logic of handling model's config for fa2 changed a bit.

I propose a simple fix to mitigate this issue which is overwriting the attribute _attn_implementation of config only in case it has been passed by the user. I can confirm with this fix the snippet:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, AutoConfig

model_id = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
config = AutoConfig.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
    model_id, 
    config=config,
    torch_dtype=torch.bfloat16, 
    attn_implementation="flash_attention_2",
    low_cpu_mem_usage=True,
)

Works as expected as in the earlier versions of transformers

cc @amyeroberts @fxmarty

@younesbelkada
Copy link
Contributor Author

I am not 100% sure this approach is correct cc @fxmarty does this looks good to you (as you took care of the attention refactor) ?

@amyeroberts
Copy link
Collaborator

I'm a bit concerned about this - this is effectively a patch inside from_pretrained to add backwards compatibility that should have already been handled. The main question this raises for me is whether there other FA parameters/behaviours we need to check?

Is it still possible to pass in both use_flash_attention_2 and config to from_pretrained? If not, it's not clear to me from the diff how this is addressed: use_flash_attention_2 isn't handled from the model kwargs.

Didn't do a final review on the recent refactor, so might be missing something. It's also not clear to me from just this PR why passing in a config would change whether or not I can pass in attn_implementation.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@fxmarty
Copy link
Contributor

fxmarty commented Dec 14, 2023

Good catch

@amyeroberts Even witout a fix, use_flash_attention_2=True along with a provided config IMO works thanks to

if use_flash_attention_2:
logger.warning_once(
'The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.'
)
config._attn_implementation = "flash_attention_2"

Comment on lines 2960 to 2965
# In case one passes a config to `from_pretrained` + "attn_implementation"
# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs
# Please see: https://github.com/huggingface/transformers/issues/28038
if config is not None:
config._attn_implementation = model_kwargs.pop("attn_implementation", None)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# In case one passes a config to `from_pretrained` + "attn_implementation"
# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs
# Please see: https://github.com/huggingface/transformers/issues/28038
if config is not None:
config._attn_implementation = model_kwargs.pop("attn_implementation", None)
# In case one passes a config to `from_pretrained` + "attn_implementation"
# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs
# Please see: https://github.com/huggingface/transformers/issues/28038
config._attn_implementation = model_kwargs.pop("attn_implementation", None)

config is a PretrainedConfig here.

Copy link
Contributor

@fxmarty fxmarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, some tests are failing though

Comment on lines 2967 to 2969
if kwargs.get("attn_implementation", None) is not None and getattr(
config, "_attn_implementation", None
) != kwargs.get("attn_implementation", None):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This handles the case where users pass a config object to from_pretrained. Note AutoModelxxx.from_pretrained pops the attn_impelmentation from the kwargs in case one do not pass a config, but doesn't if we pass the config.

Therefore this handles this corner case as well (passing a config --> attn_implementation does not get popped + attn_implementation through from_pretrained kwargs). If that's the case we should over-write the config's attn_impelmentation by the one from the kwargs assuming the user knows what they are doing.

https://github.com/huggingface/transformers/blob/main/src/transformers/models/auto/auto_factory.py#L516-L540

@younesbelkada
Copy link
Contributor Author

cc @amyeroberts @fxmarty requesting another round of review!

@@ -1823,6 +1823,16 @@ def test_error_no_flash_available(self):

self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))

def test_error_no_flash_available_with_config(self):
Copy link
Contributor

@fxmarty fxmarty Dec 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test for e.g. llama + passing a config + attn_implementation="flash_attention_2 that the correct class is loaded?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean without AutoModel?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean a test for an architecture that do support FA2, passing both a config + attn_implementation="flash_attention_2"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

younesbelkada and others added 2 commits December 14, 2023 18:11
Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing and @fxmarty for clarifying the case for use_flash_attention!

Just some nits

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
# passes manually the config to `from_pretrained`.
config = copy.deepcopy(config)

if kwargs.get("attn_implementation", None) is not None and config._attn_implementation != kwargs.get(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to get here or pop from the kwargs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I think pop would work best here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

younesbelkada and others added 3 commits December 14, 2023 19:28
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
@younesbelkada younesbelkada merged commit 1e20931 into huggingface:main Dec 15, 2023
21 checks passed
@younesbelkada younesbelkada deleted the fix-fa-2-from-config branch December 15, 2023 10:08
iantbutler01 pushed a commit to BismuthCloud/transformers that referenced this pull request Dec 16, 2023
…uggingface#28043)

* fix fa-2 issue

* fix test

* Update src/transformers/modeling_utils.py

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>

* clenaer fix

* up

* add more robust tests

* Update src/transformers/modeling_utils.py

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>

* fixup

* Update src/transformers/modeling_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* pop

* add test

---------

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
amyeroberts added a commit that referenced this pull request Dec 18, 2023
…28043)

* fix fa-2 issue

* fix test

* Update src/transformers/modeling_utils.py

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>

* clenaer fix

* up

* add more robust tests

* Update src/transformers/modeling_utils.py

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>

* fixup

* Update src/transformers/modeling_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* pop

* add test

---------

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
staghado pushed a commit to staghado/transformers that referenced this pull request Jan 15, 2024
…uggingface#28043)

* fix fa-2 issue

* fix test

* Update src/transformers/modeling_utils.py

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>

* clenaer fix

* up

* add more robust tests

* Update src/transformers/modeling_utils.py

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>

* fixup

* Update src/transformers/modeling_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* pop

* add test

---------

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Cannot specify config and attn_implementation simultaneously
4 participants