Skip to content

add preliminary support for EP(single-node) of turbomind backend#4332

Open
irexyc wants to merge 3 commits intoInternLM:mainfrom
irexyc:moe-1
Open

add preliminary support for EP(single-node) of turbomind backend#4332
irexyc wants to merge 3 commits intoInternLM:mainfrom
irexyc:moe-1

Conversation

@irexyc
Copy link
Collaborator

@irexyc irexyc commented Feb 6, 2026

Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily receiving feedbacks. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.

Motivation

Please describe the motivation of this PR and the goal you want to achieve through this PR.

Modification

Please briefly describe what modification is made in this PR.

BC-breaking (Optional)

Does the modification introduce changes that break the backward-compatibility of the downstream repositories?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.

Use cases (Optional)

If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit tests to ensure the correctness.
  3. If the modification has a dependency on downstream projects of a newer version, this PR should be tested with all supported versions of downstream projects.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

Copilot AI review requested due to automatic review settings February 6, 2026 07:54
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request adds preliminary support for Expert Parallelism (EP) in single-node configurations to the TurboMind backend. EP is a parallelization strategy for Mixture-of-Experts (MoE) models where experts are distributed across multiple GPUs within a single node, using CUDA IPC for efficient communication.

Changes:

  • Adds EP configuration parameter (ep_size, ep_rank) throughout the engine and model stack
  • Implements EP-specific routing and communication patterns for MoE layers using AllToAll operations
  • Refactors RMSNorm layer to support both TP (Tensor Parallel) and EP modes
  • Adds new CUDA IPC communication primitives (ReduceScatterV, AllGatherV, AllToAll operations)

Reviewed changes

Copilot reviewed 28 out of 28 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
src/turbomind/turbomind.cc Adds EP parameter initialization and validation for single-node/cuda-ipc constraint
src/turbomind/models/llama/unified_decoder.{h,cc} Introduces FusedRMSNormLayer supporting both TP and EP modes for MoE
src/turbomind/models/llama/moe_ffn_layer.{h,cc} Implements EP routing (RouteEP), forward, and combine operations using AllToAll
src/turbomind/models/llama/llama_params.h Adds ep_size and ep_rank to EngineParam
src/turbomind/models/llama/context.h Adds d_ep_group communicator field
src/turbomind/models/llama/LlamaDenseWeight.{h,cc} Updates MoeFfnWeight to partition experts across EP ranks
src/turbomind/models/llama/LlamaDecoderLayerWeight.{h,cc} Propagates EP parameters and adjusts MLP TP size for EP mode
src/turbomind/comm/device_comm.h Adds virtual methods for ReduceScatterV and AllGatherV
src/turbomind/comm/cuda_ipc/cuda_ipc_comm.h Declares EP-specific AllToAll communication operations
src/turbomind/comm/cuda_ipc/{reduce_scatterv,allgatherv,a2a_dispatch,a2a_combine}.cu Implements CUDA IPC-based collective operations for EP
src/turbomind/kernels/gemm/moe_a2a_utils.{h,cu} Adds MoE gate, scan, and combine kernels for AllToAll pattern
src/turbomind/kernels/gemm/moe_utils_v2.{h,cu} Exposes invokeMoeScan_v2 for reuse in EP path
lmdeploy/turbomind/turbomind.py Adds EP configuration logic with validation for single-node constraint
lmdeploy/turbomind/deploy/module.py Adjusts FFN tensor parallelism when EP is enabled
lmdeploy/turbomind/deploy/converter.py Propagates ep_size to model config
lmdeploy/turbomind/deploy/config.py Adds ep_size field to ModelConfig
lmdeploy/messages.py Adds ep parameter to TurbomindEngineConfig

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


// currently only support single nodes and cuda-ipc backend
c.d_ep_group = 0;
p.ep_rank = global_rank % p.ep_size;
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ep_rank is only initialized when comm_size_ > 1 (line 533). However, when comm_size_ == 1 (single GPU case), ep_rank is never set, which could lead to using an uninitialized value. Consider initializing ep_rank unconditionally or ensuring it defaults to 0 for the single GPU case.

