Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

custom 4d attention masks broken by #28937 #29525

Closed
poedator opened this issue Mar 7, 2024 · 3 comments · Fixed by #29618 or #29731
Closed

custom 4d attention masks broken by #28937 #29525

poedator opened this issue Mar 7, 2024 · 3 comments · Fixed by #29618 or #29731

Comments

@poedator
Copy link
Contributor

poedator commented Mar 7, 2024

System Info

The 4.38.2 version breaks code using custom 4d attention masks (introduced in #27539). Apparently, the custom masks gets replaced here:

causal_mask = attention_mask
if attention_mask is not None and cache_position is not None:
causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]

The issue was introduced with #28937. It is unclear whether the relevant slow tests for 4d masks were run then, but they fail now:

RUN_SLOW=1 python -m pytest -v ./tests/test_modeling_utils.py::Mask4DTestFP32
FAILED tests/test_modeling_utils.py::Mask4DTestFP32::test_attention - AttributeError: 'NoneType' object has no attribute 'shape'
FAILED tests/test_modeling_utils.py::Mask4DTestFP32::test_causal_model_logits - AssertionError: Tensor-likes are not close!
FAILED tests/test_modeling_utils.py::Mask4DTestFP32::test_inner_model - AssertionError: Tensor-likes are not close!

RUN_SLOW=1 python -m pytest -v ./tests/test_modeling_utils.py::Mask4DTestFP16
FAILED tests/test_modeling_utils.py::Mask4DTestFP16::test_attention - AttributeError: 'NoneType' object has no attribute 'shape'
FAILED tests/test_modeling_utils.py::Mask4DTestFP16::test_causal_model_logits - AssertionError: Tensor-likes are not close!

please fix or suggest workaround

summoning @ArthurZucker
cc @gante @younesbelkada

@gante
Copy link
Member

gante commented Mar 12, 2024

@poedator thank you for opening this issue! The PR linked above should fix it 🙏

@poedator
Copy link
Contributor Author

Hi, @gante ,
Sorry, but it is not fixed yet.

I tested with yesterday's commit 56b64bf and the problem still persists.
When debugging the issue I found this:

  • A custom 4D mask gets pasted to the top corner of causal_mask in _update_causal_mask() here
  • first decoding iteration (with empty cache) runs smoothly, so the tests pass
  • at the next decoding iteration, cache_position is set to torch.arange() starting from non-zero past_seen_tokens here
  • later, in the decoding layer, the cache_position tensor is used to extract attention mask from causal mask. But since past_seen_tokens is not zero, the extracted mask is DIFFERENT FROM THE ORIGINAL MASK.

proposed solutions:
a) ignore cache_position when handling custom 4D attention mask
b) document instructions how to cook cache_position when passing custom 4D attention mask, so that it gets delivered intact
c) somehow else ensure delivery of the 4D mask intact. Maybe with extra option.

Also the test should be updated to perform more than one forward iteration, to allow testing with cache. I am not attempting a PR here because don't know the greater context of the changes in this part of transformers. Will be glad to test though.

Please fix this - I need it to work for my fancy speculative decoding trees (will show you soon).

@poedator
Copy link
Contributor Author

I put together a notebook with tests.
https://gist.github.com/poedator/5c2faaff175aa8a4f12671a8e0ce835c

The one that uses kv cache fails with 4.39.dev but works OK with 4.37.2. I also wrapped it as a new test case, ready to be pasted into transformers/tests/test_modeling_utils.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants