-
Notifications
You must be signed in to change notification settings - Fork 31.3k
🚨🚨[core] Completely rewrite the masking logic for all attentions #37866
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+2,984
−6,808
Merged
Changes from all commits
Commits
Show all changes
193 commits
Select commit
Hold shift + click to select a range
e083d5c
start
Cyrilvallez e1d43c4
start having a clean 4d mask primitive
Cyrilvallez 59a69c4
Update mask_utils.py
Cyrilvallez 8aa61b0
Update mask_utils.py
Cyrilvallez ee7bafd
switch name
Cyrilvallez bfcc5d8
Update masking_utils.py
Cyrilvallez f92757a
add a new AttentionMask tensor class
ArthurZucker 932c17b
fix import
ArthurZucker 1227356
nits
ArthurZucker 542d054
fixes
ArthurZucker 99235bb
use full and quandrants
ArthurZucker 6d16c6b
general sdpa mask for all caches
Cyrilvallez c98bc68
style
Cyrilvallez f3027fe
start some tests
Cyrilvallez 7eeed31
tests with sliding, chunked
Cyrilvallez ddd6059
add styling
ArthurZucker 9397d17
test hybrid
Cyrilvallez bb6ea15
Update masking_utils.py
Cyrilvallez d7c4fa7
small temp fixes
Cyrilvallez 3ea388c
Update modeling_gemma2.py
Cyrilvallez 2165232
compile compatible
Cyrilvallez 07fda4f
Update masking_utils.py
Cyrilvallez eed6383
improve
Cyrilvallez 15485c3
start making it more general
Cyrilvallez ce4080b
Update masking_utils.py
Cyrilvallez 039e444
generate
Cyrilvallez 8b99dde
make it work with flex style primitives!
Cyrilvallez 14f0163
Update masking_utils.py
Cyrilvallez 6bb8742
Update masking_utils.py
Cyrilvallez 1c16acb
Update masking_utils.py
Cyrilvallez 8f75abb
improve
Cyrilvallez 387cdb4
Update cache_utils.py
Cyrilvallez 346bfa9
Update masking_utils.py
Cyrilvallez b826d91
simplify - starting to look good!
Cyrilvallez 83f20d3
Update masking_utils.py
Cyrilvallez 6fa1d35
name
Cyrilvallez 05777fb
Update masking_utils.py
Cyrilvallez 6ab437e
style
Cyrilvallez 1d6c900
Update masking_utils.py
Cyrilvallez b5e1ebd
Update masking_utils.py
Cyrilvallez 26a428c
Update masking_utils.py
Cyrilvallez 71162e4
Update masking_utils.py
Cyrilvallez 28d5a19
small fix for flex
Cyrilvallez 8532033
flex compile
Cyrilvallez f3c8e7c
FA2
Cyrilvallez d0c0b40
Update masking_utils.py
Cyrilvallez bbc6bec
Escape for TGI/vLLM!
Cyrilvallez ada1627
Update masking_utils.py
Cyrilvallez 1620f1e
Update masking_utils.py
Cyrilvallez 8f3a2a0
Update masking_utils.py
Cyrilvallez d5b3285
General case without cache
Cyrilvallez ce22728
rename
Cyrilvallez 7bf1352
full test on llama4
Cyrilvallez 05529ff
small fix for FA2 guard with chunk
Cyrilvallez 5afd898
Update modeling_gemma2.py
Cyrilvallez 67bea3f
post rebase cleanup
Cyrilvallez 7bb501d
FA2 supports static cache!
Cyrilvallez 9bb6219
Update modeling_flash_attention_utils.py
Cyrilvallez f4849ab
Update flex_attention.py
Cyrilvallez 44f9b65
Update masking_utils.py
Cyrilvallez dc52eb3
Update masking_utils.py
Cyrilvallez d2f645d
Update utils.py
Cyrilvallez 07bd06c
override for export
Cyrilvallez f6735f2
Update executorch.py
Cyrilvallez f2d8a54
Update executorch.py
Cyrilvallez ee0afdd
Update executorch.py
Cyrilvallez 73549e5
Update executorch.py
Cyrilvallez 7031bc4
Update masking_utils.py
Cyrilvallez 552b586
Update masking_utils.py
Cyrilvallez d2fb4de
output attentions
Cyrilvallez 628dcd8
style
Cyrilvallez 27fb93f
Update masking_utils.py
Cyrilvallez a091517
Update executorch.py
Cyrilvallez cbf1144
Add doicstring
Cyrilvallez 59eb3cc
Add license and put mask visualizer at the end
Cyrilvallez 85ab5da
Update test_modeling_common.py
Cyrilvallez daf5bee
fix broken test
Cyrilvallez 201da65
Update test_modeling_gemma.py
Cyrilvallez f06c7cd
Update test_modeling_gemma2.py
Cyrilvallez 73a12b4
Use fullgraph=False with FA2
Cyrilvallez d0f6f7f
Update utils.py
Cyrilvallez 5a25046
change name
Cyrilvallez cd9461d
Update masking_utils.py
Cyrilvallez 4169bc3
improve doc
Cyrilvallez 3166d47
change name
Cyrilvallez 528aab6
Update modeling_attn_mask_utils.py
Cyrilvallez 77f2c66
more explicit logic based on model's property
Cyrilvallez 58d1384
pattern in config
Cyrilvallez 7a6ac01
extend
Cyrilvallez f390675
fixes
Cyrilvallez e26eb84
make it better
Cyrilvallez bca422c
generalize to other test models
Cyrilvallez a66697e
fix
Cyrilvallez 7ea8db7
Update masking_utils.py
Cyrilvallez 1d3751f
fix
Cyrilvallez df43917
do not check mask equivalence if layer types are different
Cyrilvallez 095746a
executorch
Cyrilvallez 770422c
Update modeling_gemma2.py
Cyrilvallez 0b5a817
Update masking_utils.py
Cyrilvallez cf5212c
use layer_idx instead
Cyrilvallez e28d663
adjust
Cyrilvallez 53e9f47
Update masking_utils.py
Cyrilvallez 8e2bdd1
test
Cyrilvallez 558c47e
fix imports
Cyrilvallez df49780
Update modeling_gemma2.py
Cyrilvallez a87f7dd
other test models
Cyrilvallez 8426b34
Update modeling_llama4.py
Cyrilvallez 413d446
Update masking_utils.py
Cyrilvallez 7f0f989
improve
Cyrilvallez 3ed17a2
simplify
Cyrilvallez f23236d
Update masking_utils.py
Cyrilvallez 0ffff1d
typos
Cyrilvallez 09d32df
typo
Cyrilvallez e20ebab
fix
Cyrilvallez d273325
Update masking_utils.py
Cyrilvallez 5ae049c
default DynamicCache
Cyrilvallez 326bacf
remove default cache
Cyrilvallez d58eaab
simplify
Cyrilvallez 02a9180
Update masking_utils.py
Cyrilvallez d67de19
Update masking_utils.py
Cyrilvallez 3831ccc
Update masking_utils.py
Cyrilvallez 4b54f18
Update masking_utils.py
Cyrilvallez 6edf116
simplify
Cyrilvallez 18614a5
Update masking_utils.py
Cyrilvallez bd931a0
Update masking_utils.py
Cyrilvallez 93f8d82
Update masking_utils.py
Cyrilvallez 711ab9b
export
Cyrilvallez 58f198e
Update executorch.py
Cyrilvallez 9c69ae5
Update executorch.py
Cyrilvallez 4e40516
Update flex_attention.py
Cyrilvallez 6a28a34
Update executorch.py
Cyrilvallez c70bf3c
upstream to modular gemma 1 & 2
Cyrilvallez 3a972d4
Update modular_mistral.py
Cyrilvallez 7ca132d
switch names
Cyrilvallez 34a55c5
use dict
Cyrilvallez 5c89d72
put it in the Layer directly
Cyrilvallez e6891b6
update copy model source for mask functions
Cyrilvallez ac02170
apply so many modular (hopefully 1 shot)
Cyrilvallez 59e11ab
use explicite dicts for make style happy
Cyrilvallez 27041e0
protect import
Cyrilvallez 0cf18e2
check docstring
Cyrilvallez 47158df
better default in hybrid caches
Cyrilvallez 022c4a9
qwens
Cyrilvallez 94896dc
Update modular_qwen2.py
Cyrilvallez 9bbe1cb
simplify core logic!
Cyrilvallez 0844a49
Update executorch.py
Cyrilvallez dbbecde
qwen3 moe
Cyrilvallez a350263
Update masking_utils.py
Cyrilvallez 09b0148
Update masking_utils.py
Cyrilvallez fcd21a4
simplify a lot sdpa causal skip
Cyrilvallez 8cb637f
Update masking_utils.py
Cyrilvallez 481f086
post-rebase
Cyrilvallez 91c87f8
gemma3 finally
Cyrilvallez 9bda864
style
Cyrilvallez d24309f
check it before
Cyrilvallez 8e153a1
gemma3
Cyrilvallez ebc7f9d
More general with newer torch
Cyrilvallez 31008ba
align gemma3
Cyrilvallez 3c385ea
Update utils.py
Cyrilvallez b206cd5
Update utils.py
Cyrilvallez b0850bf
Update masking_utils.py
Cyrilvallez 79eac77
Update test_modeling_common.py
Cyrilvallez 29a6bc2
Update flex_attention.py
Cyrilvallez bb2dda0
Update flex_attention.py
Cyrilvallez 1b85bbb
Update flex_attention.py
Cyrilvallez f76df19
test
Cyrilvallez 3ff3908
executorch
Cyrilvallez fd8a6a2
Update test_modeling_common.py
Cyrilvallez 84db8ee
Update masking_utils.py
Cyrilvallez 83ba79f
Update masking_utils.py
Cyrilvallez acbe4be
Update masking_utils.py
Cyrilvallez b0333de
Update masking_utils.py
Cyrilvallez 3c48334
Update executorch.py
Cyrilvallez cfd0694
Update test_modeling_common.py
Cyrilvallez 0181042
fix copies
Cyrilvallez ad5fb36
device
Cyrilvallez b477c1e
sdpa can be used without mask -> pass the torchscript tests in this case
Cyrilvallez 3b71b7b
Use enum for check
Cyrilvallez 1a05ca1
revert enum and add check instead
Cyrilvallez 2029cfa
remove broken test
Cyrilvallez 28d62da
cohere2
Cyrilvallez 9d7bd3a
some doc & reorganize the Interface
Cyrilvallez 343ab95
Update tensor_parallel.py
Cyrilvallez 78a21ea
Update tensor_parallel.py
Cyrilvallez 4c87caa
doc and dummy
Cyrilvallez 1f21213
Update test_modeling_paligemma2.py
Cyrilvallez e353067
Update modeling_falcon_h1.py
Cyrilvallez 7979ac6
Update masking_utils.py
Cyrilvallez ba6501c
executorch patch
Cyrilvallez 269969e
style
Cyrilvallez 75ccf7a
CIs
Cyrilvallez 7bcd55f
use register in executorch
Cyrilvallez 9245fcd
final comments!
Cyrilvallez File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's rather show how to do something like the paligemma or document masking here, something relevant!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those are a bit different, it's modifying the mask pattern vs adding a new mask format for the attention itself (both are complementary)