Skip to content

Conversation

@Cyrilvallez
Copy link
Member

@Cyrilvallez Cyrilvallez commented Apr 29, 2025

What does this PR do?

As per the title. The goal is to properly separate masking logic from modeling code itself, to continue our objective of simplifying the library.

  • Code is much simpler to understand
  • Much more general: always work for all lengths, and all attention implementations, e.g.:
    • flex attention now works with sliding/hybrid models (not the case before)
    • FA2 now works with static caches (including models with default hybrid structures) (was only the case for hybrid models before)
  • All models can use all Cache classes (e.g. models with Hybrid structure can default back to use DynamicCache)
  • Extremely scalable in the future: any pattern of layers can be taken into account WITHOUT ANY CHANGE to modeling or masking. A new masking pattern (e.g. the recently introduced chunked attention for Llama4) can be added with minimal efforts (just add a new mask_mod to describe it, and voila!)
  • A single truth: mask creation was copied over and over again, but sometimes with slight changes to account for sliding windows or similar. This would eventually lead to mistakes or inefficiencies as things would be "forced to fit", and a lot of maintenance burden
  • compile compatible: the new mask creation is technically compile compatible - it should however stay outside what is compiled in the forward to avoid recompilations as it's being done in generate
  • Allow external mask creation: In case someone passes their custom attention implementation, they may need their own mask creation function, which is now supported
  • TGI/vLLM backend should be even more efficient now, as we don't waste compute on creating a useless mask (would previously create a 4d mask as for sdpa, which would not be used)

@github-actions github-actions bot marked this pull request as draft April 29, 2025 14:26
@github-actions
Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Cyrilvallez Cyrilvallez changed the title Refactor mask [core] Completely rewrite the masking logic for all attentions May 8, 2025
@Cyrilvallez Cyrilvallez force-pushed the refactor-mask branch 3 times, most recently from 53ca556 to ce42aa7 Compare May 12, 2025 07:49
@Cyrilvallez Cyrilvallez marked this pull request as ready for review May 12, 2025 16:25
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very vey nice!
One thing I want to consider is to rather call the sliding, causal and chuncking directly in the modeling.
For example:

  • llama only need causal_mask, under the hood the causal mask should do an and with sddpa or cflash or flex
  • gemma need sliding_causal: same
  • llama4 needs chuncked causal
    I want the modeling to call an explicit function, rather than the mega general one!

This would keep our philosophy, as we don't want too general stuff hapenning when not needed (ex: llama should never care about sliding in codepathes)

Also misssing doc about how to add a new func!

@ydshieh
Copy link
Collaborator

ydshieh commented May 15, 2025

Wow!!!!!!!! 🚀

This PR seems worth a manually full CI. Ping me when it's time you think this PR is ready to trigger CI.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Damn nice

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

REview for the core logic, IMO can be simplified! BUt the modeling part is absolutely perfect!
For the visualization, I'll see how we could just overwrite the repr without affecting other operations!

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a quick question on this refactor: If I understand the code correctly, then the focus is currently on causal masks only, correct?

Would be nice to add a non-causal alternative which should only use a padding mask and expand respectively to the q_len and kv_len. That's more food for thought :D I dont want to make this PR even harder than it is.

@Cyrilvallez
Copy link
Member Author

For now it's mostly on causal masks because they are the one we need, but the idea is that it can be extended super easily from a set of mask primitives!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mega nice!
TODO before merging:

  1. move the causal_mask_mapping to a class attribute!
  2. show example of how to register a new function, but minimal (without sdpa correction for example)
  3. Make sure full graph training is not broken maybe? or at least fa2 training

That should be i

@Cyrilvallez Cyrilvallez changed the title [core] Completely rewrite the masking logic for all attentions 🚨🚨[core] Completely rewrite the masking logic for all attentions May 20, 2025
@Cyrilvallez Cyrilvallez force-pushed the refactor-mask branch 2 times, most recently from dee568c to 4a2e906 Compare May 21, 2025 11:30
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's go!

Comment on lines +141 to +144
def my_new_sdpa_mask(*args, **kwargs):
print("I just entered the attention mask computation")
return sdpa_mask(*args, **kwargs)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's rather show how to do something like the paligemma or document masking here, something relevant!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those are a bit different, it's modifying the mask pattern vs adding a new mask format for the attention itself (both are complementary)

@Cyrilvallez Cyrilvallez merged commit 163138a into main May 22, 2025
21 checks passed
@Cyrilvallez Cyrilvallez deleted the refactor-mask branch May 22, 2025 09:38
@BenjaminBossan
Copy link
Member

Hi @Cyrilvallez I noticed that after this PR, calling prepare_inputs_for_generation can return an attention_mask that is a dict instead of a tensor. Is this expected? If yes, I need to update PEFT.