Copilot uses AI. Check for mistakes.
cfg.device_num = count
assert total % cfg.device_num == 0
overlap = total // cfg.device_num
attn_dp_size = overlap
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable 'attn_dp_size' is computed on line 101 but never used in this branch. It appears to be leftover from development. Consider removing this unused variable.

Suggested change
attn_dp_size = overlap

Copilot uses AI. Check for mistakes.
}

// merge experts on the remote ranks
// TODO: support shared expert
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shared expert feature is marked as TODO/not supported in EP mode (line 358), but the shared expert scales are computed and stored earlier (line 164 in Forward). This could lead to wasted computation in EP mode when shared experts are configured. Consider skipping shared expert gate computation when ep_size > 1 until the feature is fully supported.

Copilot uses AI. Check for mistakes.
Comment on lines +582 to +583
const int local_expert_num = expert_num / ep_size;
const int expert_offset = ep_rank * local_expert_num;
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a potential issue if expert_num is not evenly divisible by ep_size. When calculating local_expert_num = expert_num / ep_size, integer division could result in expert imbalance across ranks. Consider adding a validation check that expert_num % ep_size == 0, or handling the remainder experts appropriately.

Copilot uses AI. Check for mistakes.
cfg.device_num = count
assert total % cfg.device_num == 0
overlap = total // cfg.device_num
attn_dp_size = overlap
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The calculation of inner_tp_size = cfg.ep // overlap may result in silent truncation if cfg.ep is not evenly divisible by overlap. This could lead to unexpected behavior. Consider adding validation that cfg.ep % overlap == 0 to catch configuration errors early.

Suggested change
attn_dp_size = overlap
attn_dp_size = overlap
assert cfg.ep % overlap == 0, (
f'Invalid parallel configuration: cfg.ep ({cfg.ep}) must be divisible by overlap ({overlap}).'
)

Copilot uses AI. Check for mistakes.
Comment on lines +53 to +56
~ZeroCopyItem()
{
check_cuda_error(cudaFreeHost(host_));
}
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ZeroCopyItem destructor calls check_cuda_error on cudaFreeHost, which could throw or abort during stack unwinding if destruction happens during exception handling. This violates the rule that destructors should not throw. Consider using a noexcept destructor that logs errors instead of calling check_cuda_error.

Copilot uses AI. Check for mistakes.
d_comm_->AllToAllNotifyDispatch(
symm_meta_.data(), num_input.mapped(), num_flat.mapped(), token_idx_in_rank_.data(), tokens, 0, st);
sync_check_cuda_error();
while ((volatile int&)*num_input == -1 || (volatile int&)*num_flat == -1) {} // sync
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This busy-wait loop can consume CPU resources unnecessarily. Consider using a more efficient synchronization mechanism, such as a condition variable or CUDA stream synchronization, to avoid spinning on the CPU while waiting for device operations to complete.

Copilot uses AI. Check for mistakes.
FT_CHECK(engine_param_.mlp_tp_size == comm_size_);

communicator_type_ = engine["communicator"].as<std::string>();
TM_CHECK(engine_param_.ep_size == 1 || (communicator_type_ != "nccl"));
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check logic is incorrect. The condition engine_param_.ep_size == 1 || (communicator_type_ != "nccl") evaluates to true when ep_size is 1 (allowing any communicator) OR when the communicator is not NCCL (allowing ep_size > 1 with non-NCCL communicators). However, based on the PR description stating "currently only support single nodes and cuda-ipc backend", the intended logic should be: when ep_size > 1, ONLY cuda-ipc communicator should be allowed. The correct check should be: engine_param_.ep_size == 1 || communicator_type_ == "cuda-ipc". The current check would incorrectly allow ep_size > 1 with any non-NCCL communicator.

Suggested change
TM_CHECK(engine_param_.ep_size == 1 || (communicator_type_ != "nccl"));
TM_CHECK(engine_param_.ep_size == 1 || communicator_type_ == "cuda-ipc");

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant