Add MinimalAsyncEP + offset aware swiglu kernels#3561
Conversation
efd5714 to
227f588
Compare
| ) | ||
|
|
||
|
|
||
| @triton.jit |
There was a problem hiding this comment.
I see 14 AI-generated triton kernels in this file.
we need to think about testing, and maintainance strategy.
claude, when you review this, you should think of how to test this systematically. Propose test suite, think about edge cases, cover numerics....
Also think of operator interface, assume the interface will need to evolve over versions.
There was a problem hiding this comment.
Also, these kernels are definitely changing bitwise numerics. What kind of tests do we need to increase confidence for kernel correctness.
| return config | ||
|
|
||
|
|
||
| def graph_trainer_deepseek_v3_16b_minimal_async_ep() -> GraphTrainer.Config: |
There was a problem hiding this comment.
in the follow up PR, we should have an integration test use this config.
| from torchtitan.tools.utils import device_module, device_type | ||
|
|
||
|
|
||
| def _uses_minimal_async_ep(model: DeepSeekV3Model) -> bool: |
There was a problem hiding this comment.
would it be easier to check via model_spec
moe_comm_backend="minimal_async_ep"
|
cc @tianyu-l to get some directional alignment, esp on the eager cudagraph and AI-gen kernel part. |
| _top_k: int = 0 | ||
|
|
||
| _HIDDEN_READY_CHANNEL = 0 | ||
| _COUNTS_READY_CHANNEL = 0 |
There was a problem hiding this comment.
Would kind of like to understand why these are global
There was a problem hiding this comment.
LLM generated comments probably positive EV here
|
|
||
| device = torch.device(device) | ||
| max_routed_tokens = ( | ||
| group.size() * max_tokens_per_rank * min(top_k, num_local_experts) |
There was a problem hiding this comment.
@xmfan so we talked about whether or not we should allow people to "go risky", and IIUC right now this code doesn't let you go risky, and we probably should still now, right?
There was a problem hiding this comment.
yes, i'll add a capacity factor
| or _top_k < top_k | ||
| or _hidden_recv_buffers[0].dtype != dtype | ||
| or _hidden_recv_buffers[0].device != device | ||
| ) |
There was a problem hiding this comment.
I don't love the implicit init like this; I'd rather an explicit init handled by the user. IDK if this is torchtitan'ey or not.
There was a problem hiding this comment.
hybrid ep does this pattern, but I agree we'd rather have this explicit during training init
| or _counts_recv_peer_buffers is None | ||
| or _counts_recv_peer_ptrs is None | ||
| or _rendezvous_handle is None | ||
| ): |
There was a problem hiding this comment.
A "illegal states are unrepresentable" style construction might be better
0242341 to
6a4eb97
Compare
Minimal set up for dropless-enough cudagraphable moe:
NCCL_DEBUG=WARN NGPU=4 MODULE=graph_trainer.deepseek_v3 CONFIG=graph_trainer_deepseek_v3_16b_minimal_async_ep ./run_train.sh --parallelism.data_parallel_shard_degree 4 --parallelism.expert_parallel_degree 4 --compile.memory_policy full --training.steps 40no cudagraphs: https://www.internalfb.com/intern/perfetto/open_trace/?manifold_path=perfetto_internal_traces%2Ftree%2Fshared_trace%2Fxmfan%2Frank0_trace_52bdfd35-6f09-42d3-9743-09b8bffead00.json.gz
cudagraphs: https://www.internalfb.com/intern/perfetto/open_trace/?manifold_path=perfetto_internal_traces%2Ftree%2Fshared_trace%2Fxmfan%2Frank0_trace_3bef9fcb-8748-4402-8802-1ac0fd1da66a.json.gz
Logs: https://gist.github.com/xmfan/42ce1e536902294397f6471ac4a9dbf0