Skip to content

Integrate deep-ep nccl backend#4477

Open
irexyc wants to merge 26 commits into
InternLM:mainfrom
irexyc:moe-2
Open

Integrate deep-ep nccl backend#4477
irexyc wants to merge 26 commits into
InternLM:mainfrom
irexyc:moe-2

Conversation

@irexyc
Copy link
Copy Markdown
Collaborator

@irexyc irexyc commented Mar 27, 2026

Related link deepseek-ai/DeepEP#521

Copilot AI review requested due to automatic review settings March 27, 2026 12:51
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR integrates DeepEP-based Expert Parallelism (EP) over the NCCL backend into TurboMind, wiring EP initialization into runtime context creation and extending LLaMA MoE execution to support EP token routing/dispatch/combine.

Changes:

  • Add DeepEP/NCCL EP backend (NcclCommImpl::InitializeEp/Dispatch/Combine) and build it as a new deepep static library.
  • Extend TurboMind engine/model parameters for EP (ep_size, ep_rank, ll_max_tokens_per_rank) and initialize EP in TurboMind::Impl::CreateContext.
  • Update LLaMA unified decoder + MoE FFN to support EP routing and add a fused RMSNorm path that supports EP token partitioning (ReduceScatterV/AllGatherV).

Reviewed changes