Reproducer:

from transformers import AutoModelForCausalLM
from transformers.cache_utils import HybridCache
import torch

model_id = 'hf-internal-testing/tiny-random-Gemma3ForCausalLM'
model = AutoModelForCausalLM.from_pretrained(model_id)

inputs = torch.arange(6).view(2, 3)
attention_mask = torch.ones_like(inputs)
# cache is required, w/o cache a tensor is returned as expected
cache = HybridCache(model.config, max_batch_size=2, max_cache_len=3)

model_kwargs = model.prepare_inputs_for_generation(
    inputs, attention_mask=attention_mask, past_key_values=cache, cache_position=torch.arange(3)
)
mask = model_kwargs['attention_mask']
assert isinstance(mask, torch.Tensor), f"expected attention mask to be tensor, got {mask}"

Before the PR (f8630c778c9220defecf1e3026d3438108b0baba), this passes. After the PR (163138a911c1fb4451ec4b32edaee20918a59def), it fails with:

AssertionError: expected attention mask to be tensor, got {'sliding_attention': tensor([[[[ True, False, False],
          [ True,  True, False],
          [ True,  True,  True]]],


        [[[ True, False, False],
          [ True,  True, False],
          [ True,  True,  True]]]])}

Comment on lines +60 to +66
if not hasattr(model.config, "layer_types"):
# If `layer_types` is not specified explicitly in the config, there is only 1 type of layers, so
# export will use `StaticCache` by default.
logging.info("Using `StaticCache` for export as `layer_types` is not specified in the config.")
self.model = TorchExportableModuleWithStaticCache(model)
else:
if model.config.cache_implementation == "hybrid":
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
else:
raise ValueError(
f"Unsupported cache implementation: {model.config.cache_implementation}. "
"Please use `hybrid` or `static`."
)
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Cyrilvallez What is layer_types? I'm concerning whether changes here are backwards compatible. For existing models on Hub like google/gemma-3-1b, it doesn't seem to come with the layer_types so it will fallback to the static cache which doesn't look correct.

@guangy10 guangy10 mentioned this pull request Jun 6, 2025
3 tasks
@kimishpatel
Copy link
Contributor

One comment i have is that the way mask calculation is incorporated in most models is that the calculation of mask happens at model level. e.g. here https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/modeling_gemma3.py#L565-L566, however, different cache implementations may imply different attention masks. Different layers may have different cache impl, for example some layers can have sliding window of different size, others may use attention sink to keep say first few or some tokens. I feel the best way for the custom mask is at the attention layer so that the said layer can pass in all the information, including kv cache, to the custom mask function (e.g. layer_index).

@Cyrilvallez
Copy link
Member Author

Hey, sorry all I was on vacations!

@BenjaminBossan indeed, this is expected. Models using different types of attention in different layers (i.e. gemma3) will now have a dict returned by prepare_inputd_for_generation (one dict entry per attention type). Sorry this broke your tests! 😬

@guangy10 if you look at the configs, e.g. here, you'll see that the attribute was added in a BC manner for all models that were refactored! Let me know if you notice any issue though!

