-
Notifications
You must be signed in to change notification settings - Fork 30.5k
Llama: fix custom 4D masks #29930
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama: fix custom 4D masks #29930
Conversation
if attention_mask is not None: | ||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit | ||
if attention_mask.dim() == 2: | ||
if attention_mask is not None and attention_mask.dim() == 4: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reordered the logic: custom 4D masks are now a superset of the default mask, so we don't need to create the default mask first :)
offset = cache_position[0] | ||
mask_slice = mask_slice[..., offset : offset + sequence_length, :] | ||
causal_mask = mask_slice | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This else
has no changes. Only the if attention_mask is not None and attention_mask.dim() == 4:
is different.
tests/test_modeling_utils.py
Outdated
self.assertEqual(decoded_0, decoded_1b) | ||
|
||
# Case 2: we pass a 4D attention mask regarding the full sequence length (i.e. [..., full_len, full_len]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added this test case (we can now pass full custom 4D attention masks)
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks LGTM just want to always trigger the tests
Thanks @gante ! |
@poedator would you like to open a PR with that? As a user, you'll probably have cool examples in mind! |
will try, but not this week... |
@@ -735,3 +736,138 @@ def test_model_7b_logits(self): | |||
] | |||
infilling = tokenizer.batch_decode(generated_ids) | |||
self.assertEqual(infilling, EXPECTED_INFILLING) | |||
|
|||
|
|||
@slow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This set of slow tests was moved to the llama test file -> if we run the slow llama tests, which we often request, this will now be triggered
@@ -4027,6 +4027,101 @@ def test_flash_attn_2_from_config(self): | |||
|
|||
self.assertFalse(fa2_correctly_converted) | |||
|
|||
def _get_custom_4d_mask_test_data(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This set of tests are now:
- part of the mixin, so they are run on all push commits
- a fast test, using
model = model_class(config)
from the test config - triggered by
model_class._supports_cache_class == True
-- recent LLMs [llama, cohere, gemma, mistral, mixtral, starcoder2, ...] have this attribute set toTrue
and are 4D mask-compatible. Older models are often not compatible. Over time, as we spread the cache refactor, this test will be run on those classes as well 👀
@ArthurZucker ready for a re-review (test rework) -- we now have on push tests for all recent models + custom 4D mask :) |
@gante , I made the cache longer than the masks and padded the masks to the cache length. is this the correct way? |
Sorry for the delay, let's rebase on main as well |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very good! Let's rebase on main, #30047 was merged, and run slow tests!
please, please merge this PR - I need it for my speculative decoding paper project! The 4D masks are essential for it. |
Sorry just got back to github 😓 could you rebase! |
I rebased this PR in new one #30348 and added few important changes. |
Closing in favor of #30348 |
What does this PR do?
Fixes the issue raised by @poedator in this comment.
Causal mask is now of shape
[..., seq_len, full_len]
, as opposed to[..., full_len, full_len]
. This means custom 4D attention masks are now the whole causal mask, so we don't need a sliced copy -- we can copy the whole thing :)This PR also expands the support of custom 4D attention mask: we can pass both the full mask (
[..., full_len, full_len]
) or the partial mask ([..., seq_len, full_len]
).