Copilot reviewed 41 out of 42 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
src/turbomind/turbomind.cc Parse EP/LL params and initialize EP in device communicator during context setup
src/turbomind/models/llama/unified_decoder.{h,cc} Add EP-aware hidden-state layout + fused RMSNorm integration + partial-token FFN execution
src/turbomind/models/llama/moe_ffn_layer.{h,cc} Add EP routing/dispatch/combine implementation and EP-mode state
src/turbomind/models/llama/llama_params.h Add EP + LL threshold parameters to engine/moe config
src/turbomind/models/llama/LlamaDenseWeight.{h,cc} Shard MoE expert weights by ep_size/ep_rank
src/turbomind/models/llama/LlamaDecoderLayerWeight.{h,cc} Thread EP params into MoE weight construction; adjust MLP TP handling for EP
src/turbomind/models/llama/FusedRMSNormLayer.h New TP/EP fused RMSNorm abstraction with EP ReduceScatterV/AllGatherV
src/turbomind/kernels/gemm/moe_ep_utils.{h,cu} New kernels/utilities for EP gating and (LL/HT) combine helpers
src/turbomind/comm/device_comm.h Extend device-comm interface with ReduceScatterV/AllGatherV and EP APIs
src/turbomind/comm/nccl/{nccl_comm.h,nccl.cu,nccl_ep.cu} Refactor NCCL comm impl into header + add DeepEP EP ops
src/turbomind/comm/nccl/deep_ep/* Vendored DeepEP implementation and kernels
src/turbomind/comm/nccl/CMakeLists.txt Build/link deepep and include EP source in nccl_comm
lmdeploy/turbomind/turbomind.py Add EP parallel-config derivation in Python front-end
lmdeploy/turbomind/deploy/{config.py,converter.py,module.py} Plumb ep_size into deploy config and TP sizing for EP
lmdeploy/messages.py Add ep to TurbomindEngineConfig
lmdeploy/cli/serve.py Add CLI wiring to pass --ep into engine config
src/turbomind/models/llama/llama_utils.cu Add Compare<int64_t> instantiation

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/turbomind/kernels/gemm/moe_ep_utils.cu Outdated
Comment thread src/turbomind/comm/nccl/nccl.cu Outdated
Comment on lines +89 to +107
if not complete_parallel_config(cfg) and cfg.ep > 1:
if cfg.communicator in ['cuda-ipc', 'native']:
assert cfg.nnodes == 1, 'TurboMind does not support multi-node with ep > 1'
total = cfg.dp * cfg.ep
if not cfg.device_num:
count = torch.cuda.device_count() * cfg.nnodes
if total < count:
count = total
cfg.device_num = count
assert total % cfg.device_num == 0
overlap = total // cfg.device_num
attn_dp_size = overlap
inner_tp_size = cfg.ep // overlap
cfg.outer_dp_size = cfg.dp // overlap
cfg.attn_dp_size = overlap // cfg.nnodes
cfg.attn_tp_size = inner_tp_size // cfg.cp
cfg.attn_cp_size = cfg.cp
cfg.mlp_dp_size = 1
cfg.mlp_tp_size = cfg.attn_dp_size * cfg.attn_tp_size * cfg.attn_cp_size
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EP path can compute attn_dp_size = overlap // cfg.nnodes, which becomes 0 for common multi-node cases (e.g., overlap==1 and nnodes>1), violating later invariants and producing invalid parallel config. Since device_num already accounts for nnodes, avoid dividing overlap by nnodes here (or otherwise ensure attn_dp_size>=1 with a correct derivation).

Copilot uses AI. Check for mistakes.
Comment on lines 33 to +46

void SetWarpup(ForwardParam& p);

void ForwardNative(ForwardParam& p);

void ForwardFused(ForwardParam& p);

void RouteTP(ForwardParam& p, Tensor_<float>& logits);

void RouteEP(ForwardParam& p, Tensor_<float>& logits);

void CombineTP(ForwardParam& p);

void CombineEP(ForwardParam& p);
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in new private helper name SetWarpup (should be SetWarmup). Keeping the misspelling makes call sites harder to discover/search and looks inconsistent with the existing is_warm_up_ naming.

Copilot uses AI. Check for mistakes.
Comment on lines +52 to +58
struct EpCombineInput {
EpMode& mode;
core::Tensor& x;
std::vector<core::Tensor>& handle;
std::optional<core::Tensor> topk_weights;
std::optional<core::Tensor> topk_idx;
};
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EpCombineInput uses std::optional, but this header doesn't include <optional>, which will cause compilation errors depending on include order. Add #include <optional> (and keep headers self-contained).

Copilot uses AI. Check for mistakes.
Comment on lines +60 to +70
int comm_nranks_ = -1; // Number of ranks in NCCL communicator

ncclComm_t nccl_comm_;

ncclDevComm_t dev_ht_comm_{};
ncclDevComm_t dev_ll_comm_{};

std::unordered_map<void*, ncclWindow_t> wins_;
std::unordered_map<void*, size_t> buffers_;

// GIN signal management
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This header declares std::unordered_map members (wins_, buffers_) but doesn't include <unordered_map>, which will fail to compile in translation units that include this header first. Add the missing include (and any other required STL headers) to keep the header self-contained.

Copilot uses AI. Check for mistakes.
Comment on lines 589 to 604
@@ -596,7 +600,7 @@ MoeFfnWeight::MoeFfnWeight(int layer_id,
group_size,
act_type,
fuse_silu_act});
register_module("experts", *experts.back(), i);
register_module("experts", *experts.back(), i + expert_offset);
}
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

local_expert_num is computed via integer division (expert_num / ep_size) without validating divisibility. If expert_num isn't a multiple of ep_size, this will silently drop experts and mis-register / mis-load weights. Add a TM_CHECK_EQ(expert_num % ep_size, 0) (and ideally validate ep_rank < ep_size) before computing local_expert_num/expert_offset.

Copilot uses AI. Check for mistakes.
@lvhan028 lvhan028 added the enhancement New feature or request label Apr 2, 2026
irexyc and others added 14 commits April 7, 2026 14:08
- Resolve conflicts in nccl.cu / turbomind.cc / LlamaDecoderLayerWeight.cc /
  LlamaLinear.cu / CMakeLists.txt, adopting main's modern fmt-style logger
  while preserving moe-2's EP / DeepEP / ContextGuard additions.
- Migrate remaining printf-style TM_LOG_* calls in moe-2 added files
  (deep_ep.cpp, gin_backend.cu, nccl_ep.cu) to fmt-style ({} placeholders),
  rename TM_LOG_WARNING to TM_LOG_WARN, and switch utils/logger.h includes
  to core/logger.h.

Made-with: Cursor
irexyc and others added 4 commits May 18, 2026 04:55
Merge of origin/main (aed026f) into moe-2 (aa20784); merge-base e38927c.
Two upstream refactors required re-applying moe-2's DeepEP / expert-parallel
work onto a changed foundation rather than a textual merge:

- 01ddf16 (turbomind modeling infra): deleted LlamaDenseWeight /
  LlamaDecoderLayerWeight / LlamaWeight and the Python turbomind/deploy/*
  conversion stack; replaced by model_weight / decoder_layer_weight /
  moe_weight / ffn_weight / linear_weight / ... and a new Python loader.
  MoeParam/ModelParam removed; geometry now flows via core::*Config X-macros.
- a4025b9 (CUDA error handling): removed check_cuda_error /
  sync_check_cuda_error / FT_CHECK / CUDRVCHECK in favour of TM_CUDA_CHECK /
  TM_CUDRV_CHECK / TM_CHECK + manual scope tracing.

Port summary:
- EngineConfig + core::MoeConfig X-macros gain ep_size / ep_rank /
  ll_max_tokens_per_rank (auto-bound to Python). MoeWeight carries them and
  prepare() links only the local expert window
  (local_num_experts/local_expert_offset), exemplar = first local expert.
- moe_ffn_layer / unified_decoder reconciled onto origin ctors + MoeWeight /
  DecoderLayerWeight accessors; RouteEP/CombineEP/SetWarmup and
  FusedRMSNormLayer/HiddenStateLayout layered on. Fused-only path (origin has
  no MoeParam::kNaive equivalent).
- LlamaLinear: origin LinearWeight/out-param API + EP fp8-scales overload;
  dispatch driven by total mapping size to match merged moe_utils_v2
  (num_expert_tokens) semantics.
- nccl.cu: kept moe-2's out-of-line structure (needed by nccl_ep/nccl_comm.h),
  re-applied the a4025b9 macro/scope conversion.
- New moe-2 files (moe_ep_utils, nccl_ep, FusedRMSNormLayer.h) converted to the
  new error macros.
- turbomind.cc: dropped YAML parsing (EngineConfig-driven); EP InitializeEp
  relocated to ProcessWeights where ModelWeight geometry is known.
- Python: ec.ep_size / ec.ll_max_tokens_per_rank plumbed; make_moe_config
  gained EP fields (defaults keep the TP path identical).

Build verified: `ninja _turbomind` links the full extension cleanly.

Follow-up (runtime EP enablement, owner-validated): per-GPU ep_rank
ParallelGroup through the new model loader/builders and local-expert-range
expert construction across the rewritten model specs. Defaults make ep_size→1
(safe TP fallback); non-EP paths are unaffected.

Backup: branch backup/moe-2-premerge, tag premerge-moe-2.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

target_link_libraries(nccl_comm PRIVATE deepep)
else()
message(STATUS "Skip deepep build because NCCL ${NCCL_VERSION_STRING} < 2.29.7")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we raise FATAL error message?

int num_nodes;
int num_experts;
int experts_per_token;
int hidden;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convertional "hidden_size" or "hidden_dim" is more appreciated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants