-
Notifications
You must be signed in to change notification settings - Fork 26.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
4D attention_mask
support
#27539
4D attention_mask
support
#27539
Conversation
Generally, I don't have a problem with allowing to pass 4D attention masks! @poedator can you explain your use case a little bit for why you want to pass 4d attention masks? |
@patrickvonplaten
and run it with mask of all ones, passing such mask in 2D which gets expanded internally to 4D. The proposed way would be to have a batch shaped (1, 7):
At subsequent beam search iterations the mask will reflect which past tokens should the new tokens attend to. Another use case is kindly proposed by @UniverseFly below. |
Very interesting PR! Would this feature also enable SFT packing as mentioned in huggingface/trl#805? |
|
I tried this branch and the transformers/src/transformers/models/llama/modeling_llama.py Lines 1087 to 1124 in 53a7e77
|
Generate looks like a harder challenge for your methods - each individual sequence will be expanding, thus you'd need to reorder past_kv and mask at each step. I believe that to implement it, you'd need to write custom |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally the PR looks good to me! (We'd need some tests here).
@ArthurZucker wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks alright! but there should not be changes to the forward of the models (IMO)
if attention_mask is not None and len(attention_mask.shape) == 4: | ||
# assumes 4D mask for efficient beam search | ||
token_positions = torch.cumsum(attention_mask, dim=-1).amax(dim=(1, 2)) | ||
used_tokens_mask = attention_mask.amax(dim=(1, 2)) | ||
position_ids = (token_positions * used_tokens_mask).long() - 1 | ||
position_ids = position_ids[:, past_key_values_length:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this logic should not go here, it's should go in the prepare inputs for generation, as it's purely specific to 4d beam search.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, that this should be limited to just the mask code. Makes this PR more manageable. Llama can work without that, since it can accept position_ids
argument. Hopefully the newer models will support this argument. (could HF make it a part of some model guidelines?)
Hi, @ArthurZucker So far I have demo in Colab with monkey patch based on this PR. It shows a negligible difference in logits obtained the old and new ways. I dent to believe that this is a rounding error somewhere. Would you support it as the basis for the tests? Hi, @UniverseFly , |
Thanks for this PR and the demo. It is very helpful in trying the SpecInfer paper. Also in another recent progress on speculative decoding look ahead decoding Fig 5, this PR will also be useful. |
Reviewing now 😉 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, the test should go in the
transformers/tests/test_modeling_utils.py
Line 1481 in 8eae5ea
class AttentionMaskTester(unittest.TestCase): |
@ArthurZucker, please review. Hopefully it is ready to merge. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, just a few testing nits and good to go
tests/test_modeling_utils.py
Outdated
self.device = torch.device("cuda:0") | ||
model_name = "JackFram/llama-160m" # small Llama-like model from FlexFlow | ||
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32).to(self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32).to(self.device) | |
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(self.device) |
the smaller the better for our CI
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I observed that fp16 tests are more noisy, so what I did is:
- retained fp32 testsm but used even smaller model
- added fp16 test with relaxed tolerances
- added fp16 testing option for the top tokens order.
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@ArthurZucker, pls give me a hint about |
Earlier, I got frustrated with failing commits and added decorators everywhere. Now most of them are gone and it still passes CI checks. |
Thanks for the contribution! 🤗 |
@ArthurZucker , would you want to publish a blog post in HF blog with 4d attention use cases?
|
If you want feel free to do so! 🤗 |
Note that not all paths of this can be The following fails due to
|
@PhilJd, have you tested the preceding commit? |
Ah sorry, just looked at the blame - yeah, the previous commit fails as well @fxmarty . |
|
The function description should be updated to avoid confusion as |
@shentianxiao , thank you for your attention to the 4D attention!
it is not about compatibility, rather the flash_attention_2 code contrasted original mask vs modified mask coming from
I agree, that the original mask may also be 4d-shaped now. I just started PR #28151 with documentation updates - will make edits there. Hopefully the maintainers responsible for |
IMPORTANT: this PR makes changes that can only used by few classes of models
as of 20.12.2023, only a handful (under 20) of transformers model classes meet these criteria. Most of these classes are multimodal, which may require their own use cases for 4D masks. The pure language modelling classes fit to use the 4D mask changes from this PR are only |
I made a small blog post based on this PR. |
* edits to _prepare_4d_causal_attention_mask() * initial tests for 4d mask * attention_mask_for_sdpa support * added test for inner model hidden * added autotest decorators * test mask dtype to torch.int64 * torch.testing.assert_close Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * torch_device and @torch_gpu in tests * upd tests * +torch decorators * torch decorators fixed * more decorators! * even more decorators * fewer decorators --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Thanks for the amazing addition!! This is a great new feature. Just wanted to ask a question to make sure I am using it properly. In the code here, it looks like the 4D masks are expected to have shape My question is: are the Thanks! |
@jpgard , |
Great, thanks for the quick reply and for your hard work on this @poedator !! |
Has this been tested with flash attention 2? Works great for me without flash attention 2, but when using flash attention I get lots of messages of the form Lower chunk of the stack trace posted below.
Would be great to be able to use FA2 with this PR as the speedups are much larger as sequence length grows -- so FA2 seems like the perfect accompaniment to e.g. "packed" training sequences enabled by this PR. |
@jpgard , please share some simple testing code. I will look into this issue. |
This is implementation for feature request from #27493 custom 4d attention_mask as transformers .forward() argument.
_prepare_4d_causal_attention_mask()
intactpositions
tensor)position_ids
, I added code to generate them internally)The benefits of the code are to enable more memory-efficient text generation with tree-based parallel decoding as described in SpecInfer paper
Tagging:
@gante (generate)
@patrickvonplaten (masks)
@younesbelkada @ArthurZucker (text models)
This PR is WiP:
IMPORTANT: this PR makes changes that can only used by few classes of models
requirements to use:
position_ids
argument in.forward()
methodmodeling_attn_mask_utils.py::_prepare_4d_attention_mask()
function for 4d mask generationas of 20.12.2023, only a handful (under 20) of transformers model classes meet these criteria. Most of these classes are multimodal, which may require their own use cases for 4D masks. The pure language modelling classes fit to use the 4D mask changes from this PR are only
LlamaModel
,FalconModel
andXGLMModel
.