Skip to content

Conversation

gante
Copy link
Member

@gante gante commented Mar 28, 2024

What does this PR do?

Fixes the issue raised by @poedator in this comment.

Causal mask is now of shape [..., seq_len, full_len], as opposed to [..., full_len, full_len]. This means custom 4D attention masks are now the whole causal mask, so we don't need a sliced copy -- we can copy the whole thing :)

This PR also expands the support of custom 4D attention mask: we can pass both the full mask ([..., full_len, full_len]) or the partial mask ([..., seq_len, full_len]).

@gante gante requested a review from ArthurZucker March 28, 2024 10:35
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
if attention_mask is not None and attention_mask.dim() == 4:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reordered the logic: custom 4D masks are now a superset of the default mask, so we don't need to create the default mask first :)

offset = cache_position[0]
mask_slice = mask_slice[..., offset : offset + sequence_length, :]
causal_mask = mask_slice
else:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This else has no changes. Only the if attention_mask is not None and attention_mask.dim() == 4: is different.

self.assertEqual(decoded_0, decoded_1b)

# Case 2: we pass a 4D attention mask regarding the full sequence length (i.e. [..., full_len, full_len])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this test case (we can now pass full custom 4D attention masks)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks LGTM just want to always trigger the tests

@poedator
Copy link
Contributor

Thanks @gante !
I tested the code and it works.
My only suggestion is to cover the 4D masks in the documentation for Llama and other models.

@gante
Copy link
Member Author

gante commented Mar 28, 2024

My only suggestion is to cover the 4D masks in the documentation for Llama and other models.

@poedator would you like to open a PR with that? As a user, you'll probably have cool examples in mind!

@poedator
Copy link
Contributor

My only suggestion is to cover the 4D masks in the documentation for Llama and other models.

@poedator would you like to open a PR with that? As a user, you'll probably have cool examples in mind!

will try, but not this week...

@gante gante requested a review from ArthurZucker March 28, 2024 15:54
@@ -735,3 +736,138 @@ def test_model_7b_logits(self):
]
infilling = tokenizer.batch_decode(generated_ids)
self.assertEqual(infilling, EXPECTED_INFILLING)


@slow
Copy link
Member Author

@gante gante Mar 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This set of slow tests was moved to the llama test file -> if we run the slow llama tests, which we often request, this will now be triggered

@@ -4027,6 +4027,101 @@ def test_flash_attn_2_from_config(self):

self.assertFalse(fa2_correctly_converted)

def _get_custom_4d_mask_test_data(self):
Copy link
Member Author

@gante gante Mar 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This set of tests are now:

  1. part of the mixin, so they are run on all push commits
  2. a fast test, using model = model_class(config) from the test config
  3. triggered by model_class._supports_cache_class == True -- recent LLMs [llama, cohere, gemma, mistral, mixtral, starcoder2, ...] have this attribute set to True and are 4D mask-compatible. Older models are often not compatible. Over time, as we spread the cache refactor, this test will be run on those classes as well 👀

@gante
Copy link
Member Author

gante commented Mar 28, 2024

@ArthurZucker ready for a re-review (test rework) -- we now have on push tests for all recent models + custom 4D mask :)

@poedator
Copy link
Contributor

poedator commented Mar 30, 2024

@gante ,
May I suggest adding extra tests for 4D masks with StaticCache? I am concerned that StaticCache code may handle the custom masks differently. I intend to use 4D masks with StaticCache in my new speculative decoding implementation.
Here are the additional methods for Mask4DTestHard:
https://gist.github.com/poedator/f1c15551d202df2682c65c1bbdcb1c07

I made the cache longer than the masks and padded the masks to the cache length. is this the correct way?

@ArthurZucker
Copy link
Collaborator

Sorry for the delay, let's rebase on main as well

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good! Let's rebase on main, #30047 was merged, and run slow tests!

@poedator
Copy link
Contributor

please, please merge this PR - I need it for my speculative decoding paper project! The 4D masks are essential for it.
@gante @ArthurZucker

@ArthurZucker
Copy link
Collaborator

Sorry just got back to github 😓 could you rebase!

@poedator
Copy link
Contributor

poedator commented Apr 19, 2024

Sorry just got back to github 😓 could you rebase!
Too bad it did not make it into 4.40 (

I rebased this PR in new one #30348 and added few important changes.

@poedator
Copy link
Contributor

poedator commented Apr 22, 2024

@gante, thank you for the rebase. Meanwhile I added more improvements in #30348 - let's close this #29930 and continue there.

@gante
Copy link
Member Author

gante commented Apr 23, 2024

Closing in favor of #30348

@gante gante closed this Apr 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants