Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attention implementation cannot work together with config in AutoModel #30298

Closed
2 of 4 tasks
hiyouga opened this issue Apr 17, 2024 · 2 comments · Fixed by #30299
Closed
2 of 4 tasks

Attention implementation cannot work together with config in AutoModel #30298

hiyouga opened this issue Apr 17, 2024 · 2 comments · Fixed by #30299

Comments

@hiyouga
Copy link
Contributor

hiyouga commented Apr 17, 2024

System Info

  • transformers version: 4.40.0.dev0
  • Platform: Linux-5.15.0-100-generic-x86_64-with-glibc2.35
  • Python version: 3.11.8
  • Huggingface_hub version: 0.21.4
  • Safetensors version: 0.4.2
  • Accelerate version: 0.28.0
  • PyTorch version (GPU?): 2.2.0+cu121 (True)

Who can help?

@younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Similar to #28038

We want to pass a model config to from_pretrained with an attn_implementation parameter. The attention type cannot be faithful to the one in the attn_implementation

from transformers import AutoConfig, AutoModelForCausalLM
model_name = "meta-llama/Llama-2-7b-hf"
config = AutoConfig.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, attn_implementation="eager")
print(model.config._attn_implementation)
# sdpa

Expected behavior

_attn_implementation should be eager

@hiyouga
Copy link
Contributor Author

hiyouga commented Apr 17, 2024

Given the logic below, we cannot enforce the model to use eager attention, since config._attn_implementation falls back to eager when config._attn_implementation_internal is None [1]. Hence, the if condition config._attn_implementation != kwarg_attn_imp cannot hold, and the config._attn_implementation_internal will be not affected, resulting a SDPA attention [2].

# In case one passes a config to `from_pretrained` + "attn_implementation"
# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs
# Please see: https://github.com/huggingface/transformers/issues/28038
# Overwrite `config._attn_implementation` by the one from the kwargs --> in auto-factory
# we pop attn_implementation from the kwargs but this handles the case where users
# passes manually the config to `from_pretrained`.
config = copy.deepcopy(config)
kwarg_attn_imp = kwargs.pop("attn_implementation", None)
if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp:
config._attn_implementation = kwarg_attn_imp
model_kwargs = kwargs

I think we should use config._attn_implementation_internal != kwarg_attn_imp instead

1:

@property
def _attn_implementation(self):
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
if hasattr(self, "_attn_implementation_internal"):
if self._attn_implementation_internal is None:
# `config.attn_implementation` should never be None, for backward compatibility.
return "eager"
else:
return self._attn_implementation_internal
else:
return "eager"
@_attn_implementation.setter
def _attn_implementation(self, value):
self._attn_implementation_internal = value

2:

elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available():
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
config = cls._check_and_enable_sdpa(
config,
hard_check_only=False if requested_attn_implementation is None else True,
)

@amyeroberts
Copy link
Collaborator

cc @fxmarty

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants