Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix key_padding_mask when using attn_impl='flash' #163

Merged
merged 5 commits into from
Feb 14, 2023

Conversation

abhi-mosaic
Copy link
Contributor

@abhi-mosaic abhi-mosaic commented Feb 14, 2023

When we merged #128 , we removed the codepath by which we always force a padding token pad_token = eos_token for tokenizers like GPT2. Since this can lead to silent errors. We also started pretokenizing our data and passing the raw torch tensor to the HF collate fn, rather than a batch created by a HF tokenizer (with attention_mask optionally inside).

These changes both lead to our text dataloader batches not producing attention_mask anymore with the GPT2 tokenizer, it just produces None, since there is no padding possible anymore.

This meant that our MosaicGPT class was now receiving key_padding_mask=None in all situations, including attn_impl='flash'. We had not tested this before because it had never showed up.

Upon initial testing of attn_impl=flash up to 3B models, nothing seemed broken, training throughput was fine, and even slightly higher for some models.

Last night, while testing 7B models with the settings in our throughput/ tables, we noticed that the MFU was significantly worse, and there was a lot of "thrashing" in the memory allocator.

The FlashMHA layer from HazyResearch has a separate codepath when key_padding_mask=None: https://github.com/HazyResearch/flash-attention/blob/2dc2a195890d323f6f9e1b74e4667099e6144f79/flash_attn/flash_attention.py#L42

For reasons unknown, this codepath is hurting performance for the larger models. Avoiding this codepath by sending a key_padding_mask of all 1s appears to fix the problem.

examples/llm/src/mosaic_gpt.py Outdated Show resolved Hide resolved
examples/llm/src/mosaic_gpt.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me now, but I'm going to leave approval because I'd like another pair of eyes since its a late breaking fix. cc @dskhudia or @vchiley for another look

Copy link
Contributor

@vchiley vchiley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if throughput looks good, then this lgtm

@abhi-mosaic
Copy link
Contributor Author

Verified with a 7B model on 80GB and 40GB cards, and throughput looks good. Adding a comment with link to this PR and then merging.

@abhi-mosaic abhi-mosaic merged commit 61a00d6 into release/v0.0.3 Feb 14, 2023
@abhi-mosaic abhi-mosaic deleted the abhi/debug_key_padding_mask branch February 14, 2023 19:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants