-
Notifications
You must be signed in to change notification settings - Fork 31.3k
🚨🚨[core] Completely rewrite the masking logic for all attentions #37866
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
53ca556 to
ce42aa7
Compare
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very vey nice!
One thing I want to consider is to rather call the sliding, causal and chuncking directly in the modeling.
For example:
- llama only need
causal_mask, under the hood the causal mask should do an and with sddpa or cflash or flex - gemma need sliding_causal: same
- llama4 needs chuncked causal
I want the modeling to call an explicit function, rather than the mega general one!
This would keep our philosophy, as we don't want too general stuff hapenning when not needed (ex: llama should never care about sliding in codepathes)
Also misssing doc about how to add a new func!
|
Wow!!!!!!!! 🚀 This PR seems worth a manually full CI. Ping me when it's time you think this PR is ready to trigger CI. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Damn nice
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
REview for the core logic, IMO can be simplified! BUt the modeling part is absolutely perfect!
For the visualization, I'll see how we could just overwrite the repr without affecting other operations!
vasqu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a quick question on this refactor: If I understand the code correctly, then the focus is currently on causal masks only, correct?
Would be nice to add a non-causal alternative which should only use a padding mask and expand respectively to the q_len and kv_len. That's more food for thought :D I dont want to make this PR even harder than it is.
0b6bbe5 to
7fc4f91
Compare
|
For now it's mostly on causal masks because they are the one we need, but the idea is that it can be extended super easily from a set of mask primitives! |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mega nice!
TODO before merging:
- move the
causal_mask_mappingto a class attribute! - show example of how to register a new function, but minimal (without sdpa correction for example)
- Make sure full graph training is not broken maybe? or at least fa2 training
That should be i
28e232c to
5170e9d
Compare
dee568c to
4a2e906
Compare
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's go!
| def my_new_sdpa_mask(*args, **kwargs): | ||
| print("I just entered the attention mask computation") | ||
| return sdpa_mask(*args, **kwargs) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's rather show how to do something like the paligemma or document masking here, something relevant!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those are a bit different, it's modifying the mask pattern vs adding a new mask format for the attention itself (both are complementary)
|
Hi @Cyrilvallez I noticed that after this PR, calling Reproducer: from transformers import AutoModelForCausalLM
from transformers.cache_utils import HybridCache
import torch
model_id = 'hf-internal-testing/tiny-random-Gemma3ForCausalLM'
model = AutoModelForCausalLM.from_pretrained(model_id)
inputs = torch.arange(6).view(2, 3)
attention_mask = torch.ones_like(inputs)
# cache is required, w/o cache a tensor is returned as expected
cache = HybridCache(model.config, max_batch_size=2, max_cache_len=3)
model_kwargs = model.prepare_inputs_for_generation(
inputs, attention_mask=attention_mask, past_key_values=cache, cache_position=torch.arange(3)
)
mask = model_kwargs['attention_mask']
assert isinstance(mask, torch.Tensor), f"expected attention mask to be tensor, got {mask}"Before the PR ( |
| if not hasattr(model.config, "layer_types"): | ||
| # If `layer_types` is not specified explicitly in the config, there is only 1 type of layers, so | ||
| # export will use `StaticCache` by default. | ||
| logging.info("Using `StaticCache` for export as `layer_types` is not specified in the config.") | ||
| self.model = TorchExportableModuleWithStaticCache(model) | ||
| else: | ||
| if model.config.cache_implementation == "hybrid": | ||
| self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len) | ||
| else: | ||
| raise ValueError( | ||
| f"Unsupported cache implementation: {model.config.cache_implementation}. " | ||
| "Please use `hybrid` or `static`." | ||
| ) | ||
| self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Cyrilvallez What is layer_types? I'm concerning whether changes here are backwards compatible. For existing models on Hub like google/gemma-3-1b, it doesn't seem to come with the layer_types so it will fallback to the static cache which doesn't look correct.
|
One comment i have is that the way mask calculation is incorporated in most models is that the calculation of mask happens at model level. e.g. here https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/modeling_gemma3.py#L565-L566, however, different cache implementations may imply different attention masks. Different layers may have different cache impl, for example some layers can have sliding window of different size, others may use attention sink to keep say first few or some tokens. I feel the best way for the custom mask is at the attention layer so that the said layer can pass in all the information, including kv cache, to the custom mask function (e.g. layer_index). |
|
Hey, sorry all I was on vacations! @BenjaminBossan indeed, this is expected. Models using different types of attention in different layers (i.e. gemma3) will now have a dict returned by @guangy10 if you look at the configs, e.g. here, you'll see that the attribute was added in a BC manner for all models that were refactored! Let me know if you notice any issue though! @kimishpatel In transformers, Cache are not at the layer level, so as of now only some configurations are acceptable (though I've had in mind to change that for some time, to make it more modular). And computing the mask at the AttentionLayer-level is not only redundant (most layers will create the same mask, wasting precious time), but it breaks compile completely, as we cannot pre-compute the masks anymore. For now, there are no known models with sliding windows of different sizes for different layers, so we decided to make it as simple as possible. This was taken into account when doing this refactor though, no worries, we definitely thought about it to scale easily in the future should this scenario happen |
Resolves CI errors such as this one: https://github.com/huggingface/peft/actions/runs/15481482956/job/43588020111#step:5:53182 After resolving that error, other errors can occur, but they're unrelated and investigated independently. After the transformers change in huggingface/transformers#37866, it can happen that: > Models using different types of attention in different layers (i.e. gemma3) will now have a dict returned by prepare_inputd_for_generation (one dict entry per attention type) As PEFT operates on the attention mask for prompt learning methods, we need to adjust the code for the possibility of attention_mask being a dict. Right now, I simply extract the single value if the dict is just one element. For other sizes, I just raise an error, as I don't know how to deal with that. For our tests, this is enough but we might need to find a better solution in the future.
To be fair, I doubt attention mask calculation has that much impact on performance for the most models. I have implemented ring buffer based kv cache, that needs a very different way of calculating mask and that mask calculation, while redundant, happens at attention layer. I have not observed any significant amount of time spent in there. Although I think for block mask in flex attention, you might be right. That one is non-trivial.
how so? cache's I do understand though that transformers is not exactly providing building blocks for model authoring so from that perspective composability and modularity has limited value i suppose |
@Cyrilvallez Let's follow up in #38646 |
Resolves CI errors such as this one: https://github.com/huggingface/peft/actions/runs/15481482956/job/43588020111#step:5:53182 After resolving that error, other errors can occur, but they're unrelated and investigated independently. After the transformers change in huggingface/transformers#37866, it can happen that: > Models using different types of attention in different layers (i.e. gemma3) will now have a dict returned by prepare_inputd_for_generation (one dict entry per attention type) As PEFT operates on the attention mask for prompt learning methods, we need to adjust the code for the possibility of attention_mask being a dict. Right now, I simply extract the single value if the dict is just one element. For other sizes, I just raise an error, as I don't know how to deal with that. For our tests, this is enough but we might need to find a better solution in the future.
## Purpose ## * Fix tracing for model definitions introduced as part of `transformers==4.53` * Resolves #1603 ## Background ## In the latest transformers release, this change landed which changed the name of the function which generates the causal mask. huggingface/transformers#37866 ## Changes ## * Extend the list of function names to ignore during tracing, specifically targeting functions which create causal masks * Update debugger tool to use ignore list from `DatasetArguments` * Update Tracer to skip masking function as part of autowrapping any functions which were not caught by the autowrapper ## Testing ## * `tests/llmcompressor/transformers/tracing/test_models.py` now passes with the latest `transformers==4.53` --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Resolves CI errors such as this one: https://github.com/huggingface/peft/actions/runs/15481482956/job/43588020111#step:5:53182 After resolving that error, other errors can occur, but they're unrelated and investigated independently. After the transformers change in huggingface/transformers#37866, it can happen that: > Models using different types of attention in different layers (i.e. gemma3) will now have a dict returned by prepare_inputd_for_generation (one dict entry per attention type) As PEFT operates on the attention mask for prompt learning methods, we need to adjust the code for the possibility of attention_mask being a dict. Right now, I simply extract the single value if the dict is just one element. For other sizes, I just raise an error, as I don't know how to deal with that. For our tests, this is enough but we might need to find a better solution in the future.
## Purpose ## * Fix tracing for model definitions introduced as part of `transformers==4.53` * Resolves vllm-project#1603 ## Background ## In the latest transformers release, this change landed which changed the name of the function which generates the causal mask. huggingface/transformers#37866 ## Changes ## * Extend the list of function names to ignore during tracing, specifically targeting functions which create causal masks * Update debugger tool to use ignore list from `DatasetArguments` * Update Tracer to skip masking function as part of autowrapping any functions which were not caught by the autowrapper ## Testing ## * `tests/llmcompressor/transformers/tracing/test_models.py` now passes with the latest `transformers==4.53` --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Resolves CI errors such as this one: https://github.com/huggingface/peft/actions/runs/15481482956/job/43588020111#step:5:53182 After resolving that error, other errors can occur, but they're unrelated and investigated independently. After the transformers change in huggingface/transformers#37866, it can happen that: > Models using different types of attention in different layers (i.e. gemma3) will now have a dict returned by prepare_inputd_for_generation (one dict entry per attention type) As PEFT operates on the attention mask for prompt learning methods, we need to adjust the code for the possibility of attention_mask being a dict. Right now, I simply extract the single value if the dict is just one element. For other sizes, I just raise an error, as I don't know how to deal with that. For our tests, this is enough but we might need to find a better solution in the future.
What does this PR do?
As per the title. The goal is to properly separate masking logic from modeling code itself, to continue our objective of simplifying the library.
generate