@kimishpatel In transformers, Cache are not at the layer level, so as of now only some configurations are acceptable (though I've had in mind to change that for some time, to make it more modular). And computing the mask at the AttentionLayer-level is not only redundant (most layers will create the same mask, wasting precious time), but it breaks compile completely, as we cannot pre-compute the masks anymore. For now, there are no known models with sliding windows of different sizes for different layers, so we decided to make it as simple as possible. This was taken into account when doing this refactor though, no worries, we definitely thought about it to scale easily in the future should this scenario happen

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Jun 10, 2025
Resolves CI errors such as this one:

https://github.com/huggingface/peft/actions/runs/15481482956/job/43588020111#step:5:53182

After resolving that error, other errors can occur, but they're
unrelated and investigated independently.

After the transformers change in
huggingface/transformers#37866, it can happen
that:

> Models using different types of attention in different layers (i.e.
gemma3) will now have a dict returned by
prepare_inputd_for_generation (one dict entry per attention type)

As PEFT operates on the attention mask for prompt learning methods, we
need to adjust the code for the possibility of attention_mask being a
dict. Right now, I simply extract the single value if the dict is just
one element. For other sizes, I just raise an error, as I don't know how
to deal with that. For our tests, this is enough but we might need to
find a better solution in the future.
@kimishpatel
Copy link
Contributor

kimishpatel commented Jun 10, 2025

And computing the mask at the AttentionLayer-level is not only redundant (most layers will create the same mask, wasting precious time),

To be fair, I doubt attention mask calculation has that much impact on performance for the most models. I have implemented ring buffer based kv cache, that needs a very different way of calculating mask and that mask calculation, while redundant, happens at attention layer. I have not observed any significant amount of time spent in there.

Although I think for block mask in flex attention, you might be right. That one is non-trivial.

Cache are not at the layer level,

how so? cache's update functions accept layer_idx, so they do have to know what layer the update belongs to.

I do understand though that transformers is not exactly providing building blocks for model authoring so from that perspective composability and modularity has limited value i suppose

@guangy10
Copy link
Contributor

guangy10 commented Jun 10, 2025

@guangy10 if you look at the configs, e.g. here, you'll see that the attribute was added in a BC manner for all models that were refactored! Let me know if you notice any issue though!

@Cyrilvallez Let's follow up in #38646

BenjaminBossan added a commit to huggingface/peft that referenced this pull request Jun 27, 2025
Resolves CI errors such as this one:

https://github.com/huggingface/peft/actions/runs/15481482956/job/43588020111#step:5:53182

After resolving that error, other errors can occur, but they're
unrelated and investigated independently.

After the transformers change in
huggingface/transformers#37866, it can happen
that:

> Models using different types of attention in different layers (i.e.
gemma3) will now have a dict returned by
prepare_inputd_for_generation (one dict entry per attention type)

As PEFT operates on the attention mask for prompt learning methods, we
need to adjust the code for the possibility of attention_mask being a
dict. Right now, I simply extract the single value if the dict is just
one element. For other sizes, I just raise an error, as I don't know how
to deal with that. For our tests, this is enough but we might need to
find a better solution in the future.
kylesayrs added a commit to vllm-project/llm-compressor that referenced this pull request Jul 1, 2025
## Purpose ##
* Fix tracing for model definitions introduced as part of
`transformers==4.53`
* Resolves #1603

## Background ##
In the latest transformers release, this change landed which changed the
name of the function which generates the causal mask.
huggingface/transformers#37866

## Changes ##
* Extend the list of function names to ignore during tracing,
specifically targeting functions which create causal masks
* Update debugger tool to use ignore list from `DatasetArguments`
* Update Tracer to skip masking function as part of autowrapping any
functions which were not caught by the autowrapper

## Testing ##
* `tests/llmcompressor/transformers/tracing/test_models.py` now passes
with the latest `transformers==4.53`

---------

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
efraimdahl pushed a commit to efraimdahl/peft that referenced this pull request Jul 12, 2025
Resolves CI errors such as this one:

https://github.com/huggingface/peft/actions/runs/15481482956/job/43588020111#step:5:53182

After resolving that error, other errors can occur, but they're
unrelated and investigated independently.

After the transformers change in
huggingface/transformers#37866, it can happen
that:

> Models using different types of attention in different layers (i.e.
gemma3) will now have a dict returned by
prepare_inputd_for_generation (one dict entry per attention type)

As PEFT operates on the attention mask for prompt learning methods, we
need to adjust the code for the possibility of attention_mask being a
dict. Right now, I simply extract the single value if the dict is just
one element. For other sizes, I just raise an error, as I don't know how
to deal with that. For our tests, this is enough but we might need to
find a better solution in the future.
aireilly pushed a commit to aireilly/llm-compressor that referenced this pull request Jul 30, 2025
## Purpose ##
* Fix tracing for model definitions introduced as part of
`transformers==4.53`
* Resolves vllm-project#1603

## Background ##
In the latest transformers release, this change landed which changed the
name of the function which generates the causal mask.
huggingface/transformers#37866

## Changes ##
* Extend the list of function names to ignore during tracing,
specifically targeting functions which create causal masks
* Update debugger tool to use ignore list from `DatasetArguments`
* Update Tracer to skip masking function as part of autowrapping any
functions which were not caught by the autowrapper

## Testing ##
* `tests/llmcompressor/transformers/tracing/test_models.py` now passes
with the latest `transformers==4.53`

---------

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
csqaiub added a commit to csqaiub/peft that referenced this pull request Sep 28, 2025
Resolves CI errors such as this one:

https://github.com/huggingface/peft/actions/runs/15481482956/job/43588020111#step:5:53182

After resolving that error, other errors can occur, but they're
unrelated and investigated independently.

After the transformers change in
huggingface/transformers#37866, it can happen
that:

> Models using different types of attention in different layers (i.e.
gemma3) will now have a dict returned by
prepare_inputd_for_generation (one dict entry per attention type)

As PEFT operates on the attention mask for prompt learning methods, we
need to adjust the code for the possibility of attention_mask being a
dict. Right now, I simply extract the single value if the dict is just
one element. For other sizes, I just raise an error, as I don't know how
to deal with that. For our tests, this is enough but we might need to
find a better solution in the future.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants