-
Notifications
You must be signed in to change notification settings - Fork 125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix key_padding_mask
when using attn_impl='flash'
#163
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if throughput looks good, then this lgtm
Verified with a 7B model on 80GB and 40GB cards, and throughput looks good. Adding a comment with link to this PR and then merging. |
When we merged #128 , we removed the codepath by which we always force a padding token
pad_token = eos_token
for tokenizers like GPT2. Since this can lead to silent errors. We also started pretokenizing our data and passing the raw torch tensor to the HF collate fn, rather than a batch created by a HF tokenizer (withattention_mask
optionally inside).These changes both lead to our text dataloader batches not producing
attention_mask
anymore with the GPT2 tokenizer, it just producesNone
, since there is no padding possible anymore.This meant that our
MosaicGPT
class was now receivingkey_padding_mask=None
in all situations, includingattn_impl='flash'
. We had not tested this before because it had never showed up.Upon initial testing of
attn_impl=flash
up to 3B models, nothing seemed broken, training throughput was fine, and even slightly higher for some models.Last night, while testing 7B models with the settings in our
throughput/
tables, we noticed that the MFU was significantly worse, and there was a lot of "thrashing" in the memory allocator.The
FlashMHA
layer from HazyResearch has a separate codepath whenkey_padding_mask=None
: https://github.com/HazyResearch/flash-attention/blob/2dc2a195890d323f6f9e1b74e4667099e6144f79/flash_attn/flash_attention.py#L42For reasons unknown, this codepath is hurting performance for the larger models. Avoiding this codepath by sending a
key_padding_mask
of all 1s appears to fix the problem.