Skip to content

Conversation

@shuningjin
Copy link
Collaborator

@shuningjin shuningjin commented Jul 25, 2025

Description

Goal

For mixture of expert models, we may use expert parallelism. For attention layer, EP acts as FSDP currently. Built upon previous context parallelism work, this PR is to introduce the option of using EP as CP for attention. This is joint effort with @RissyRan.

FIXES: b/418396648

Main code changes

  • attentions.py

    • MHA, MLA, tpu_flash_attention
    • changed logical_axis
  • base.yml

    • change logical_axis_rules
    • add a flag to customize expert sharding behavior in attention expert_shard_attention_option: fsdp or context
  • unit test: tests.attention_test

    • AttentionTest.test_tpu_flash_attention_cp_and_ep & MLATest.test_tpu_flash_attention_cp_and_ep (extended from cp test)
    • using cp/ep with tpu_flash_attention, with or without context_parallel_load_balance
    • compare logit against dot_product without sharding

Use case

  • training, MoE with MHA / MLA attention
  • cp_load_balance={true, false}, tpu_flash_attention
  • Example (ep_as_cp): ici_expert_parallelism=4, ici_context_parallelism=1, expert_shard_attention_option=context, shard context by 4 in attention, shard expert by 4 for moe
  • Example (ep_as_cp + native cp): ici_expert_parallelism=2, ici_context_parallelism=2, expert_shard_attention_option=context, shard context by 4 in attention, shard expert by 2 and context by 2 for moe

Tests

Tested on v5p-8

Verify sharding shape

  • end-to-end pretraining: reduced version of Mixtral-8x7b (for MHA) and reduced DeepSeek3-671b (for MLA)
  • tpu_flash_kernel, context_parallel_load_balance=True, parallelism: ici_expert_parallelism=2, ici_context_parallelism=2, expert_shard_attention_option=context and ici_expert_parallelism=2, ici_context_parallelism=2, expert_shard_attention_option=fsdp
  • See test details in b/418396648#comment7

Verify attention output logit against dot product

  • use the added unit test
  • tpu_flash_kernel, attention {MHA, MLA}, context_parallel_load_balance={True, False}, parallelism: {ici_expert_parallelism=4, expert_shard_attention_option=context, ici_expert_parallelism=2, expert_shard_attention_option=context, ici_context_parallelism=2}

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Shuning! Great work!

Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you considered an approach like conditionally modifying the rules (instead of creating new ones?) This is an approach used for pipeline parallelism

def modify_activation_embed_and_logits_batch(logical_axis_rules):

there are pros and cons of both, both pretty ugly IMO but at least when modifying rules there are

  1. less rules
  2. no if statements which rules to use

Copy link
Collaborator

@richjames0 richjames0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really impressive that you understood this and got it working!

@shuningjin shuningjin requested a review from NuojCheng as a code owner August 5, 2025 22:01
@shuningjin shuningjin force-pushed the shuningjin-ep branch 5 times, most recently from a9f3ea9 to 40f8ad2 Compare August 7, 2025 21:16
@shuningjin
Copy link
Collaborator Author

shuningjin commented Aug 8, 2025

Resolved merge conflict with nnx migration for attention layer

Re-testing on local v5p-8, diff (before vs. after nnx migration)

  • Mini mixtral: 1.1, 1.2. Mini deepseek3: 2.1, 2.2
  • Sharding shape is still correct, training loss is close
  • TFLOP/s/device is better now. In both cases, the FLOPs calculation has included the recent change. Possible reasons: different docker image (jax 0.6.2 vs. 0.7.0), linen vs. nnx.

@copybara-service copybara-service bot merged commit b98c47a into main Aug 8, 2025
18 checks passed
@copybara-service copybara-service bot deleted the shuningjin-ep branch August 8, 2025 05:25
@shuningjin shuningjin mentioned this pull request Aug 29, 2025
4 tasks
@gobbleturk gobbleturk mentioned this pull request Sep 8, 2025
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants