Skip to content

Conversation

@kocchop
Copy link
Collaborator

@kocchop kocchop commented Jan 1, 2025

Description

This PR adds Context Parallelism support to GPU Flash Attention. It is necessary to support large sequence lengths in MaxText. Right now, the support is offered through Transformer-Engine and uses an All-Gather type implementation. Note that, it requires mask type to be causal and does not work with sliding window attention yet. Also it requires transformer-engine==1.13 or above.

NEW

  1. Implemented input sequence re-ordering at the data loading stage
  2. Added context_parallel_load_balance flag in base.yml to turn on/off load balancing.
  3. Added sequence packing flag enable_packing in base.yml and also modified the associated data processing files.
  4. Added/modified unit tests for context parallelism and GPU flash attention cudnn_flash_te
  5. Added sharding of parameters across CP rank

Tests

Unit test is included with the PR with the base model for 4 x a100 gpus,

Checklist

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, this looks great! Just one change to the sharding rules (with current state of this PR it looks like the old sequence parallelism is broken)

I'd appreciate it if you could also run bash code_style.sh to run our linter (I recommend saving the branch before running this just in case...)

Copy link
Collaborator

@A9isha A9isha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! thank you so much

Curious, what kind of improvement are you observing with the loadbalancing enabled?

@A9isha
Copy link
Collaborator

A9isha commented Mar 7, 2025

Oh also please merge all commits into one before pushing

@kocchop
Copy link
Collaborator Author

kocchop commented Mar 11, 2025

This is great! thank you so much

Curious, what kind of improvement are you observing with the loadbalancing enabled?

for llama3-8b, with load balancing and cp=2, we're able to see around ~10% perf improvement

@kocchop kocchop requested a review from gobbleturk March 22, 2025 00:58
@kocchop
Copy link
Collaborator Author

kocchop commented Mar 22, 2025

@gobbleturk @A9isha once I get a LGTM from you, I can work on resolving the merge conflicts and also merge all the commits into one.

Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I would like either @khatwanimohit or @aireenmei to review for the data loading change part enable_packing=True

Copy link
Collaborator

@A9isha A9isha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - thank you!

@kocchop kocchop requested a review from aireenmei April 11, 2025 02:06
@kocchop kocchop force-pushed the faysal/add-cp-to-cudnn-flash-te branch from a3e1739 to 5016d35 Compare April 11, 2025 03:12
@kocchop
Copy link
Collaborator Author

kocchop commented Apr 11, 2025

Hi @gobbleturk could you please approve so that the unit tests can run?

@kocchop kocchop force-pushed the faysal/add-cp-to-cudnn-flash-te branch from 5016d35 to dbf3d70 Compare April 11, 2025 21:27
1. Implemented input sequence re-ordering at the data loading stage
2. Added context_parallel_load_balance flag in base.yml to turn on/off load balancing.
3. Added sequence packing flag packing in base.yml and also modified the associated data processing files.
4. Added/modified unit tests for context parallelism and GPU flash attention cudnn_flash_te
5. Added sharding of parameters across CP rank (modified the logical axis sharding rules)
6. fixed the q,k,v and out_proj sharding axis names by adding MODEL_MODE_TRAIN in attention.py
@kocchop kocchop force-pushed the faysal/add-cp-to-cudnn-flash-te branch from dbf3d70 to 336fb38 Compare April 11, 2025 23:09
@kocchop kocchop requested a review from gobbleturk April 11, 2025 23:09
Copy link
Collaborator

@A9isha A9isha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

Copy link
Collaborator

@aireenmei aireenmei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for data input changes!

@gobbleturk
Copy link
Collaborator

All tests passed, manually adding pull ready

@gobbleturk
Copy link
Collaborator

gobbleturk commented Apr 12, 2025

pull ready might have added successfully automatically after an additional allowing step by me to run next steps, not sure

@copybara-service copybara-service bot merged commit c4f7060 into AI-Hypercomputer:main Apr 12, 2025
18 of 21 checks passed
@A9isha A9isha mentioned this pull request Apr 17, 2025
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants