Skip to content

Add MinimalAsyncEP + offset aware swiglu kernels#3561

Open
xmfan wants to merge 48 commits into
mainfrom
xmfan/minimal_async_ep
Open

Add MinimalAsyncEP + offset aware swiglu kernels#3561
xmfan wants to merge 48 commits into
mainfrom
xmfan/minimal_async_ep

Conversation

@xmfan

@xmfan xmfan commented Jun 5, 2026

Copy link
Copy Markdown
Member

Minimal set up for dropless-enough cudagraphable moe:

  1. only for full recompute
  2. 1 worst case buffer allocated on CUDA, only used by x_recv, one buffer for whole model
  3. no cpu-sync ep dispatch/combine
  4. offset aware swiglu to avoid processing padding

NCCL_DEBUG=WARN NGPU=4 MODULE=graph_trainer.deepseek_v3 CONFIG=graph_trainer_deepseek_v3_16b_minimal_async_ep ./run_train.sh --parallelism.data_parallel_shard_degree 4 --parallelism.expert_parallel_degree 4 --compile.memory_policy full --training.steps 40

no cudagraphs: https://www.internalfb.com/intern/perfetto/open_trace/?manifold_path=perfetto_internal_traces%2Ftree%2Fshared_trace%2Fxmfan%2Frank0_trace_52bdfd35-6f09-42d3-9743-09b8bffead00.json.gz

cudagraphs: https://www.internalfb.com/intern/perfetto/open_trace/?manifold_path=perfetto_internal_traces%2Ftree%2Fshared_trace%2Fxmfan%2Frank0_trace_3bef9fcb-8748-4402-8802-1ac0fd1da66a.json.gz

Logs: https://gist.github.com/xmfan/42ce1e536902294397f6471ac4a9dbf0

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 5, 2026
@xmfan xmfan force-pushed the xmfan/minimal_async_ep branch from efd5714 to 227f588 Compare June 5, 2026 23:32
@xmfan xmfan changed the title Add MinimalAsyncEP token dispatcher Add MinimalAsyncEP + offset aware swiglu kernels Jun 5, 2026
Comment thread torchtitan/distributed/minimal_async_ep_kernels.py Outdated
Comment thread torchtitan/distributed/minimal_async_ep_kernels.py Outdated
Comment thread torchtitan/models/deepseek_v3/config_registry.py
Comment thread torchtitan/distributed/minimal_async_ep_kernels.py Outdated
Comment thread torchtitan/distributed/cudagraph.py Outdated
Comment thread torchtitan/distributed/activation_checkpoint.py Outdated
Comment thread torchtitan/distributed/cudagraph.py Outdated
)


@triton.jit

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see 14 AI-generated triton kernels in this file.

we need to think about testing, and maintainance strategy.

claude, when you review this, you should think of how to test this systematically. Propose test suite, think about edge cases, cover numerics....
Also think of operator interface, assume the interface will need to evolve over versions.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, these kernels are definitely changing bitwise numerics. What kind of tests do we need to increase confidence for kernel correctness.

Comment thread torchtitan/distributed/minimal_async_ep.py Outdated
return config


def graph_trainer_deepseek_v3_16b_minimal_async_ep() -> GraphTrainer.Config:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the follow up PR, we should have an integration test use this config.

Comment thread torchtitan/models/deepseek_v3/parallelize.py Outdated
from torchtitan.tools.utils import device_module, device_type


def _uses_minimal_async_ep(model: DeepSeekV3Model) -> bool:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be easier to check via model_spec

moe_comm_backend="minimal_async_ep"

Comment thread torchtitan/models/common/config_utils.py Outdated
Comment thread torchtitan/models/deepseek_v3/parallelize.py Outdated
Comment thread torchtitan/trainer.py Outdated
Comment thread torchtitan/trainer.py Outdated
Comment thread torchtitan/trainer.py Outdated
Comment thread torchtitan/models/common/moe.py Outdated
@SherlockNoMad

SherlockNoMad commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

cc @tianyu-l to get some directional alignment, esp on the eager cudagraph and AI-gen kernel part.

_top_k: int = 0

_HIDDEN_READY_CHANNEL = 0
_COUNTS_READY_CHANNEL = 0

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would kind of like to understand why these are global

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LLM generated comments probably positive EV here


device = torch.device(device)
max_routed_tokens = (
group.size() * max_tokens_per_rank * min(top_k, num_local_experts)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xmfan so we talked about whether or not we should allow people to "go risky", and IIUC right now this code doesn't let you go risky, and we probably should still now, right?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, i'll add a capacity factor

or _top_k < top_k
or _hidden_recv_buffers[0].dtype != dtype
or _hidden_recv_buffers[0].device != device
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love the implicit init like this; I'd rather an explicit init handled by the user. IDK if this is torchtitan'ey or not.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hybrid ep does this pattern, but I agree we'd rather have this explicit during training init

or _counts_recv_peer_buffers is None
or _counts_recv_peer_ptrs is None
or _rendezvous_handle is None
):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A "illegal states are unrepresentable" style construction might be better

Comment thread torchtitan/distributed/minimal_async_ep.py Outdated
Comment thread torchtitan/distributed/minimal_async_ep.py Outdated
Comment thread torchtitan/distributed/minimal_async_ep.py
@xmfan xmfan force-pushed the xmfan/minimal_async_ep branch from 0242341 to 6a4eb97 Compare June 10, 2026 01:39
@xmfan xmfan marked this pull request as ready for review June 10, 2026 01:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants