-
Notifications
You must be signed in to change notification settings - Fork 593
Port TRT-LLM communication kernels to flashinfer #2102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughAdds a throughput-optimized MOE all-to-all backend: new CUDA dispatch/combine/sanitize kernels and headers, C++ TVM/FFI bridge exposing workspace/init/dispatch/combine/sanitize/metainfo APIs, Python JIT/AOT integration and MoeAlltoAll manager, env helpers, meta-info utilities, tests, and CI script updates. Changes
Sequence Diagram(s)sequenceDiagram
participant App as Application
participant Py as Python MoeAlltoAll
participant JIT as JIT Module
participant Cpp as C++ FFI
participant CUDA as CUDA Kernels
participant Net as Network/P2P
App->>Py: moe_a2a_dispatch(...)
Py->>JIT: call registered dispatch op (workspace, metainfo)
JIT->>Cpp: FFI -> moe_a2a_dispatch
Cpp->>CUDA: launch prepare_dispatch kernel
Cpp->>CUDA: launch dispatch kernel
CUDA->>Net: write/send per-rank payloads
Net->>CUDA: peers receive payloads
App->>Py: moe_a2a_combine(...)
Py->>JIT: call combine op
JIT->>Cpp: FFI -> moe_a2a_combine
Cpp->>CUDA: launch prepare_combine (copy recv->workspace)
Cpp->>CUDA: launch combine kernel
CUDA->>Cpp: return outputs / set flags
App->>Py: moe_a2a_sanitize_expert_ids(...)
Py->>JIT: sanitize op
JIT->>Cpp: FFI -> sanitize entry
Cpp->>CUDA: launch sanitize kernel
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
710a388 to
bd82a2b
Compare
| #define check_timeout(s) false | ||
| #else | ||
| // 300 * 2000 MHz - should be high enough on any GPU but will prevent a hang | ||
| #define check_timeout(s) ((clock64() - (s)) > (300ll * 2000ll * 1000ll * 1000ll)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have manually added this, can I get someone to sanity check my logic here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me
f766bfe to
8cdf8d8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
🧹 Nitpick comments (18)
csrc/nv_internal/cpp/common/envUtils.cpp (1)
357-357: Consider cachinggetEnvEplbForceGdrcopylike other bool env helpers
getEnvEplbForceGdrcopycallsgetBoolEnv(and thusstd::getenv) on every invocation, while most other helpers in this file cache the value in astaticlocal. Functionally this is fine, but for consistency and to avoid repeated env lookups in hot paths you might want to align it:-bool getEnvEplbForceGdrcopy() { return getBoolEnv("TRTLLM_EPLB_FORCE_GDRCOPY"); } +bool getEnvEplbForceGdrcopy() { + static bool const forceGdrcopy = getBoolEnv("TRTLLM_EPLB_FORCE_GDRCOPY"); + return forceGdrcopy; +}Not critical, but it would match the rest of the env-utils style.
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (2)
428-431: Consider documenting why the acquire fence is commented out.The
fence.acquire.sysat line 430 is commented out after the dispatch wait loop. While the combine kernel (line 735) does have an acquire fence, having it commented here without explanation could cause confusion for future maintainers. If this is intentional (relying on the combine kernel's fence), a brief comment explaining the design decision would help.} - // asm volatile("fence.acquire.sys;"); + // NOTE: Acquire fence intentionally omitted here; combine kernel provides + // the acquire semantics before reading peer data. #endif
596-609: Generic fallback is unreachable code.The generic fallback reduction loop (lines 599-608) can never be reached because the
SWITCH_TOP_Kmacro (lines 53-78) only allowsTOP_Kvalues of 1, 2, 4, or 8, and all these cases have explicit handling above. Consider removing this dead code or adding astatic_assertto document the constraint.} else if constexpr (TOP_K == 1) { // nothing to do - } else { - // Generic fallback: accumulate all into acc[0] - T* a0 = reinterpret_cast<T*>(&acc[0]); -#pragma unroll - for (int k = 1; k < TOP_K; ++k) { - T* ak = reinterpret_cast<T*>(&acc[k]); -#pragma unroll - for (int j = 0; j < elems_per_vec; ++j) { - a0[j] += ak[j]; - } - } + } else { + static_assert(TOP_K == 1 || TOP_K == 2 || TOP_K == 4 || TOP_K == 8, + "Only TOP_K values 1, 2, 4, 8 are supported"); }scripts/task_test_multi_node_comm_kernels.sh (1)
9-13: Disabling cache cleanup may cause stale import issues.The cache cleanup commands are commented out. If module refactoring occurs between test runs, stale
.pycfiles could cause import errors or unexpected behavior. Consider re-enabling these commands or documenting why they're disabled.tests/comm/test_trtllm_moe_alltoall.py (4)
1-2: Copyright year should be updated to 2025.The license header shows 2024 but this is a new file created in 2025.
-Copyright (c) 2024 by FlashInfer team. +Copyright (c) 2025 by FlashInfer team.
112-112: Potential issue with payload size calculation.
x[0].numel()gets the number of elements in the first row, but ifinput_tensorsis a list of 2D tensors, this calculates size per token correctly. However, the variable namepayload_size_per_tokenand the indexingx[0]could be clearer.- payload_size_per_token = sum([x[0].numel() * x.itemsize for x in input_tensors]) + payload_size_per_token = sum([x.shape[-1] * x.element_size() for x in input_tensors])
207-236: CUDA streams created but not explicitly cleaned up.The
cuda_streams_all_rankslist creates CUDA streams that are not explicitly destroyed. While Python's garbage collector will eventually clean them up, for test reliability consider using a context manager or explicit cleanup.
411-411: Minor typo in comment.Extra slash at end of comment.
- # For each expert selected for this token/ + # For each expert selected for this tokentests/comm/test_mnnvl_moe_alltoall.py (4)
37-46: Consider usingraisewithout exception name per Python best practices.The explicit
raise eis redundant; bareraisepreserves the traceback better.def safe_run(func, *args, **kwargs): comm = MPI.COMM_WORLD try: func(*args, **kwargs) except MPIExit as e: - raise e + raise except Exception as e: traceback.print_exc() comm.allgather(True) - raise e + raise
49-51: Test fixture should yield for proper cleanup semantics.Even though no cleanup is needed, the fixture pattern should include
yieldfor consistency.@pytest.fixture(autouse=True) def setup_test(): torch.manual_seed(0x1234) + yield
571-576: Blind exception catch may mask real initialization errors.Catching bare
Exceptionwhen checking MNNVL support could hide legitimate configuration issues. Consider catching specific exception types or at least logging the exception.try: MnnvlMemory.initialize() if not MnnvlMemory.supports_mnnvl(): pytest.skip("MNNVL not supported on this system") - except Exception: + except (RuntimeError, pynvml.NVMLError) as e: + # Log exception for debugging if needed pytest.skip("MNNVL not supported on this system")
709-712: Unused variableexpert_id_payload_indexas flagged by static analysis.The unpacked variable is never used. Either prefix with underscore or remove from unpacking.
- payloads, expert_id_payload_index = make_bfloat16_payloads( + payloads, _expert_id_payload_index = make_bfloat16_payloads( local_num_tokens, hidden_size, top_k, rank, token_selected_experts )flashinfer/comm/trtllm_moe_alltoall.py (5)
8-8: TODO comment should be addressed or tracked.The
# TODO Reviewcomment at the top suggests this module needs review. Consider removing after review or converting to a tracked issue.Would you like me to open an issue to track any remaining review items?
351-351: Mutable class attribute should useClassVarannotation.Per static analysis and Python best practices, mutable class attributes should be annotated with
typing.ClassVar.+from typing import ClassVar + class MoeAlltoAll: ... # Single shared workspace across the process - _WORKSPACE_CACHE: dict[tuple[int, int, int, int], dict] = {} + _WORKSPACE_CACHE: ClassVar[dict[tuple[int, int, int, int], dict]] = {}
456-463: Assertions for validation could use proper exceptions in production.Using
assertfor validation is acceptable for debug builds but these checks may be skipped in optimized Python (python -O). Consider using explicitif/raisefor critical invariants.
610-610: Inefficient way to get element size.Creating an empty tensor just to get element size is wasteful. Use
torch.finfoortorch.iinfoor a lookup table instead.- element_size = torch.tensor([], dtype=dtype).element_size() + # More efficient: use dtype itemsize directly + element_size = torch.empty(0, dtype=dtype).element_size()Or better, consider caching element sizes or using:
element_size = torch.finfo(dtype).bits // 8 if dtype.is_floating_point else torch.iinfo(dtype).bits // 8
621-628:__all__is not sorted as noted by static analysis.Consider sorting for consistency, though this is a minor issue.
__all__ = [ "MoeAlltoAll", "moe_a2a_initialize", + "moe_a2a_combine", "moe_a2a_dispatch", - "moe_a2a_combine", + "moe_a2a_get_workspace_size_per_rank", "moe_a2a_sanitize_expert_ids", - "moe_a2a_get_workspace_size_per_rank", ]csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h (1)
78-120: Well-documented struct with clear field descriptions.The
MoeA2ADispatchParamsstruct has excellent inline documentation explaining each field's purpose and dimensions. The TODO on line 90-91 about renamingmax_tokens_per_ranktoruntime_max_tokens_per_rankshould be tracked.Would you like me to open an issue to track the TODO about renaming
max_tokens_per_rank?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (16)
csrc/nv_internal/cpp/common/envUtils.cpp(2 hunks)csrc/nv_internal/tensorrt_llm/common/envUtils.h(2 hunks)csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu(1 hunks)csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h(1 hunks)csrc/nv_internal/tensorrt_llm/thop/moeAlltoAllMeta.h(1 hunks)csrc/trtllm_moe_a2a.cu(1 hunks)docs/api/comm.rst(1 hunks)flashinfer/aot.py(1 hunks)flashinfer/comm/__init__.py(1 hunks)flashinfer/comm/trtllm_moe_alltoall.py(1 hunks)flashinfer/jit/__init__.py(1 hunks)flashinfer/jit/comm.py(1 hunks)scripts/task_test_multi_node_comm_kernels.sh(1 hunks)tests/comm/test_mnnvl_memory.py(1 hunks)tests/comm/test_mnnvl_moe_alltoall.py(1 hunks)tests/comm/test_trtllm_moe_alltoall.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (11)
flashinfer/jit/__init__.py (1)
flashinfer/jit/comm.py (1)
gen_mnnvl_a2a_module(83-109)
csrc/nv_internal/tensorrt_llm/common/envUtils.h (1)
csrc/nv_internal/cpp/common/envUtils.cpp (8)
getEnvKVCacheTimeOutputPath(275-278)getEnvKVCacheTimeOutputPath(275-275)getEnvMoeA2AOneBlockPerToken(326-333)getEnvMoeA2AOneBlockPerToken(326-326)getEnvMoeA2ADispatchBlockSize(347-350)getEnvMoeA2ADispatchBlockSize(347-347)getEnvMoeA2ACombineBlockSize(352-355)getEnvMoeA2ACombineBlockSize(352-352)
tests/comm/test_mnnvl_memory.py (1)
flashinfer/comm/mapping.py (1)
local_rank(391-392)
csrc/nv_internal/tensorrt_llm/thop/moeAlltoAllMeta.h (1)
csrc/trtllm_moe_a2a.cu (2)
getMoeA2AMetaInfoIndexPairs(395-407)getMoeA2AMetaInfoIndexPairs(395-395)
csrc/trtllm_moe_a2a.cu (1)
csrc/nv_internal/cpp/common/envUtils.cpp (2)
getEnvMoeA2AOneBlockPerToken(326-333)getEnvMoeA2AOneBlockPerToken(326-326)
flashinfer/aot.py (1)
flashinfer/jit/comm.py (1)
gen_mnnvl_a2a_module(83-109)
flashinfer/jit/comm.py (1)
flashinfer/jit/core.py (2)
JitSpec(213-312)gen_jit_spec(315-381)
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (1)
csrc/nv_internal/cpp/common/envUtils.cpp (4)
getEnvMoeA2ADispatchBlockSize(347-350)getEnvMoeA2ADispatchBlockSize(347-347)getEnvMoeA2ACombineBlockSize(352-355)getEnvMoeA2ACombineBlockSize(352-352)
flashinfer/comm/trtllm_moe_alltoall.py (4)
flashinfer/comm/mnnvl.py (5)
MnnvlMemory(232-551)MnnvlConfig(224-229)as_torch_strided_tensor(264-273)initialize(276-285)set_comm_from_config(288-293)flashinfer/comm/mapping.py (2)
Mapping(21-475)moe_ep_rank(349-350)flashinfer/jit/comm.py (1)
gen_mnnvl_a2a_module(83-109)include/flashinfer/trtllm/fused_moe/runner.h (1)
num_experts(263-263)
tests/comm/test_mnnvl_moe_alltoall.py (2)
flashinfer/comm/trtllm_moe_alltoall.py (4)
MoeAlltoAll(336-618)dispatch(484-541)get_combine_payload_tensor_in_workspace(585-618)combine(543-583)flashinfer/comm/mnnvl.py (3)
MnnvlMemory(232-551)initialize(276-285)supports_mnnvl(545-551)
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h (3)
csrc/nv_internal/tensorrt_llm/common/envUtils.h (1)
tensorrt_llm(23-104)csrc/nv_internal/tensorrt_llm/thop/moeAlltoAllMeta.h (1)
mnnvl_throughput(25-58)csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (8)
moe_a2a_dispatch_launch(445-506)moe_a2a_dispatch_launch(445-445)moe_a2a_prepare_dispatch_launch(436-439)moe_a2a_prepare_dispatch_launch(436-436)moe_a2a_combine_launch(792-842)moe_a2a_combine_launch(792-792)moe_a2a_sanitize_expert_ids_launch(864-872)moe_a2a_sanitize_expert_ids_launch(864-866)
🪛 Clang (14.0.6)
csrc/nv_internal/tensorrt_llm/thop/moeAlltoAllMeta.h
[error] 19-19: 'array' file not found
(clang-diagnostic-error)
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h
[error] 18-18: 'cuda_bf16.h' file not found
(clang-diagnostic-error)
🪛 Ruff (0.14.6)
flashinfer/comm/trtllm_moe_alltoall.py
351-351: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
443-443: Avoid specifying long messages outside the exception class
(TRY003)
445-445: Avoid specifying long messages outside the exception class
(TRY003)
606-608: Avoid specifying long messages outside the exception class
(TRY003)
621-628: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
tests/comm/test_mnnvl_moe_alltoall.py
34-34: Avoid specifying long messages outside the exception class
(TRY003)
42-42: Use raise without specifying exception name
Remove exception name
(TRY201)
46-46: Use raise without specifying exception name
Remove exception name
(TRY201)
575-575: Do not catch blind exception: Exception
(BLE001)
676-676: Do not catch blind exception: Exception
(BLE001)
710-710: Unpacked variable expert_id_payload_index is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🔇 Additional comments (38)
csrc/nv_internal/cpp/common/envUtils.cpp (2)
275-277: KV cache time output path now cached by reference – behavior looks consistentReturning
std::string const&to a function-local static is safe here and matches the pattern used by other helpers in this file (single evaluation of the env var, cheap repeated access). The only behavior change is that changes toTRTLLM_KVCACHE_TIME_OUTPUT_PATHafter the first call are no longer observed, which is usually fine for env-based config.If you don’t rely on mutating the env mid-process (including in tests), this change looks good to me.
326-355: Based on my analysis with the available information, I can now rewrite the review comment:Verify MOE A2A environment variable prefixes (
TLLM_vsTRTLLM_) against project documentation and call sites; fix misleading comment insanitizeBlockSizeto reflect "round up" instead of "nearest"The new helpers follow the file's consistent patterns (single-read statics, sane defaults), but need attention in two areas:
Env var prefixes for MoE A2A knobs
- These functions read
TLLM_MOE_A2A_ONE_BLOCK_PER_TOKEN,TLLM_MOE_A2A_DISPATCH_BLOCK_SIZE, andTLLM_MOE_A2A_COMBINE_BLOCK_SIZE- Review the project documentation, call sites, and surrounding code to confirm whether the
TLLM_prefix matches intended usage; most other envs in this file appear to useTRTLLM_prefix and silently using the wrong names would be problematic.
sanitizeBlockSizecomment accuracy
- The implementation rounds up to a multiple of 32:
block = (block + 31) / 32 * 32;(e.g., 33 → 64, not 32)- The comment currently states "Round to nearest multiple of 32 (warp size)", which is misleading.
Fix the comment:
- // Round to nearest multiple of 32 (warp size)
- // Round up to the next multiple of 32 (warp size)
The extra `if (block == 0) block = 256;` check after clamping is redundant but harmless.csrc/nv_internal/tensorrt_llm/common/envUtils.h (1)
95-102: LGTM! New MoE A2A environment variable accessors are properly declared.The three new accessor functions (
getEnvMoeA2AOneBlockPerToken,getEnvMoeA2ADispatchBlockSize,getEnvMoeA2ACombineBlockSize) are well-documented with default behaviors and align with their implementations inenvUtils.cpp.csrc/nv_internal/tensorrt_llm/thop/moeAlltoAllMeta.h (3)
19-22: Static analysis false positive - standard headers are valid.The Clang error about
<array>not being found is a false positive. This is a standard C++11 header that should be available in any modern C++ environment. The includes are correct.
28-43: LGTM! Well-structured metadata index enum.The
MoeA2AMetaInfoIndexenum provides clear, sequential indexing for metadata fields withNUM_METAINFO_FIELDS = 9correctly representing the count of actual data fields (0-8). TheMoeA2ADataOffsetstype alias correctly uses this count for the array size.
45-58: LGTM! Useful name-to-index mapping function.The inline
getMoeA2AMetaInfoIndexPairs()function provides a clean way to expose metadata field names and their corresponding indices, which is consumed by the TVM FFI interface incsrc/trtllm_moe_a2a.cu.csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (4)
114-116: Timeout calculation looks reasonable.The timeout of
300ll * 2000ll * 1000ll * 1000llcycles (~600 billion) translates to approximately 300 seconds at 2 GHz, which provides a generous upper bound to prevent infinite hangs while allowing ample time for legitimate synchronization delays. The calculation avoids overflow by usinglong longliterals.
272-276: LGTM! Single-threaded flag increment is safe.The
flag_valincrement occurs only whenidx == 0, ensuring single-threaded access. Since this kernel runs sequentially in the stream before the dispatch kernel, there's no race condition.
844-872: LGTM! Sanitize kernel implementation is correct.The kernel correctly identifies invalid tokens (where
token_idx >= recv_counters[source_rank]) and sets their expert IDs toinvalid_id. Each thread operates on disjoint memory locations, avoiding any race conditions.
315-344: Based on my investigation, I cannot access the repository or find public documentation for thekMaxRanksconstant definition. The repository clone failed, and web searches returned no results for this internal NVIDIA TensorRT-LLM code.However, the core concern raised in the review comment remains valid and cannot be conclusively verified without access to:
- The header file containing
kMaxRanksdefinition- The actual value of
kMaxRanks- Runtime validation constraints on
ep_size(ensemble parallel size)The potential undefined behavior is legitimate: if
target_rankcan be 64 or greater, the bit shift1ULL << target_rankon auint64_twould indeed cause undefined behavior in C++.
Verify
kMaxRanksdoes not exceed 64 to avoid undefined behavior.The
already_copiedbitmask usesuint64_twith bit operations1ULL << target_rank. Iftarget_rankcan be 64 or greater, this causes undefined behavior (shifting by >= width of type). The code validatesparams.ep_size <= kMaxRanksat line 448, so ensurekMaxRanksis defined as ≤ 64 in the header.flashinfer/aot.py (1)
515-522: LGTM! MNNVL A2A module integration follows existing patterns.The new
gen_mnnvl_a2a_moduleis correctly imported within theadd_commblock and added under thehas_sm100condition, consistent with the existinggen_trtllm_comm_moduleandgen_trtllm_mnnvl_comm_moduleplacement.tests/comm/test_mnnvl_memory.py (1)
125-125: LGTM! Correct device selection for multi-node scenarios.Using
self.local_rankinstead ofself.rankis the correct fix for multi-node setups where the global rank may exceed the local GPU count. This aligns with the setup fixture (line 51) and theMapping.local_rankproperty shown in the relevant snippet.scripts/task_test_multi_node_comm_kernels.sh (1)
17-19: LGTM - new MoE A2A test coverage added.The addition of
test_mnnvl_moe_alltoall.pyaligns with the new MoE A2A functionality introduced in this PR.flashinfer/jit/__init__.py (1)
80-80: LGTM!The re-export follows the established pattern for module generators in this file.
flashinfer/comm/__init__.py (1)
42-52: LGTM!The new MoE A2A exports follow the established import patterns and properly expose the public API surface for the throughput backend.
csrc/trtllm_moe_a2a.cu (6)
53-99: LGTM - offset calculations and workspace sizing are well-structured.The alignment logic using cache-line boundaries (128 bytes) is appropriate for GPU memory access patterns. The offset calculations correctly account for metadata fields and payload regions.
101-123: LGTM - initialization logic is correct.Input validation is comprehensive. The
cudaStreamSynchronizeat line 119 ensures the memset completes before returning, which is necessary for correctness.
132-262: LGTM - dispatch implementation is solid.The function has comprehensive input validation, proper payload descriptor setup, workspace size verification, and appropriate CUDA error checking after kernel launch.
264-277: LGTM!The dtype conversion correctly handles the common floating-point types used in MoE computations. The unreachable return after the throw silences compiler warnings.
279-361: LGTM - combine implementation is well-validated.The function properly validates all inputs including the optional
payloadInWorkspaceoptimization for zero-copy scenarios. Error handling is comprehensive.
393-416: LGTM!The metainfo accessor correctly wraps the C++ pairs for Python accessibility, and all FFI exports are properly registered.
flashinfer/jit/comm.py (1)
83-109: I'll verify whether architecture-specific NVCC flags are required for thegen_mnnvl_a2a_modulefunction. Let me search the codebase for relevant information.
<function_calls>
#!/bin/bashFind the moeAlltoAllKernels.cu file
find . -name "moeAlltoAllKernels.cu" -type f 2>/dev/null
</function_calls>
<function_calls>
#!/bin/bashCheck if the file exists and search for architecture-specific patterns
if [ -f "csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu" ]; then
echo "=== Checking for SM guards and architecture macros ==="
rg -n 'sm|SM[0-9]|CUDA_ARCH|CUDA_SM' "csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu" || echo "No explicit SM guards found"echo ""
echo "=== Checking for tensor core/MMA operations ==="
rg -n 'mma|wmma|tensorop|__mma|__hmma' "csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu" || echo "No tensor core intrinsics found"
else
echo "File not found at expected path"
fi
</function_calls><function_calls>
#!/bin/bashLook at the comm.py file to understand nvcc_flags definition and usage
echo "=== Current comm.py around nvcc_flags definition ==="
rg -n 'nvcc_flags' flashinfer/jit/comm.py -B3 -A3echo ""
echo "=== All functions in comm.py to see pattern ==="
rg -n '^def gen_' flashinfer/jit/comm.py
</function_calls><function_calls>
#!/bin/bashLet's see the full context of comm.py to understand nvcc_flags and module patterns
wc -l flashinfer/jit/comm.py
</function_calls><function_calls>
#!/bin/bashRead the entire comm.py file to understand structure and nvcc_flags
cat -n flashinfer/jit/comm.py
</function_calls><function_calls>
#!/bin/bashCheck what nvcc_flags is and where it comes from
rg -n 'nvcc_flags\s*=' flashinfer/jit/comm.py -B2 -A2
echo ""
echo "=== Check for imports in comm.py ==="
head -30 flashinfer/jit/comm.py
</function_calls><function_calls>
#!/bin/bashCheck the jit module structure
ls -la flashinfer/jit/
</function_calls><function_calls>
#!/bin/bashCheck jit_env for nvcc_flags
rg -n 'nvcc_flags' flashinfer/jit/ -A2 -B2
</function_calls>tests/comm/test_trtllm_moe_alltoall.py (6)
25-29: Session-scoped fixture withtorch.manual_seedmay not reset seed between tests.The fixture sets the seed once at session start but individual tests may want reproducible seeds. Consider if this is intentional or if test-level seeding is needed.
62-69: Good defensive check for SM resources.This helper appropriately skips tests when insufficient SMs are available for parallel kernel execution. The check prevents hangs on systems with limited GPU resources.
72-84: LGTM!The
make_payloadhelper correctly distinguishes between integer and floating-point types for random tensor generation.
138-141: Sorting approach for validation is reasonable but fragile.Sorting both input and output tensors to compare them works for this test case but relies on unique values. If there are duplicate values, the sort order could differ. Consider documenting this assumption or using a more robust comparison.
388-429: Reference implementation for fake_moe looks correct.The fake_moe function properly handles expert parallelism filtering and accumulation. The tree reduction comment on line 423 correctly explains why results are summed after collection.
530-536: Relatively loose tolerance for numerical comparison.Using
atol=1.5e-2andrtol=1.5e-2is quite loose for bf16/fp16. This may mask precision issues. Verify this tolerance is intentional given the accumulation order differences mentioned elsewhere.tests/comm/test_mnnvl_moe_alltoall.py (3)
293-293: Direct modification of class variable_WORKSPACEis concerning.Setting
MoeAlltoAll._WORKSPACE = Nonedirectly before instantiation suggests test isolation concerns. This should be documented or handled via a proper reset method.Consider whether
_reset_workspace()method fromMoeAlltoAllshould be used instead, or if this pattern is intentional for test setup.
800-813: Good documentation of tolerance rationale.The comment on line 809 clearly explains why a 99% match threshold is used instead of exact comparison due to bf16 accumulation order differences. This is helpful for future maintainers.
836-838: Helpful run instructions in docstring.The comment showing how to run with
mpirunis useful for developers unfamiliar with MPI testing.flashinfer/comm/trtllm_moe_alltoall.py (3)
353-383: Workspace caching strategy looks correct.The caching by
(workspace_size_per_rank, ep_rank, ep_size, max_num_tokens)tuple allows reusing workspaces across instances with compatible configurations. This addresses the past review comment about supporting different shaped communicators.
470-482:_reset_workspacemethod deletes from class cache without thread safety.If multiple threads could access this class simultaneously, the
deloperation on_WORKSPACE_CACHEcould cause issues. Document that this method is not thread-safe.def _reset_workspace(self): - """Reset the workspace to free up its state. This is mainly used for testing. Use this with caution. This object is no longer usable after this.""" + """Reset the workspace to free up its state. + + Warning: This method is not thread-safe and is mainly used for testing. + This object is no longer usable after calling this method. + """
505-508: Good use of state machine pattern for dispatch/combine sequencing.The phase checking prevents calling dispatch twice without combine and ensures proper operation ordering. This is a clean design.
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h (4)
17-19: Static analysis reports missingcuda_bf16.h- this is a false positive.The
cuda_bf16.handcuda_fp16.hheaders are provided by the CUDA toolkit and will be available during compilation with nvcc. Static analysis tools without CUDA environment cannot find these headers.
23-27: Configuration constants are well-documented and reasonable.The limits (256 experts, 8 top-k, 8 payloads, 64 ranks) provide good flexibility while keeping fixed-size arrays manageable. Consider whether these should be configurable at runtime if larger deployments are anticipated.
173-179: Function declarations are clean and match the implementation.The kernel launch function declarations align with the implementations shown in the relevant code snippets from
moeAlltoAllKernels.cu.
148-148: Unable to verify include configuration due to repository access failure.The repository clone failed, preventing me from examining the file's include structure, verifying whether
nvinfer1::DataTypeis actually used, or confirming if the necessary headers are already present. Manual verification is required to confirm:
- Whether
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.hcurrently includes NvInfer headers- Whether
nvinfer1::DataTypeis actually declared in the file or included transitively- Whether the code compiles successfully without the suggested include
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (3)
flashinfer/comm/trtllm_moe_alltoall.py (3)
379-379: Annotate mutable class attribute withClassVar.Per Python best practices, mutable class attributes should be annotated with
ClassVarto make clear they are shared across instances.+from typing import ClassVar + class MoeAlltoAll: ... - _WORKSPACE_CACHE: dict[tuple[int, int, int, int], dict] = {} + _WORKSPACE_CACHE: ClassVar[dict[tuple[int, int, int, int], dict]] = {}
638-638: Consider usingtorch.finfoortorch.iinfofor element size.Creating an empty tensor just to get element size has minor overhead. Consider using dtype introspection directly.
- element_size = torch.tensor([], dtype=dtype).element_size() + element_size = torch._utils._element_size(dtype)Alternatively, keep the current approach if you prefer avoiding private APIs.
649-656: Consider addingmoe_a2a_wrap_payload_tensor_in_workspaceto__all__.This function is used in tests and appears to be part of the public API. Also consider sorting
__all__for consistency.__all__ = [ "MoeAlltoAll", + "moe_a2a_combine", "moe_a2a_initialize", "moe_a2a_dispatch", - "moe_a2a_combine", + "moe_a2a_get_workspace_size_per_rank", "moe_a2a_sanitize_expert_ids", - "moe_a2a_get_workspace_size_per_rank", + "moe_a2a_wrap_payload_tensor_in_workspace", ]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
csrc/trtllm_moe_alltoall.cu(1 hunks)flashinfer/aot.py(1 hunks)flashinfer/comm/trtllm_moe_alltoall.py(1 hunks)flashinfer/jit/__init__.py(1 hunks)flashinfer/jit/comm.py(1 hunks)tests/comm/test_trtllm_moe_alltoall.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- flashinfer/aot.py
- flashinfer/jit/comm.py
🧰 Additional context used
🧬 Code graph analysis (2)
csrc/trtllm_moe_alltoall.cu (2)
csrc/tvm_ffi_utils.h (3)
Tensor(282-284)get_current_stream(266-270)encode_dlpack_dtype(29-31)flashinfer/comm/trtllm_moe_alltoall.py (6)
moe_a2a_get_workspace_size_per_rank(175-198)moe_a2a_get_workspace_size_per_rank(350-361)moe_a2a_initialize(41-47)moe_a2a_initialize(210-218)moe_a2a_dispatch(53-93)moe_a2a_dispatch(251-309)
tests/comm/test_trtllm_moe_alltoall.py (7)
flashinfer/comm/mapping.py (1)
Mapping(21-475)tests/test_helpers/test_helpers.py (1)
get_device_properties(10-11)include/flashinfer/trtllm/fused_moe/runner.h (1)
num_experts(263-263)flashinfer/comm/trtllm_moe_alltoall.py (13)
moe_a2a_get_workspace_size_per_rank(175-198)moe_a2a_get_workspace_size_per_rank(350-361)MoeAlltoAll(364-646)dispatch(512-569)get_combine_payload_tensor_in_workspace(613-646)combine(571-611)moe_a2a_initialize(41-47)moe_a2a_initialize(210-218)moe_a2a_dispatch(53-93)moe_a2a_dispatch(251-309)moe_a2a_sanitize_expert_ids(146-155)moe_a2a_sanitize_expert_ids(338-347)moe_a2a_wrap_payload_tensor_in_workspace(221-248)flashinfer/fused_moe/utils.py (1)
_(157-163)csrc/xqa/mha.cu (1)
any(157-157)tests/comm/test_mnnvl_moe_alltoall.py (1)
fake_moe(121-181)
🪛 Ruff (0.14.6)
flashinfer/comm/trtllm_moe_alltoall.py
379-379: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
471-471: Avoid specifying long messages outside the exception class
(TRY003)
473-473: Avoid specifying long messages outside the exception class
(TRY003)
634-636: Avoid specifying long messages outside the exception class
(TRY003)
649-656: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
🔇 Additional comments (16)
flashinfer/jit/__init__.py (1)
80-80: LGTM!The new import follows the established pattern for re-exporting JIT module generators from the comm submodule.
csrc/trtllm_moe_alltoall.cu (5)
53-88: LGTM!The offset calculation logic correctly aligns data structures to cache-line boundaries where needed, following a clear sequential layout pattern.
101-123: LGTM!The initialization correctly validates inputs, zeros the workspace region, and returns metadata offsets. The stream synchronization ensures the workspace is properly initialized before returning.
278-360: LGTM!The combine operation has thorough input validation, properly handles the workspace-backed payload case, and includes appropriate error checking after kernel launch.
362-394: LGTM!The sanitization operation correctly validates inputs and launches the kernel with proper error checking.
396-419: LGTM!The metainfo index pairs helper provides a clean mechanism to expose C++ constants to Python, and all required functions are properly exported.
tests/comm/test_trtllm_moe_alltoall.py (6)
74-86: LGTM!The payload generator correctly handles both integer and floating-point dtypes for test data generation.
93-161: LGTM!Comprehensive single-GPU test covering multiple payload dtypes, dispatch/combine workflow, and workspace-backed tensor operations.
164-240: LGTM!The helper correctly simulates multi-rank dispatch on a single GPU using separate CUDA streams, with proper synchronization.
302-344: LGTM!The multi-rank test correctly validates token routing across simulated ranks with proper verification of payload delivery.
390-431: LGTM!The reference MoE implementation provides a deterministic baseline for verifying combine correctness, with appropriate handling of expert-parallel scenarios.
434-551: LGTM!Comprehensive combine test covering multiple dtypes, workspace configurations, and ranks with appropriate numerical tolerances for reduced-precision arithmetic.
flashinfer/comm/trtllm_moe_alltoall.py (4)
32-207: LGTM!The JIT module getter follows the established pattern with proper caching and custom op registration.
221-248: LGTM!The function correctly creates a workspace-backed tensor view with properly documented parameters.
470-473: LGTM!The validation logic is appropriate and the exception messages are concise.
498-510: LGTM!The reset method appropriately handles workspace cleanup for testing scenarios, with clear documentation about post-call state.
6e9bed5 to
a51b1ea
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (6)
csrc/trtllm_moe_alltoall.cu (1)
263-276: Consider extending dtype support for future flexibility.The
toNvDataTypefunction currently supports half, bfloat16, and float32. Consider documenting supported types or adding int8/fp8 support if those are common in MoE workloads.tests/comm/test_mnnvl_moe_alltoall.py (3)
711-712: Unused variable is intentional; consider underscore prefix per Ruff hint.The
expert_id_payload_indexis returned by the helper but not used in this test. Consider renaming to_expert_id_payload_indexto signal intentional discard.- payloads, expert_id_payload_index = make_bfloat16_payloads( + payloads, _expert_id_payload_index = make_bfloat16_payloads(
293-294: Setting class attribute_WORKSPACE = Nonemay conflict with class-level cache.Assigning
MoeAlltoAll._WORKSPACE = Noneresets a non-existent instance attribute. The class uses_WORKSPACE_CACHEfor caching. This assignment has no effect but is misleading.Consider removing this line or using
MoeAlltoAll._WORKSPACE_CACHE.clear()if the intent is to reset the cache:- MoeAlltoAll._WORKSPACE = None + MoeAlltoAll._WORKSPACE_CACHE.clear()
742-742: Same issue:_WORKSPACE = Noneassignment is ineffective.This line also sets a non-existent attribute. Consider removing or using
_WORKSPACE_CACHE.clear().- MoeAlltoAll._WORKSPACE = Noneflashinfer/comm/trtllm_moe_alltoall.py (2)
375-377: Annotate mutable class attribute withClassVarper Ruff hint.The
_WORKSPACE_CACHEis a mutable class-level attribute that should be annotated withClassVarto make the intent clear.+from typing import ClassVar + class MoeAlltoAll: ... - _WORKSPACE_CACHE: dict[tuple[int, int, int, int], dict] = {} + _WORKSPACE_CACHE: ClassVar[dict[tuple[int, int, int, int], dict]] = {}
496-508: Document that_reset_workspaceinvalidates the instance.The docstring mentions this but it's critical: after calling
_reset_workspace, the object is unusable. Consider adding a stronger warning or raising an exception on subsequent method calls.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
csrc/trtllm_moe_alltoall.cu(1 hunks)flashinfer/comm/trtllm_moe_alltoall.py(1 hunks)tests/comm/test_mnnvl_moe_alltoall.py(1 hunks)tests/comm/test_trtllm_moe_alltoall.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
csrc/trtllm_moe_alltoall.cu (2)
csrc/tvm_ffi_utils.h (3)
Tensor(282-284)get_current_stream(266-270)encode_dlpack_dtype(29-31)flashinfer/comm/trtllm_moe_alltoall.py (11)
moe_a2a_get_workspace_size_per_rank(173-196)moe_a2a_get_workspace_size_per_rank(348-359)moe_a2a_initialize(39-45)moe_a2a_initialize(208-216)moe_a2a_dispatch(51-91)moe_a2a_dispatch(249-307)moe_a2a_combine(97-138)moe_a2a_combine(310-333)moe_a2a_sanitize_expert_ids(144-153)moe_a2a_sanitize_expert_ids(336-345)moe_a2a_get_metainfo_index_pairs(159-167)
flashinfer/comm/trtllm_moe_alltoall.py (3)
flashinfer/comm/mnnvl.py (5)
MnnvlMemory(232-551)MnnvlConfig(224-229)as_torch_strided_tensor(264-273)initialize(276-285)set_comm_from_config(288-293)flashinfer/comm/mapping.py (2)
Mapping(21-475)moe_ep_rank(349-350)flashinfer/jit/comm.py (1)
gen_mnnvl_moe_alltoall_module(83-109)
tests/comm/test_mnnvl_moe_alltoall.py (3)
flashinfer/comm/trtllm_moe_alltoall.py (4)
MoeAlltoAll(362-644)dispatch(510-567)get_combine_payload_tensor_in_workspace(611-644)combine(569-609)flashinfer/comm/mapping.py (1)
Mapping(21-475)flashinfer/comm/mnnvl.py (3)
MnnvlMemory(232-551)initialize(276-285)supports_mnnvl(545-551)
🪛 Ruff (0.14.6)
flashinfer/comm/trtllm_moe_alltoall.py
377-377: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
469-469: Avoid specifying long messages outside the exception class
(TRY003)
471-471: Avoid specifying long messages outside the exception class
(TRY003)
632-634: Avoid specifying long messages outside the exception class
(TRY003)
tests/comm/test_mnnvl_moe_alltoall.py
34-34: Avoid specifying long messages outside the exception class
(TRY003)
576-576: Do not catch blind exception: Exception
(BLE001)
677-677: Do not catch blind exception: Exception
(BLE001)
711-711: Unpacked variable expert_id_payload_index is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🔇 Additional comments (44)
csrc/trtllm_moe_alltoall.cu (8)
1-51: LGTM! Well-structured header and utility functions.The license header, includes, and utility functions are appropriately organized. The
alignOffsetfunction correctly implements cache-line alignment using bitwise operations.
53-88: LGTM! Offset calculation is well-structured.The
calculateOffsetsfunction properly calculates memory offsets for various MoE A2A data structures with appropriate alignment for cache-line boundaries.
90-99: LGTM! Workspace size calculation.The workspace size calculation correctly accounts for metadata, payload, and combine regions with proper alignment.
101-123: LGTM! Initialize operation with proper validation.Good input validation for workspace dimensions, rank bounds, and proper error checking for CUDA operations. The synchronization before returning metainfo is appropriate.
125-261: LGTM! Dispatch operation is well-implemented.The dispatch function has comprehensive input validation, proper payload descriptor handling, and correct workspace pointer arithmetic. Error checking after kernel launch is appropriate.
278-360: LGTM! Combine operation with proper validation.The combine function correctly validates payload dimensions, workspace pointer alignment, and handles the
payloadInWorkspaceflag appropriately. Error checking after kernel launch is proper.
362-394: LGTM! Sanitize operation is correctly implemented.Proper input validation and error checking for the sanitize expert IDs kernel.
396-419: LGTM! Metainfo export and FFI registration.The metainfo index pairs function and TVM FFI exports are correctly implemented, providing clean Python interoperability.
tests/comm/test_trtllm_moe_alltoall.py (12)
25-29: LGTM! Docstring has been corrected.The fixture docstring now accurately describes that it sets the torch seed for deterministic tests.
32-60: Good test parameter coverage.The test parameters cover a good range of configurations (small, medium, large) for both single-GPU and multi-rank scenarios, with various dtypes and payload configurations.
63-72: Good resource-aware skip logic.The SM count check appropriately skips tests when hardware resources are insufficient, preventing false failures on less capable GPUs.
74-86: LGTM! Payload generation helper.The
make_payloadfunction correctly handles both integer and floating-point dtypes with appropriate random value generation.
89-162: Comprehensive single-GPU test with proper verification.The test covers dispatch and combine flows with multiple dtypes, validates output via sorting and exact comparison, and tests the workspace-backed combine path.
164-240: LGTM! Multi-rank dispatch helper is well-structured.The helper properly manages workspaces, initializes per-rank metadata, uses separate CUDA streams for parallel execution, and synchronizes appropriately.
243-259: LGTM! Sanitize helper function.Simple and correct delegation to the underlying sanitize function for each rank.
262-299: LGTM! Combine helper with parallel execution.The combine helper correctly uses separate streams per rank and synchronizes before returning results.
302-345: LGTM! Multi-rank test with proper verification.Good verification logic that filters non-zero tensors and compares sorted outputs against the reference filtered by expert assignment.
347-388: LGTM! Sanitize test with comprehensive verification.The test properly clones tensors before sanitization to enable before/after comparison and correctly verifies the sanitization logic.
390-431: LGTM! Reference MoE implementation for verification.The
fake_moefunction provides a clear reference implementation for verifying the distributed MoE behavior, with proper EP-rank filtering logic.
434-555: Good end-to-end combine test with tolerance handling.The test covers the full dispatch-process-combine cycle with both in-workspace and external payload paths. The tolerance values for bf16 are reasonable.
tests/comm/test_mnnvl_moe_alltoall.py (11)
27-46: MPI error handling utilities are well-designed.The
MPIExitexception,check_any_rank_failed, andsafe_runpattern provide robust MPI coordination for test failures across ranks, ensuring clean error propagation.
49-52: LGTM! Test fixture for deterministic seeding.
55-88: LGTM! Helper functions for expert routing and token generation.
compute_target_rank_idcorrectly implements contiguous expert partitioning, andgenerate_token_selected_expertsproperly generates random expert assignments.
91-119: LGTM! Expert weight creation with reproducible seeding.Using
ep_rank * 1000 + ias a seed ensures reproducibility across runs while differentiating experts per rank.
122-182: LGTM! Comprehensive fake MoE reference implementation.The function correctly handles both EP-rank and global modes with proper local expert ID conversion.
185-258: LGTM! Payload creation helpers.Both NV FP4 and BFloat16 payload creators are well-structured with appropriate rank-specific patterns for verification.
261-383: LGTM! Single-rank dispatch worker function.Comprehensive workspace setup, dispatch execution, and metadata extraction for MPI-based testing.
386-556: LGTM! Thorough dispatch verification.The
verify_dispatchfunction provides exhaustive validation of shapes, dtypes, counters, routing, and payload content. This is excellent for catching regressions.
572-577: Bare Exception catch is intentional for MNNVL availability check.The broad exception catch here is acceptable since it's used to detect MNNVL support availability across various failure modes (driver issues, missing hardware, etc.).
673-678: Bare Exception catch is acceptable for capability detection.Same as above - this is intentional for gracefully skipping tests on systems without MNNVL support.
657-818: LGTM! Full dispatch+combine cycle test.The test properly verifies the complete MoE A2A workflow with appropriate tolerance for bf16 accumulation order differences. The 99% match threshold is reasonable given the expected numerical variations.
flashinfer/comm/trtllm_moe_alltoall.py (13)
1-19: LGTM! Module header and imports are well-organized.Clean module docstring and appropriate imports for the MoE A2A functionality.
21-28: LGTM! State dataclass is appropriate.The
_A2AStatedataclass cleanly tracks the dispatch/combine lifecycle with appropriate fields.
30-205: LGTM! JIT module initialization with custom op registration.The lazy module loading with
@functools.cacheand custom op registration is well-structured. The returnedSimpleNamespaceprovides a clean API surface.
208-246: LGTM! Public wrapper functions are clean delegations.The top-level
moe_a2a_*functions provide clean interfaces to the JIT module, with appropriate docstrings where needed.
249-307: LGTM! Dispatch wrapper with tensor wrapping.The dispatch function correctly wraps the raw offsets into workspace-backed tensors for each payload.
310-359: LGTM! Combine, sanitize, and workspace size wrappers.Clean delegation to the underlying JIT module.
379-409: LGTM! Workspace caching with proper key management.The
get_workspaceclassmethod correctly caches workspaces by configuration tuple, preventing redundant allocations.
411-432: LGTM! Lazy metainfo constant initialization.The
_init_constantsmethod properly strips prefixes for a cleaner Python API.
434-494: LGTM! Constructor with proper validation and MNNVL configuration.Good input validation for
top_kandnum_experts, with optional MnnvlConfig support as discussed in past reviews.
510-567: LGTM! Dispatch method with proper state management.Good state assertions, lifecycle tracking, and optional sanitization flow.
569-609: LGTM! Combine method with state reset.Proper state validation and reset after combine completes, enabling the next dispatch/combine cycle.
611-645: LGTM! Workspace-backed tensor accessor.The
get_combine_payload_tensor_in_workspacemethod correctly computes slice bounds and validates state.
647-654: LGTM! Clean__all__export list.Explicitly defines the public API surface.
|
/bot run |
f93dd14 to
d400fd4
Compare
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
♻️ Duplicate comments (5)
csrc/trtllm_moe_alltoall.cu (1)
356-388: Add missing validation for epRank and workspace dimensions.The sanitize operation lacks validation that could prevent out-of-bounds access:
- Missing check:
workspace.size(0) == epSize- Missing check:
epRank >= 0 && epRank < epSizeWithout these checks, line 376's pointer arithmetic
workspaceBase + epRank * workspace.stride(0)could access invalid memory if epRank is out of range or if the workspace dimensions don't match expectations.Apply this diff to add the missing validation:
void moeA2ASanitizeExpertIdsOp(TensorView expertIds, TensorView workspace, TensorView metainfo, int64_t epRank, int64_t invalidExpertId) { CHECK_INPUT(expertIds); CHECK_INPUT_TYPE(expertIds, dl_int32); TVM_FFI_ICHECK_EQ(expertIds.ndim(), 3); int64_t epSize = expertIds.size(0); int64_t runtimeMaxTokensPerRank = expertIds.size(1); int64_t topK = expertIds.size(2); CHECK_CPU(metainfo); CHECK_INPUT_TYPE(metainfo, dl_int64); TVM_FFI_ICHECK_EQ(metainfo.ndim(), 1); TVM_FFI_ICHECK_EQ(metainfo.size(0), fi_throughput::NUM_METAINFO_FIELDS); auto const* offsetsPtr = static_cast<int64_t const*>(metainfo.data_ptr()); fi_throughput::MoeA2ADataOffsets offsets{}; std::copy(offsetsPtr, offsetsPtr + fi_throughput::NUM_METAINFO_FIELDS, offsets.begin()); CHECK_INPUT_TYPE(workspace, dl_uint8); TVM_FFI_ICHECK_EQ(workspace.ndim(), 2); + TVM_FFI_ICHECK_EQ(workspace.size(0), epSize); + TVM_FFI_ICHECK(epRank >= 0 && epRank < epSize); auto* workspaceBase = static_cast<uint8_t*>(workspace.data_ptr()); auto* rankWorkspacePtr = workspaceBase + epRank * workspace.stride(0);csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h (1)
173-179: Complete + reorder the sanitize API doc comment to match the signature.
Comment currently omitsep_size/max_tokens_per_rank/top_k/stream, and previous review already flagged ordering consistency—worth fixing here too.// Sanitize expert IDs for invalid tokens // expert_ids: [ep_size, max_tokens_per_rank, top_k] (int32) // recv_counters: [ep_size] (int32), number of valid tokens per source // invalid_id: value to fill for invalid tokens' expert ids +// ep_size: Number of EP ranks +// max_tokens_per_rank: Maximum tokens per source rank in workspace +// top_k: Number of experts per token +// stream: CUDA stream void moe_a2a_sanitize_expert_ids_launch(int32_t* expert_ids, int32_t const* recv_counters, int32_t invalid_id, int ep_size, int max_tokens_per_rank, int top_k, cudaStream_t stream);csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (2)
122-134: Add defensive bounds checks forexpert_id → target_rankbefore indexing/bitshifts.
Ifexpert_idis negative or out of[0, ep_size * num_experts_per_rank),target_rankcan go out of range, causing OOB onsend_counters,recv_buffers,recv_counters, and1ULL << target_rankUB.- int expert_id = token_selected_experts[local_token_idx * TOP_K + k]; - // Use contiguous partitioning to determine target rank - int target_rank = compute_target_rank_id(expert_id, num_experts_per_rank); + int expert_id = token_selected_experts[local_token_idx * TOP_K + k]; + if (__builtin_expect(num_experts_per_rank <= 0, 0)) { + asm volatile("trap;"); + } + // Defensive: drop invalid expert ids rather than going OOB. + if (__builtin_expect(expert_id < 0 || expert_id >= num_experts_per_rank * ep_size, 0)) { + if (thread_idx == 0) { + ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = -1; + ptrs.topk_send_indices[local_token_idx * TOP_K + k] = -1; + smem_topk_target_ranks[k] = -1; + smem_topk_send_indices[k] = -1; + } + continue; + } + int target_rank = expert_id / num_experts_per_rank;Also applies to: 315-345
748-785: Use the same configurable block size inprepare_combineascombine(env var).
Right nowmoe_a2a_prepare_combine_launchhardcodes 256 whilemoe_a2a_combine_launchusesgetEnvMoeA2ACombineBlockSize(). This can desync the “warp vs block policy” mapping and grid sizing if the env var is changed.void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params) { - constexpr int kBlockSize = 256; - constexpr int kWarpsPerBlock = kBlockSize / 32; // 8 warps per block + int const kBlockSize = tensorrt_llm::common::getEnvMoeA2ACombineBlockSize(); + int const kWarpsPerBlock = kBlockSize / 32;Also applies to: 799-841
tests/comm/test_mnnvl_moe_alltoall.py (1)
583-587: Fix allgather equality check to include all ranks (don’t skip rank 0).
On non-zero ranks, slicing[1:]skips validating rank 0’s value, which can mask divergence.gathered_all_num_tokens = comm.allgather(all_num_tokens) - assert all(i == all_num_tokens for i in gathered_all_num_tokens[1:]), ( + assert all(i == gathered_all_num_tokens[0] for i in gathered_all_num_tokens), ( "all_num_tokens should be the same" )
🧹 Nitpick comments (3)
scripts/task_test_multi_node_comm_kernels.sh (1)
9-15: Clarify or remove commented code.The Python bytecode cleanup (lines 9-12) and pip install (line 15) are commented out without explanation. If these operations are no longer needed, remove the commented lines. If they're temporarily disabled, add a comment explaining why.
-# echo "Cleaning Python bytecode cache..." -# find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true -# find . -type f -name '*.pyc' -delete 2>/dev/null || true -# echo "Cache cleaned." -# echo "" - -# pip install -e . -vflashinfer/comm/trtllm_moe_alltoall.py (1)
408-488: Ruff nits: annotate mutable class attrs asClassVar, and makeOptionalexplicit.
This is low-risk cleanup that will keep CI quieter.from dataclasses import dataclass from types import SimpleNamespace -from typing import Optional +from typing import Optional, ClassVar ... class MoeAlltoAll: ... - _WORKSPACE_CACHE: dict[tuple[int, int, int, int], dict] = {} + _WORKSPACE_CACHE: ClassVar[dict[tuple[int, int, int, int], dict]] = {} ... - workspace_size_per_rank: int = None, - hidden_size: int = None, + workspace_size_per_rank: int | None = None, + hidden_size: int | None = None, ... - _METAINFO_INDEX: Optional[dict] = None + _METAINFO_INDEX: ClassVar[dict | None] = NoneAlso applies to: 514-516
tests/comm/test_mnnvl_moe_alltoall.py (1)
569-576: Catching broadExceptionfor “MNNVL unsupported” is acceptable for tests, but consider narrowing.
If these blocks end up masking real infra/test bugs, it’ll be painful to debug.Also applies to: 669-675
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (17)
csrc/nv_internal/cpp/common/envUtils.cpp(2 hunks)csrc/nv_internal/tensorrt_llm/common/envUtils.h(2 hunks)csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu(1 hunks)csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h(1 hunks)csrc/nv_internal/tensorrt_llm/thop/moeAlltoAllMeta.h(1 hunks)csrc/trtllm_moe_alltoall.cu(1 hunks)docs/api/comm.rst(1 hunks)flashinfer/aot.py(1 hunks)flashinfer/comm/__init__.py(1 hunks)flashinfer/comm/trtllm_moe_alltoall.py(1 hunks)flashinfer/jit/__init__.py(1 hunks)flashinfer/jit/comm.py(1 hunks)scripts/task_test_multi_node_comm_kernels.sh(1 hunks)scripts/task_test_single_node_comm_kernels.sh(1 hunks)tests/comm/test_mnnvl_memory.py(1 hunks)tests/comm/test_mnnvl_moe_alltoall.py(1 hunks)tests/comm/test_trtllm_moe_alltoall.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (5)
- scripts/task_test_single_node_comm_kernels.sh
- tests/comm/test_mnnvl_memory.py
- csrc/nv_internal/tensorrt_llm/thop/moeAlltoAllMeta.h
- flashinfer/aot.py
- flashinfer/jit/comm.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
csrc/trtllm_moe_alltoall.cucsrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu
🧬 Code graph analysis (9)
csrc/nv_internal/cpp/common/envUtils.cpp (1)
include/flashinfer/trtllm/common.h (1)
getBoolEnv(195-198)
flashinfer/jit/__init__.py (1)
flashinfer/jit/comm.py (1)
gen_moe_alltoall_module(83-109)
csrc/trtllm_moe_alltoall.cu (2)
csrc/tvm_ffi_utils.h (3)
Tensor(287-289)get_current_stream(271-275)encode_dlpack_dtype(30-32)csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (10)
moe_a2a_prepare_dispatch_launch(436-439)moe_a2a_prepare_dispatch_launch(436-436)moe_a2a_dispatch_launch(445-506)moe_a2a_dispatch_launch(445-445)moe_a2a_prepare_combine_launch(748-786)moe_a2a_prepare_combine_launch(748-748)moe_a2a_combine_launch(792-842)moe_a2a_combine_launch(792-792)moe_a2a_sanitize_expert_ids_launch(864-872)moe_a2a_sanitize_expert_ids_launch(864-866)
flashinfer/comm/__init__.py (1)
flashinfer/comm/trtllm_moe_alltoall.py (11)
MoeAlltoAll(393-732)moe_a2a_combine(99-140)moe_a2a_combine(319-342)moe_a2a_dispatch(53-93)moe_a2a_dispatch(257-315)moe_a2a_initialize(41-47)moe_a2a_initialize(202-210)moe_a2a_get_workspace_size_per_rank(359-390)moe_a2a_sanitize_expert_ids(146-155)moe_a2a_sanitize_expert_ids(346-355)moe_a2a_wrap_payload_tensor_in_workspace(214-253)
csrc/nv_internal/tensorrt_llm/common/envUtils.h (1)
csrc/nv_internal/cpp/common/envUtils.cpp (8)
getEnvKVCacheTimeOutputPath(275-278)getEnvKVCacheTimeOutputPath(275-275)getEnvMoeA2AOneBlockPerToken(326-333)getEnvMoeA2AOneBlockPerToken(326-326)getEnvMoeA2ADispatchBlockSize(347-350)getEnvMoeA2ADispatchBlockSize(347-347)getEnvMoeA2ACombineBlockSize(352-355)getEnvMoeA2ACombineBlockSize(352-352)
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (1)
csrc/nv_internal/cpp/common/envUtils.cpp (4)
getEnvMoeA2ADispatchBlockSize(347-350)getEnvMoeA2ADispatchBlockSize(347-347)getEnvMoeA2ACombineBlockSize(352-355)getEnvMoeA2ACombineBlockSize(352-352)
flashinfer/comm/trtllm_moe_alltoall.py (2)
flashinfer/comm/mnnvl.py (5)
MnnvlMemory(232-551)MnnvlConfig(224-229)as_torch_strided_tensor(264-273)initialize(276-285)set_comm_from_config(288-293)include/flashinfer/trtllm/fused_moe/runner.h (2)
num_experts(263-263)hidden_size(265-265)
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h (3)
csrc/nv_internal/tensorrt_llm/common/envUtils.h (1)
tensorrt_llm(23-104)csrc/nv_internal/tensorrt_llm/thop/moeAlltoAllMeta.h (1)
moe_alltoall(25-58)csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (10)
moe_a2a_dispatch_launch(445-506)moe_a2a_dispatch_launch(445-445)moe_a2a_prepare_dispatch_launch(436-439)moe_a2a_prepare_dispatch_launch(436-436)moe_a2a_combine_launch(792-842)moe_a2a_combine_launch(792-792)moe_a2a_prepare_combine_launch(748-786)moe_a2a_prepare_combine_launch(748-748)moe_a2a_sanitize_expert_ids_launch(864-872)moe_a2a_sanitize_expert_ids_launch(864-866)
tests/comm/test_mnnvl_moe_alltoall.py (2)
flashinfer/comm/trtllm_moe_alltoall.py (3)
dispatch(596-653)get_combine_payload_tensor_in_workspace(699-732)combine(656-696)flashinfer/comm/mnnvl.py (1)
initialize(276-285)
🪛 Ruff (0.14.8)
flashinfer/comm/trtllm_moe_alltoall.py
408-408: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
514-514: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
515-515: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
554-554: Avoid specifying long messages outside the exception class
(TRY003)
556-556: Avoid specifying long messages outside the exception class
(TRY003)
720-722: Avoid specifying long messages outside the exception class
(TRY003)
tests/comm/test_mnnvl_moe_alltoall.py
34-34: Avoid specifying long messages outside the exception class
(TRY003)
574-574: Do not catch blind exception: Exception
(BLE001)
673-673: Do not catch blind exception: Exception
(BLE001)
706-706: Unpacked variable expert_id_payload_index is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🔇 Additional comments (33)
scripts/task_test_multi_node_comm_kernels.sh (1)
19-19: LGTM!The new test invocation follows the established pattern for other multi-node communication tests.
flashinfer/jit/__init__.py (1)
81-81: LGTM!The new import follows the established pattern for JIT module exports and properly surfaces the MOE A2A module generation functionality.
docs/api/comm.rst (1)
132-146: LGTM!The new documentation section properly covers the MNNVL A2A (Throughput Backend) public API. The autosummary list is complete and consistent with the symbols exported in
flashinfer/comm/__init__.py.csrc/nv_internal/tensorrt_llm/common/envUtils.h (2)
67-67: Note the API rename and return type change.The function
getEnvKVCacheTransferOutputPath()has been renamed togetEnvKVCacheTimeOutputPath()with a return type change fromstd::stringtostd::string const&. This is a breaking change for any external callers, though the implementation inenvUtils.cppproperly supports the new signature.
95-102: LGTM!The new MOE A2A environment utility functions are well-documented with clear default behaviors and follow the established patterns in this file.
csrc/nv_internal/cpp/common/envUtils.cpp (4)
275-278: LGTM!The renamed function implementation correctly returns a
const&to a static string, which is safe and efficient.
326-333: LGTM!The default-true behavior for
getEnvMoeA2AOneBlockPerToken()is clearly documented and the implementation correctly handles the unset case.
335-345: LGTM!The
sanitizeBlockSize()helper correctly clamps block sizes to CUDA-valid bounds (256-1024) and rounds to warp-size (32) multiples. The check on line 343 is defensive but redundant since rounding a value ≥1 to a multiple of 32 will always yield ≥32.
347-357: LGTM!The new MOE A2A block size getters properly apply sanitization to the environment variables, ensuring valid CUDA kernel launch parameters.
flashinfer/comm/__init__.py (1)
42-56: LGTM!The new MNNVL A2A exports are properly structured and consistent with the documentation. The symbols are correctly imported from
trtllm_moe_alltoallusing the explicit import-as pattern.csrc/trtllm_moe_alltoall.cu (8)
49-88: LGTM!The
calculateOffsets()function correctly computes the workspace layout with appropriate cacheline alignment for flag and data regions.
95-117: LGTM!The initialization operation properly validates inputs, zeroes the rank's workspace asynchronously, and synchronizes before returning the metainfo. Error checking for CUDA operations is thorough.
119-196: LGTM!The dispatch operation includes comprehensive input validation for token_selected_experts, payloads, workspace, and metadata. The alignment check on line 183-185 correctly catches potential misalignment issues early with a helpful error message.
198-255: LGTM!The parameter setup, kernel launch, and result construction are correct. The use of environment-based configuration (
getEnvMoeA2AOneBlockPerToken()) provides runtime flexibility, and error checking after the kernel launch is appropriate.
257-270: LGTM!The dtype conversion helper correctly maps DLPack types to TensorRT types for the supported MOE combine dtypes, with appropriate error handling for unsupported types.
272-354: LGTM!The combine operation includes thorough validation of all inputs, proper handling of both in-workspace and external payloads, and correct parameter setup for the CUDA kernel. The pointer validation when
payloadInWorkspace=True(lines 309-314) prevents subtle bugs.
390-404: LGTM!The metainfo index pairs helper correctly wraps the C++ implementation and returns the data in a Python-friendly format (tuple of arrays).
408-413: LGTM!All TVM FFI exports are properly declared and follow the established naming convention.
tests/comm/test_trtllm_moe_alltoall.py (11)
28-32: LGTM!The fixture docstring correctly describes its purpose after being updated per the past review feedback.
67-74: LGTM!The SM count check correctly skips tests when insufficient GPU resources are available for concurrent kernel execution.
77-89: LGTM!The payload generation helper correctly handles both integer and floating-point dtypes with appropriate random value generation.
100-168: LGTM!The single-GPU test provides good end-to-end coverage of the dispatch and combine workflow, including validation that shuffled data can be recovered after sorting.
171-247: LGTM!The multi-rank dispatch helper correctly orchestrates parallel dispatch operations across ranks using separate CUDA streams and proper synchronization.
250-306: LGTM!The sanitize and combine helpers provide clean abstractions for multi-rank testing scenarios.
310-351: LGTM!The multi-rank test correctly validates that dispatched payloads match the expected input data for each rank's assigned experts.
355-394: LGTM!The sanitize test properly validates that invalid expert IDs are masked while valid ones are preserved, matching the expected behavior per rank.
397-438: LGTM!The
fake_moereference implementation provides a clear, deterministic baseline for validating the combine operation results.
444-558: LGTM!The combine test thoroughly validates multi-rank MOE combine operations against the reference implementation, testing both workspace-backed and external payload modes across multiple dtypes.
565-614: LGTM!The workspace sizing test validates the consistency between the raw sizing API and the
MoeAlltoAllclass's sizing utilities, ensuring correct workspace allocation.csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h (1)
23-167: Struct shapes / constness look coherent for host-packing → device consumption.
The split ofDispatchKernelPointers(const src + mutable recv) vsCombineKernelPointers(mutable output + const recv) is a clean contract.flashinfer/comm/trtllm_moe_alltoall.py (2)
595-653: Nice: state machine + optional sanitize hook is easy to reason about.
Theidle → dispatched → idleflow and keepingcombine_payload_offsetin state makes the public API hard to misuse.
32-198: Themutates_argscontract is correctly implemented. PyTorch'storch.library.custom_opexpects argument names (strings), not indices, and flashinfer's type hintUnion[str, Iterable[str]]correctly reflects this. The usage patternmutates_args=("workspace",)in this file is canonical across the entire codebase (also seen inxqa.pywithmutates_args=("output", "workspace_buffer")and insampling.pywithmutates_args=()). This format will work correctly with FX alias analysis whenregister_custom_opis fully enabled.csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu (1)
111-116: Clarify timeout comment—"prevent a hang" is misleading for a ~35-day timeout.
The constant(300ll * 2000ll * 1000ll * 1000ll)cycles translates to millions of seconds (order of days) at typical GPU clock rates, making "will prevent a hang" inaccurate. The timeout effectively disables early termination rather than preventing hangs. Consider rephrasing to clarify the actual intent: whether this is meant as a safety measure that should never fire, or if the duration is a genuine concern for the all-to-all synchronization workload.
| template <typename ThreadingPolicy, int TOP_K> | ||
| __global__ void moeA2ADispatchKernel( | ||
| int32_t const* token_selected_experts, // [local_num_tokens, TOP_K] | ||
| const DispatchKernelPointers ptrs, // Struct containing all kernel pointers | ||
| int num_payloads, // Number of payloads | ||
| int max_tokens_per_rank, // Maximum tokens per rank | ||
| int local_num_tokens, int rank_id, int ep_size, int num_experts_per_rank) { | ||
| int thread_idx = ThreadingPolicy::offset(); | ||
| int local_token_idx = ThreadingPolicy::token_idx(); | ||
|
|
||
| if (local_token_idx >= local_num_tokens) { | ||
| return; | ||
| } | ||
|
|
||
| // Prepare per-policy shared-memory tiles for this token | ||
| extern __shared__ int smem[]; | ||
| int* smem_topk_target_ranks; | ||
| int* smem_topk_send_indices; | ||
| int warps_per_block = blockDim.x / warpSize; | ||
| if constexpr (std::is_same<ThreadingPolicy, WarpPolicy>::value) { | ||
| int lane_id = threadIdx.x / warpSize; | ||
| smem_topk_target_ranks = smem + lane_id * TOP_K; | ||
| smem_topk_send_indices = smem + warps_per_block * TOP_K + lane_id * TOP_K; | ||
| } else { | ||
| smem_topk_target_ranks = smem; | ||
| smem_topk_send_indices = smem + TOP_K; | ||
| } | ||
|
|
||
| uint64_t already_copied = 0; | ||
| for (int k = 0; k < TOP_K; k++) { | ||
| int expert_id = token_selected_experts[local_token_idx * TOP_K + k]; | ||
| // Use contiguous partitioning to determine target rank | ||
| int target_rank = compute_target_rank_id(expert_id, num_experts_per_rank); | ||
|
|
||
| if (already_copied & (1ULL << target_rank)) { | ||
| if (thread_idx == 0) { | ||
| ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = -1; | ||
| ptrs.topk_send_indices[local_token_idx * TOP_K + k] = -1; | ||
| // Mirror to shared memory immediately | ||
| smem_topk_target_ranks[k] = -1; | ||
| smem_topk_send_indices[k] = -1; | ||
| } | ||
| continue; | ||
| } | ||
|
|
||
| // Only one thread per warp should increment the counter | ||
| int dst_token_idx; | ||
| if (thread_idx == 0) { | ||
| dst_token_idx = atomicAdd(&ptrs.send_counters[target_rank], 1); | ||
|
|
||
| ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = target_rank; | ||
| ptrs.topk_send_indices[local_token_idx * TOP_K + k] = dst_token_idx; | ||
| // Mirror to shared memory immediately | ||
| smem_topk_target_ranks[k] = target_rank; | ||
| smem_topk_send_indices[k] = dst_token_idx; | ||
| } | ||
| already_copied |= 1ULL << target_rank; | ||
| } | ||
| // Sync before dispatching data | ||
| ThreadingPolicy::sync(); | ||
|
|
||
| // Read staged routing once into registers per thread | ||
| int topk_target_ranks[TOP_K]; | ||
| int topk_send_indices[TOP_K]; | ||
| #pragma unroll | ||
| for (int k = 0; k < TOP_K; ++k) { | ||
| topk_target_ranks[k] = smem_topk_target_ranks[k]; | ||
| topk_send_indices[k] = smem_topk_send_indices[k]; | ||
| } | ||
|
|
||
| // Perform a single source load and TOP_K fanout per payload | ||
| for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++) { | ||
| uint8_t const* src_data = static_cast<uint8_t const*>(ptrs.src_data_ptrs[payload_idx]); | ||
| int bytes_per_token = ptrs.payload_bytes_per_token[payload_idx]; | ||
| uint8_t const* src_ptr = src_data + local_token_idx * bytes_per_token; | ||
|
|
||
| vectorized_dispatch<TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, | ||
| max_tokens_per_rank, payload_idx, ptrs, | ||
| topk_target_ranks, topk_send_indices); | ||
| } | ||
|
|
||
| ThreadingPolicy::sync(); | ||
|
|
||
| bool is_first_warp = threadIdx.x / warpSize == 0; | ||
| if (is_first_warp) { | ||
| int lane_id = threadIdx.x % warpSize; | ||
|
|
||
| bool is_last_token = false; | ||
| if (lane_id == 0) { | ||
| int cnt = atomicAdd(ptrs.local_token_counter, 1); | ||
| is_last_token = cnt + 1 == local_num_tokens; | ||
| } | ||
| is_last_token = __shfl_sync(0xffffffff, is_last_token, 0); | ||
|
|
||
| if (is_last_token) { | ||
| // Store send_counters to recv_counters | ||
| #pragma unroll 1 // No unroll as one iter is typically enough | ||
| for (int target_rank = lane_id; target_rank < ep_size; target_rank += warpSize) { | ||
| int send_count = ptrs.send_counters[target_rank]; | ||
| ptrs.recv_counters[target_rank][rank_id] = send_count; | ||
| } | ||
|
|
||
| #if !DISABLE_SYNC_FOR_PROFILING | ||
| uint32_t expected_value = *ptrs.flag_val; | ||
|
|
||
| asm volatile("fence.release.sys;"); | ||
| #pragma unroll 1 // No unroll as one iter is typically enough | ||
| for (int target_rank = lane_id; target_rank < ep_size; target_rank += warpSize) { | ||
| uint32_t* flag_addr = &ptrs.completion_flags[target_rank][rank_id]; | ||
| asm volatile("st.relaxed.sys.u32 [%0], %1;" ::"l"(flag_addr), "r"(expected_value)); | ||
|
|
||
| #if ENABLE_DEBUG_PRINT | ||
| printf("dispatch: +++Rank %d setting completion flag to %d for rank %d\n", rank_id, | ||
| expected_value, target_rank); | ||
| #endif | ||
| } | ||
|
|
||
| #pragma unroll 1 // No unroll | ||
| for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize) { | ||
| bool flag_set = false; | ||
| [[maybe_unused]] auto s = clock64(); | ||
| do { | ||
| uint32_t* flag_ptr = &ptrs.completion_flags[rank_id][peer_rank]; | ||
| uint32_t flag_value; | ||
| // Acquire load to ensure visibility of peer's release-store | ||
| asm volatile("ld.relaxed.sys.u32 %0, [%1];" : "=r"(flag_value) : "l"(flag_ptr)); | ||
| #if ENABLE_DEBUG_PRINT | ||
| printf( | ||
| "dispatch: ---Rank %d received completion flag from rank %d, flag_value: %d, " | ||
| "expected_value: " | ||
| "%d, address: %p\n", | ||
| rank_id, peer_rank, flag_value, expected_value, flag_ptr); | ||
| #endif | ||
| flag_set = flag_value == expected_value; | ||
| } while (!flag_set && !check_timeout(s)); | ||
|
|
||
| if (__builtin_expect(!flag_set, 0)) { | ||
| printf("dispatch: ---Rank %d timed out waiting for completion flag from rank %d\n", | ||
| rank_id, peer_rank); | ||
| asm volatile("trap;"); | ||
| return; | ||
| } | ||
| } | ||
| // asm volatile("fence.acquire.sys;"); | ||
| #endif | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CRITICAL: local_token_counter / “last token” detection is wrong for WarpPolicy (only first warp participates).
With WarpPolicy, each block processes multiple tokens (one per warp), but only the first warp increments ptrs.local_token_counter. That breaks “last token” detection and can cause missing/early flag signaling and deadlocks/corruption when one_block_per_token == false.
- bool is_first_warp = threadIdx.x / warpSize == 0;
- if (is_first_warp) {
- int lane_id = threadIdx.x % warpSize;
-
- bool is_last_token = false;
- if (lane_id == 0) {
- int cnt = atomicAdd(ptrs.local_token_counter, 1);
- is_last_token = cnt + 1 == local_num_tokens;
- }
- is_last_token = __shfl_sync(0xffffffff, is_last_token, 0);
-
- if (is_last_token) {
+ int lane_id = threadIdx.x % warpSize;
+ int warp_id = threadIdx.x / warpSize;
+
+ bool is_last_token = false;
+ if constexpr (std::is_same<ThreadingPolicy, WarpPolicy>::value) {
+ // One token per warp => count once per warp.
+ if (lane_id == 0) {
+ int cnt = atomicAdd(ptrs.local_token_counter, 1);
+ is_last_token = (cnt + 1) == local_num_tokens;
+ }
+ is_last_token = __shfl_sync(0xffffffff, is_last_token, 0);
+ } else {
+ // One token per block => count once per block (use warp0 lanes).
+ if (warp_id == 0 && lane_id == 0) {
+ int cnt = atomicAdd(ptrs.local_token_counter, 1);
+ is_last_token = (cnt + 1) == local_num_tokens;
+ }
+ // Broadcast to warp0; only warp0 will run the finalization section below.
+ is_last_token = __shfl_sync(0xffffffff, is_last_token, 0);
+ }
+
+ if (is_last_token && (std::is_same<ThreadingPolicy, WarpPolicy>::value || warp_id == 0)) {
// Store send_counters to recv_counters
...
}
- }(ellipsis = keep your existing finalization body)
🤖 Prompt for AI Agents
In
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu
around lines 287-434, the last-token detection only increments
ptrs.local_token_counter from the first warp, which is incorrect for WarpPolicy
where each warp processes its own token(s). Fix by making the atomicAdd be
performed once per warp (not only once per block): replace the current "if
(is_first_warp) { ... if (lane_id==0) { int cnt = atomicAdd(...); ... }
is_last_token = __shfl_sync(...)" pattern with a ThreadingPolicy-aware approach
where for WarpPolicy the lane with threadIdx.x % warpSize == 0 performs the
atomicAdd and broadcasts the is_last_token to its warp (use __shfl_sync), and
for non-warp policies keep the existing single-thread-per-block atomicAdd and
broadcast; ensure the final is_last_token check and subsequent peer flag logic
use the warp-broadcasted value so every warp participates correctly in storing
send_counters and flag setting.
| int dst_token_idx; | ||
| if (thread_idx == 0) { | ||
| dst_token_idx = atomicAdd(&ptrs.send_counters[target_rank], 1); | ||
|
|
||
| ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = target_rank; | ||
| ptrs.topk_send_indices[local_token_idx * TOP_K + k] = dst_token_idx; | ||
| // Mirror to shared memory immediately | ||
| smem_topk_target_ranks[k] = target_rank; | ||
| smem_topk_send_indices[k] = dst_token_idx; | ||
| } | ||
| already_copied |= 1ULL << target_rank; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard against atomicAdd(send_counters[target_rank]) overflowing max_tokens_per_rank.
If callers underestimate runtime_max_tokens_per_rank (or routing is more skewed than expected), dst_token_idx can exceed the allocated per-rank slot count and you’ll write past recv_buffers[target_rank][payload].
if (thread_idx == 0) {
dst_token_idx = atomicAdd(&ptrs.send_counters[target_rank], 1);
+ if (__builtin_expect(dst_token_idx >= max_tokens_per_rank, 0)) {
+ printf("dispatch: overflow max_tokens_per_rank=%d for target_rank=%d\n",
+ max_tokens_per_rank, target_rank);
+ asm volatile("trap;");
+ }Also applies to: 357-366
🤖 Prompt for AI Agents
In
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu
around lines 333-343 (and similarly 357-366), the atomicAdd on
ptrs.send_counters[target_rank] can produce dst_token_idx >= max_tokens_per_rank
causing out-of-bounds writes; change the logic to check the returned
dst_token_idx against ptrs.runtime_max_tokens_per_rank (or the per-rank max)
immediately after atomicAdd, and if it is >= max: revert the increment (or
atomically decrement), mark/record an overflow/error for that target (so the
caller can handle/skewed routing can be retried), skip writing into send
indices/target arrays and shared memory, and ensure any downstream code does not
use that dst_token_idx; alternatively, clamp within bounds only if you also
ensure no overwrite and signal overflow — implement the guard check and an
explicit overflow handling path in both places.
| template <int VEC_SIZE, int TOP_K, typename ThreadingPolicy, typename T> | ||
| __device__ void vectorized_combine_impl(T* dst_typed_base, int size_per_token, int rank_id, | ||
| int max_tokens_per_rank, | ||
| CombineKernelPointers const& ptrs) { | ||
| constexpr int elems_per_vec = VEC_SIZE / sizeof(T); | ||
| using flashinfer::vec_t; | ||
|
|
||
| uint8_t* dst_bytes = reinterpret_cast<uint8_t*>(dst_typed_base); | ||
|
|
||
| int const stride = ThreadingPolicy::stride() * VEC_SIZE; | ||
| int const local_token_idx = ThreadingPolicy::token_idx(); | ||
|
|
||
| for (int offset = ThreadingPolicy::offset() * VEC_SIZE; offset < size_per_token; | ||
| offset += stride) { | ||
| vec_t<uint8_t, VEC_SIZE> acc[TOP_K]; | ||
|
|
||
| // Unrolled K accumulation using compact top-k lists | ||
| #pragma unroll | ||
| for (int k = 0; k < TOP_K; ++k) { | ||
| int target_rank = ptrs.topk_target_ranks[local_token_idx * TOP_K + k]; | ||
| int dst_idx = ptrs.topk_send_indices[local_token_idx * TOP_K + k]; | ||
| if (dst_idx < 0) { | ||
| acc[k].fill(0); | ||
| continue; | ||
| } | ||
|
|
||
| uint8_t const* recv_buffer = static_cast<uint8_t const*>(ptrs.recv_buffers[target_rank][0]); | ||
| size_t base_source_rank = | ||
| static_cast<size_t>(rank_id) * static_cast<size_t>(max_tokens_per_rank) + | ||
| static_cast<size_t>(dst_idx); | ||
| size_t base_token = base_source_rank * static_cast<size_t>(size_per_token); | ||
|
|
||
| // Load directly into the per-k accumulator; reduce across k below | ||
| acc[k].load(recv_buffer + base_token + offset); | ||
| } | ||
|
|
||
| // Reduce acc[TOP_K] into acc[0] | ||
| if constexpr (TOP_K == 8) { | ||
| T* a0 = reinterpret_cast<T*>(&acc[0]); | ||
| T* a1 = reinterpret_cast<T*>(&acc[1]); | ||
| T* a2 = reinterpret_cast<T*>(&acc[2]); | ||
| T* a3 = reinterpret_cast<T*>(&acc[3]); | ||
| T* a4 = reinterpret_cast<T*>(&acc[4]); | ||
| T* a5 = reinterpret_cast<T*>(&acc[5]); | ||
| T* a6 = reinterpret_cast<T*>(&acc[6]); | ||
| T* a7 = reinterpret_cast<T*>(&acc[7]); | ||
| #pragma unroll | ||
| for (int j = 0; j < elems_per_vec; ++j) { | ||
| a0[j] += a1[j]; | ||
| a2[j] += a3[j]; | ||
| a4[j] += a5[j]; | ||
| a6[j] += a7[j]; | ||
| } | ||
| #pragma unroll | ||
| for (int j = 0; j < elems_per_vec; ++j) { | ||
| a0[j] += a2[j]; | ||
| a4[j] += a6[j]; | ||
| } | ||
| #pragma unroll | ||
| for (int j = 0; j < elems_per_vec; ++j) { | ||
| a0[j] += a4[j]; | ||
| } | ||
| } else if constexpr (TOP_K == 4) { | ||
| T* a0 = reinterpret_cast<T*>(&acc[0]); | ||
| T* a1 = reinterpret_cast<T*>(&acc[1]); | ||
| T* a2 = reinterpret_cast<T*>(&acc[2]); | ||
| T* a3 = reinterpret_cast<T*>(&acc[3]); | ||
| #pragma unroll | ||
| for (int j = 0; j < elems_per_vec; ++j) { | ||
| a0[j] += a1[j]; | ||
| a2[j] += a3[j]; | ||
| } | ||
| #pragma unroll | ||
| for (int j = 0; j < elems_per_vec; ++j) { | ||
| a0[j] += a2[j]; | ||
| } | ||
| } else if constexpr (TOP_K == 2) { | ||
| T* a0 = reinterpret_cast<T*>(&acc[0]); | ||
| T* a1 = reinterpret_cast<T*>(&acc[1]); | ||
| #pragma unroll | ||
| for (int j = 0; j < elems_per_vec; ++j) { | ||
| a0[j] += a1[j]; | ||
| } | ||
| } else if constexpr (TOP_K == 1) { | ||
| // nothing to do | ||
| } else { | ||
| // Generic fallback: accumulate all into acc[0] | ||
| T* a0 = reinterpret_cast<T*>(&acc[0]); | ||
| #pragma unroll | ||
| for (int k = 1; k < TOP_K; ++k) { | ||
| T* ak = reinterpret_cast<T*>(&acc[k]); | ||
| #pragma unroll | ||
| for (int j = 0; j < elems_per_vec; ++j) { | ||
| a0[j] += ak[j]; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| acc[0].store(dst_bytes + offset); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🌐 Web query:
NVIDIA CUDA bf16 fp16 fp32 accumulation pattern float reduction TensorRT-LLM
💡 Result:
Short summary — how NVIDIA handles BF16/FP16/FP32 accumulation and reductions (Tensor cores, CUDA, TensorRT / TensorRT-LLM):
-
Tensor cores: BF16 and FP16 inputs typically use a higher-precision accumulator for GEMM/WMMA/HMMA. BF16 matrix ops are implemented with FP32 accumulation on NVIDIA hardware; FP16 tensor ops may be accumulated in either FP16 or FP32 depending on ISA/arch and chosen mode (FP32 accumulation is common for numerical stability). [2][1]
-
CUDA / performance guidance: NVIDIA’s tuning docs and TensorRT best-practices recommend using FP32 accumulation for key BLAS/GEMM/BMM stages (especially BMM1/BMM2 inside MHA) when using FP16/BF16 inputs to avoid loss of accuracy in reductions. On SM architectures that support it, fused kernels still use FP32 accumulators for higher-precision sums. [2][3]
-
TensorRT / TensorRT-LLM: TensorRT-LLM supports FP32, FP16 and BF16 model precisions; its MHA/GEMM fusions and plugins often use FP32 accumulation even when weights/activations are FP16/BF16. The TensorRT-LLM docs and NVIDIA blog note that some plugins (e.g., gemm/attention) explicitly perform matmuls with FP32 accumulation for accuracy. TensorRT may choose FP32 accumulation for BMM stages in LLM inference. [1][4]
-
Practical rule of thumb for reductions (dot-products, matmul, softmax logits, attention sums):
- Use FP32 accumulator for long reductions / large K/N to avoid catastrophic cancellation and to match hardware fused-tactic defaults. [3][2]
- BF16 gives FP32-like dynamic range but lower mantissa — still benefit from FP32 accumulation for sums. [2][6]
- If latency/memory is critical, FP16 accumulation is faster on some paths but risks degraded numeric stability; benchmark with your model and enable FP32 accumulation where accuracy matters (Torch-TensorRT offers use_fp32_acc option). [5]
References
- TensorRT-LLM — Numerical Precision. [1]
- NVIDIA CUDA / Ampere Tuning Guide (tensor core formats & accumulator notes). [2]
- TensorRT — Best practices / Accumulation precision for MHA/BMM. [3]
- NVIDIA Developer Blog — Optimizing LLMs with TensorRT-LLM (notes on gemm plugin with FP32 accumulation). [4]
- Torch-TensorRT mixed-precision doc (use_fp32_acc flag). [5]
🌐 Web query:
vectorized accumulation bf16 fp16 CUDA best practices precision loss
💡 Result:
Summary of practical recommendations for vectorized accumulation with BF16/FP16 on CUDA (minimize precision loss)
-
Prefer higher-precision accumulation. Use FP32 accumulators for reductions / GEMM outputs where possible (Tensor Cores on Ampere/Hopper/H100 accumulate to FP32; frameworks/NVIDIA recommend FP32 accumulation for FP16/BF16 workloads). [1][2]
-
Use mixed-precision correctly: keep compute in BF16/FP16 for throughput but cast to FP32 for accumulation, weight updates, and final sensitive ops (softmax, final output). Use loss-scaling for FP16 training to avoid underflow. [2][4]
-
Numerically stable reduction patterns: use pairwise / tree-reduce or blockwise pairwise summation (rather than simple linear accumulate) to reduce rounding error in long vectors; when extreme accuracy is required, use compensated summation (Kahan or Neumaier) for optimizer accumulators or long-running sums — this is cheap memory-wise for BF16 and yields much lower error. [3]
-
Leverage hardware/SDK features: use vendor fused kernels (cuBLAS/cuDNN/TensorRT) that choose safe accumulation precision and optimized tiling/stride for Tensor Cores; follow their recommended tensor shapes (multiples of 8/16) for best performance/accuracy tradeoffs. [1][2]
-
Vectorization & memory layout:
- Align loads/stores and use contiguous layouts (NHWC vs NCHW depending on kernel) to maximize throughput and avoid extra conversions that can expose precision loss.
- Process reductions in tiles that fit L1/shared memory to accumulate locally in FP32, then reduce across tiles in FP32 to minimize cross-thread rounding. [2]
-
Keep a thin FP32 “shadow” where needed:
- Maintain FP32 master weights or FP32 optimizer states (or compensated sums) for training; cast to BF16/FP16 for forward/backward to save memory while preserving update accuracy. [2][3]
-
Quantify and test:
- Compare to an FP32 baseline, profile error growth per layer/step, and selectively keep sensitive layers (embeddings, softmax, layernorm, rotary buffers) in FP32 if you observe instability. [3][4]
-
Performance/accuracy knobs:
- If hardware/SDK supports selecting accumulation precision, prefer FP32 accumulation for BF16/FP16 BMM/BMM2 (default on modern NVIDIA architectures); only lower accumulation precision for a tight performance win after careful validation. [1]
Primary references
- NVIDIA best-practices / TensorRT & mixed-precision guidance (accumulation precision, Tensor Core behavior, shape constraints). [1][2]
- Low-precision training with compensated summation (Kahan) — practical for BF16/FP16 optimizers. [3]
- Mixed-precision coding tips (cast to FP32 for accuracy-sensitive ops, align/vectorize memory). [4]
If you want, I can: (a) give a short CUDA kernel pattern showing tile-local FP16 compute + FP32 tile accumulation + FP32 final reduction, or (b) show a small PyTorch example using Kahan compensation for optimizer updates. Which would you like?
🏁 Script executed:
# Search for similar accumulation patterns in the codebase
rg -A 10 -B 5 "fp32.*accum|float.*accum" --type cuRepository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
# Check for existing bf16/fp16 reduction patterns in TensorRT-LLM kernels
rg -A 15 "reinterpret_cast.*\*>" csrc/nv_internal/tensorrt_llm/kernels --type cu | grep -A 15 -B 5 "half\|bf16"Repository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
# Search for fp32 accumulation patterns in kernel files
find csrc -name "*.cu" -o -name "*.cuh" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 627
🏁 Script executed:
# Search for similar reduction/combine patterns with float accumulation
rg "float.*accum|FP32.*accum" csrc/nv_internal/tensorrt_llm/kernels -A 5 -B 2Repository: flashinfer-ai/flashinfer
Length of output: 4072
🏁 Script executed:
# Check for test tolerances in MOE/combine related tests
fd "moe.*test|combine.*test" csrc --type fRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look for explicit comments about precision or tolerance in combine kernels
rg "precision|tolerance|accumul" csrc/nv_internal/tensorrt_llm/kernels/communicationKernels --type cpp --type cu -i -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
# Read the full file around the vectorized_combine_impl function
wc -l csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cuRepository: flashinfer-ai/flashinfer
Length of output: 152
🏁 Script executed:
# Look at the context around lines 513-613 and find function caller/dtype handling
sed -n '480,650p' csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 8455
🏁 Script executed:
# Check for any comments about precision or dtype conversion in the combine function
sed -n '450,700p' csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu | grep -i "precision\|dtype\|float\|half\|bf16\|accum" -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 917
Use FP32 accumulation for BF16/FP16 combine to match NVIDIA best practices.
The current code accumulates directly in T (half/bf16) precision, which causes precision loss—especially for TOP_K=8 with multiple sequential additions. TensorRT-LLM and NVIDIA kernels (MOE GEMM in this codebase, cuBLAS/cuDNN) use FP32 accumulators for BF16/FP16 inputs. Restructure vectorized_combine_impl to load BF16/FP16 values, accumulate in FP32, then cast back to T for output.
🤖 Prompt for AI Agents
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu
lines 513-613: the kernel currently accumulates directly in template type T
(FP16/BF16), causing precision loss; change the accumulation to use FP32
temporaries by loading values from recv_buffer into a float-vector accumulator,
convert each loaded element from T->float before adding, perform all per-k
reductions in float, and finally cast the reduced float results back to T when
storing to dst_bytes; ensure vector types and element counts (elems_per_vec) are
adjusted to use vec_t<float, VEC_SIZE> for acc, update all reinterpret_casts and
per-element add loops to operate on float arrays, and perform a final
conversion/store from float to T at acc[0].store(dst_bytes + offset).
| #include <cuda_bf16.h> | ||
| #include <cuda_fp16.h> | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Public header is missing required includes for types it uses (e.g., nvinfer1::DataType, cudaStream_t).
MoeA2ACombineParams references nvinfer1::DataType, and both param structs use cudaStream_t, but this header doesn’t include the defining headers. Relying on include order will eventually break downstream builds.
#pragma once
+#include <cuda_runtime_api.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
+
+#include "tensorrt_llm/common/dataType.h" // for nvinfer1::DataTypeAlso applies to: 128-166
🤖 Prompt for AI Agents
In
csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h
around lines 18-20 (and similarly for the structs referenced in lines 128-166),
the header uses nvinfer1::DataType and cudaStream_t but doesn't include their
defining headers; add the proper public includes (e.g., #include <NvInfer.h> for
nvinfer1::DataType and #include <cuda_runtime.h> for cudaStream_t) near the top
of the header so the types are defined without relying on include order.
| @flashinfer_api | ||
| def moe_a2a_wrap_payload_tensor_in_workspace( | ||
| workspace: torch.Tensor, | ||
| leading_shape: list[int], | ||
| slice_start: int, | ||
| slice_end: int, | ||
| dtype: torch.dtype, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Wrap an offset in the workspace into a tensor. | ||
| Args: | ||
| workspace: [ep_size, size_per_rank] or [size_per_rank] workspace tensor | ||
| leading_shape: The leading shape to wrap the tensor with | ||
| slice_start: The start of the slice in the workspace | ||
| slice_end: The end of the slice in the workspace | ||
| dtype: Data type for the output tensor | ||
| Returns: | ||
| tensor: [leading_shape, *] workspace-backed tensor | ||
| """ | ||
| if workspace.ndim == 1: | ||
| workspace = workspace.unsqueeze(0) | ||
| workspace_base = workspace.view(dtype=torch.uint8) | ||
| assert workspace.ndim == 2, "workspace must be shape [ep_size, size_per_rank]" | ||
| assert slice_end - slice_start <= workspace_base.shape[1], ( | ||
| "slice_end - slice_start must belong to a single rank" | ||
| ) | ||
| slice_rank = slice_start // workspace_base.stride(0) | ||
| local_slice_start = slice_start % workspace_base.stride(0) | ||
| slice_length = slice_end - slice_start | ||
| local_slice_end = local_slice_start + slice_length | ||
| assert local_slice_end <= workspace_base.shape[1], ( | ||
| "slice must fall within the workspace size per rank" | ||
| ) | ||
| result = ( | ||
| workspace_base[slice_rank, local_slice_start:local_slice_end] | ||
| .view(dtype=dtype) | ||
| .view(*leading_shape, -1) | ||
| ) | ||
| return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add alignment checks for workspace slicing before .view(dtype) to avoid runtime reshape errors.
Right now, misaligned slice_start/slice_end will throw from PyTorch .view(dtype=...) with a hard-to-debug error.
slice_length = slice_end - slice_start
local_slice_end = local_slice_start + slice_length
+ element_size = torch.tensor([], dtype=dtype).element_size()
+ assert local_slice_start % element_size == 0, "slice_start must be dtype-aligned"
+ assert slice_length % element_size == 0, "slice length must be dtype-aligned"
assert local_slice_end <= workspace_base.shape[1], (
"slice must fall within the workspace size per rank"
)🤖 Prompt for AI Agents
flashinfer/comm/trtllm_moe_alltoall.py lines 213-253: the code slices the
workspace viewed as torch.uint8 then calls .view(dtype=...) which will raise an
opaque runtime error if the slice start or length are not aligned to the target
dtype; add explicit alignment checks before calling .view: compute the target
element size (bytes) for the provided dtype and assert that local_slice_start %
itemsize == 0 and slice_length % itemsize == 0 (and keep the existing check that
the slice fits within the per-rank bytes), and raise a clear AssertionError
describing required byte alignment if the checks fail so callers get a readable
error instead of a PyTorch reshape failure.
| class MoeAlltoAll: | ||
| """ | ||
| Manages MoE All-to-All operations with proper workspace allocation and synchronization. | ||
| This class provides the throughput-optimized backend that supports multiple payloads | ||
| per collective operation, explicit dispatch/combine phases, and workspace-backed tensors. | ||
| Example: | ||
| >>> moe_a2a = MoeAlltoAll(mapping, max_num_tokens=2048, top_k=2, num_experts=8) | ||
| >>> recv = moe_a2a.dispatch(experts, [hidden, ids, scales], batch_size) | ||
| >>> output = moe_a2a.combine(processed, batch_size) | ||
| """ | ||
|
|
||
| # Single shared workspace across the process | ||
| # _WORKSPACE: Optional[dict] = None | ||
| _WORKSPACE_CACHE: dict[tuple[int, int, int, int], dict] = {} | ||
|
|
||
| @classmethod | ||
| def get_workspace( | ||
| cls, | ||
| workspace_size_per_rank: int, | ||
| ep_rank: int, | ||
| ep_size: int, | ||
| max_num_tokens: int, | ||
| mapping: Mapping, | ||
| ) -> dict: | ||
| key = (workspace_size_per_rank, ep_rank, ep_size, max_num_tokens) | ||
| if key in cls._WORKSPACE_CACHE: | ||
| return cls._WORKSPACE_CACHE[key] | ||
| else: | ||
| mnnvl_mem = MnnvlMemory(mapping, workspace_size_per_rank) | ||
| workspace = mnnvl_mem.as_torch_strided_tensor(torch.uint8) | ||
| metainfo = moe_a2a_initialize( | ||
| workspace, | ||
| ep_rank, | ||
| ep_size, | ||
| max_num_tokens, | ||
| ) | ||
| cls._WORKSPACE_CACHE[key] = { | ||
| "workspace_size_per_rank": workspace_size_per_rank, | ||
| "max_num_tokens": max_num_tokens, | ||
| "ep_rank": ep_rank, | ||
| "ep_size": ep_size, | ||
| "mnnvl_mem": mnnvl_mem, | ||
| "workspace": workspace, | ||
| "metainfo": metainfo, | ||
| } | ||
| return cls._WORKSPACE_CACHE[key] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# First, let's understand the structure and locate relevant files
cd /tmp/repo 2>/dev/null || cd .
find . -type f -name "*.py" | grep -E "(mnnvl|moe_alltoall)" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 426
🏁 Script executed:
# Search for MnnvlMemory class definition
rg "class MnnvlMemory" -A 20Repository: flashinfer-ai/flashinfer
Length of output: 1128
🏁 Script executed:
# Search for set_comm_from_config method
rg "set_comm_from_config" -B 5 -A 15Repository: flashinfer-ai/flashinfer
Length of output: 7821
🏁 Script executed:
# Look for Mapping class and its usage
rg "class Mapping" -A 10Repository: flashinfer-ai/flashinfer
Length of output: 533
🏁 Script executed:
# Search for get_workspace usage and calls
rg "get_workspace" -B 3 -A 3Repository: flashinfer-ai/flashinfer
Length of output: 50380
🏁 Script executed:
# Look at the moe_a2a_initialize function and how it uses metainfo
rg "def moe_a2a_initialize" -A 30Repository: flashinfer-ai/flashinfer
Length of output: 4390
🏁 Script executed:
# Search for tests related to MoeAlltoAll or workspace caching
fd "test.*\.py" | xargs rg "MoeAlltoAll\|_WORKSPACE_CACHE" 2>/dev/null | head -30Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Get file size first
wc -l flashinfer/comm/trtllm_moe_alltoall.pyRepository: flashinfer-ai/flashinfer
Length of output: 110
🏁 Script executed:
# Find Mapping definition
rg "class Mapping" -A 30 flashinfer/comm/Repository: flashinfer-ai/flashinfer
Length of output: 1374
🏁 Script executed:
# See the full get_workspace and __init__ methods
sed -n '385,480p' flashinfer/comm/trtllm_moe_alltoall.pyRepository: flashinfer-ai/flashinfer
Length of output: 3972
🏁 Script executed:
# Check if Mapping stores any config/communicator identity
python3 << 'EOF'
import os
os.chdir('.')
# Try to understand what's in a Mapping object
with open('flashinfer/comm/trtllm_moe_alltoall.py', 'r') as f:
content = f.read()
# Find where Mapping is used
if 'Mapping' in content:
print("Mapping is used in trtllm_moe_alltoall.py")
# Find the get_workspace method
start = content.find('def get_workspace')
if start > -1:
end = content.find('\n def ', start + 10)
print("\n--- get_workspace method ---")
print(content[start:end])
EOFRepository: flashinfer-ai/flashinfer
Length of output: 154
🏁 Script executed:
# List all Python files in comm directory
ls -la flashinfer/comm/*.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 1036
🏁 Script executed:
# Get the exact line range mentioned in the review
sed -n '393,440p' flashinfer/comm/trtllm_moe_alltoall.pyRepository: flashinfer-ai/flashinfer
Length of output: 1870
🏁 Script executed:
# Search for set_comm_from_config in the entire codebase
rg "set_comm_from_config" -B 10 -A 20Repository: flashinfer-ai/flashinfer
Length of output: 10774
🏁 Script executed:
# Look for any global state or class variables in MnnvlMemory or related classes
rg "class MnnvlMemory|_comm|_config|global" flashinfer/ --type py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 2525
🏁 Script executed:
# Look for test files
find . -path "*/test*" -name "*.py" -type f -exec grep -l "MoeAlltoAll\|workspace" {} \; 2>/dev/null | head -5Repository: flashinfer-ai/flashinfer
Length of output: 292
🏁 Script executed:
# Check if the mapping parameter is passed to MnnvlMemory or moe_a2a_initialize
python3 << 'EOF'
with open('flashinfer/comm/trtllm_moe_alltoall.py', 'r') as f:
content = f.read()
# Find get_workspace method
import re
match = re.search(r'def get_workspace\([^)]*\):[^}]+?(?=\n @classmethod|\n def |\nclass |\Z)', content, re.DOTALL)
if match:
method_content = match.group(0)
print("=== get_workspace method ===")
print(method_content)
print("\n=== Analysis ===")
if 'mapping' in method_content.lower():
print("✓ 'mapping' is used in the method")
# Check how it's used
if 'MnnvlMemory(mapping' in method_content:
print("✓ mapping is passed to MnnvlMemory")
if 'moe_a2a_initialize(workspace' in method_content:
print("✓ Checking moe_a2a_initialize signature")
else:
print("✗ 'mapping' parameter is NOT used in get_workspace!")
EOFRepository: flashinfer-ai/flashinfer
Length of output: 154
🏁 Script executed:
# Double-check the get_workspace method parameters and usage
sed -n '402,430p' flashinfer/comm/trtllm_moe_alltoall.pyRepository: flashinfer-ai/flashinfer
Length of output: 1056
🏁 Script executed:
# Check if there's any documentation or comments about what Mapping does
rg "Mapping" flashinfer/comm/trtllm_moe_alltoall.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 212
🏁 Script executed:
# Final verification: Does the mapping parameter affect workspace behavior?
python3 << 'EOF'
# Key findings:
# 1. get_workspace takes 'mapping' parameter
# 2. It passes mapping to MnnvlMemory(mapping, workspace_size_per_rank)
# 3. But the cache key is: (workspace_size_per_rank, ep_rank, ep_size, max_num_tokens)
# 4. The mapping parameter is NOT included in the cache key
# 5. Different Mapping objects could potentially have different communicator configurations
# Let's check if mapping should be hashable for use in cache key
with open('flashinfer/comm/trtllm_moe_alltoall.py', 'r') as f:
content = f.read()
# Check if Mapping is from typing or a custom class
if 'from typing' in content and 'Mapping' in content:
print("Mapping might be from typing.Mapping (dict-like interface)")
elif 'Mapping' in content:
print("Mapping is likely a custom class that needs investigation")
# Check what MnnvlMemory does with the mapping
print("\nSearching for MnnvlMemory usage pattern...")
import re
matches = re.findall(r'MnnvlMemory\([^)]+\)', content)
for match in matches:
print(f" {match}")
EOFRepository: flashinfer-ai/flashinfer
Length of output: 154
🏁 Script executed:
# Check the full imports section
python3 << 'EOF'
import re
with open('flashinfer/comm/trtllm_moe_alltoall.py', 'r') as f:
lines = f.readlines()
# Print first 50 lines to see all imports
for i in range(min(50, len(lines))):
print(f"{i+1:3d}: {lines[i]}", end='')
EOFRepository: flashinfer-ai/flashinfer
Length of output: 154
🏁 Script executed:
sed -n '400,410p' flashinfer/comm/trtllm_moe_alltoall.pyRepository: flashinfer-ai/flashinfer
Length of output: 489
🏁 Script executed:
# Final check: Look for the Mapping type definition or import
rg "Mapping" flashinfer/comm/ --type py -B 2 -A 2 | grep -E "class Mapping|from.*Mapping|import.*Mapping" | head -10Repository: flashinfer-ai/flashinfer
Length of output: 540
🏁 Script executed:
# Get the complete method signature
python3 << 'EOF'
with open('flashinfer/comm/trtllm_moe_alltoall.py', 'r') as f:
lines = f.readlines()
# Show get_workspace signature
for i in range(400, 410):
print(f"{i+1}: {lines[i]}", end='')
EOFRepository: flashinfer-ai/flashinfer
Length of output: 154
Workspace cache reuses MnnvlMemory instances across different communicator configurations, causing a correctness bug.
The get_workspace() method caches MnnvlMemory instances using only (workspace_size_per_rank, ep_rank, ep_size, max_num_tokens) as the key, but MnnvlMemory stores the communicator as class-level state via MnnvlMemory.set_comm_from_config(). When two MoeAlltoAll instances are created with the same size parameters but different mapping objects (and thus different communicator configurations), the second instance reuses the cached workspace that was initialized for the first mapping, but the class-level MnnvlMemory.comm has been overwritten for the second mapping. This causes all-to-all operations to execute with mismatched workspace and communicator, leading to data corruption. The cache key must include mapping/communicator identity, or the caching mechanism must be redesigned to avoid reuse across different communicator configurations.
🧰 Tools
🪛 Ruff (0.14.8)
408-408: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
🤖 Prompt for AI Agents
In flashinfer/comm/trtllm_moe_alltoall.py around lines 393 to 440, the workspace
cache key omits the mapping/communicator identity which leads to reusing
MnnvlMemory (that holds class-level comm state) across different mappings and
corrupts all-to-all ops; fix by including a stable identifier for the
mapping/communicator in the cache key (e.g., mapping id/hash or
mapping.get_comm_config()) or avoid caching MnnvlMemory across different
communicator configs (store per-mapping cache entries or validate that stored
metainfo/comm matches the provided mapping and recreate if not); ensure
get_workspace uses this extended key or performs the validation before returning
the cached workspace.
| payloads, expert_id_payload_index = make_bfloat16_payloads( | ||
| local_num_tokens, hidden_size, top_k, rank, token_selected_experts | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Silence unused var warning (expert_id_payload_index) in combine test.
This is just noise, but it will keep lint clean.
- payloads, expert_id_payload_index = make_bfloat16_payloads(
+ payloads, _expert_id_payload_index = make_bfloat16_payloads(
local_num_tokens, hidden_size, top_k, rank, token_selected_experts
)🧰 Tools
🪛 Ruff (0.14.8)
706-706: Unpacked variable expert_id_payload_index is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
In tests/comm/test_mnnvl_moe_alltoall.py around lines 706 to 708, the variable
expert_id_payload_index returned by make_bfloat16_payloads is assigned but
unused, causing a lint unused-variable warning; silence this by explicitly
marking it as unused (for example assign to a prefixed underscore or add a dummy
assignment like _ = expert_id_payload_index or wrap it in a no-op usage),
keeping the original call and variables intact so the test behavior is
unchanged.
📌 Description
This ports the latest MNNVL A2A communication implementation from TRT-LLM
🔍 Related Issues
#2094
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Documentation
Tests
✏️ Tip: You can customize this high-level summary in your review settings.