Skip to content

[Public release 26/04] Introducing EPv2: faster EP, and Engram/PP/CP supports#605

Merged
LyricZhao merged 7 commits into
mainfrom
epv2-release
Apr 29, 2026
Merged

[Public release 26/04] Introducing EPv2: faster EP, and Engram/PP/CP supports#605
LyricZhao merged 7 commits into
mainfrom
epv2-release

Conversation

@LyricZhao
Copy link
Copy Markdown
Collaborator

@LyricZhao LyricZhao commented Apr 23, 2026

With the evolution of hardware, networking, and model architectures, the previous DeepEP V1 had accumulated too much legacy baggage and performance issues. Today, we are excited to introduce DeepEP V2, which includes a complete refactoring of Expert Parallelism — achieving extreme performance with several times fewer SM resources compared to V1, while supporting significantly larger scale-up and scale-out domains — as well as experimental 0 SM Engram, 0 SM Pipeline Parallelism, and 0 SM Context Parallelism all-gather.

We are also happy to announce that we have switched from the NVSHMEM backend to the more lightweight NCCL Gin backend.


New Features

  • Fully JIT (Just-In-Time compilation)
  • NCCL Gin backend
    • Header-only & lightweight
    • Able to reuse existing NCCL communicators
  • EPv2
    • CUDA graph compatible, or faster CPU sync in saving memory mode
    • High-throughput and low-latency APIs unified into a single interface, with a new GEMM layout
    • Larger scale-up & scale-out domain support (up to EP2048)
    • Analytical SM & QP count calculation — no more auto-tuning needed
    • Both hybrid & direct modes remain supported
    • For V3-like legacy training, SM usage reduced from 24 to 4 - 6 while maintaining equivalent or better performance
  • 0 SM Engram (with RDMA)
  • 0 SM PP (with RDMA)
  • 0 SM CP (with Copy Engine)

Notes

  • Buffer size consumption is larger than V1
  • 0 SM RDMA low-latency EP is no longer supported
  • Engram, PP, and CP are experimental features

Still On-going Features

  • Elastic GPU & CPU buffers: A contiguous virtual address space that maps to a hybrid of GPU and CPU physical memory under the hood, enabling fully automatic and transparent Engram or imbalanced EP
  • Reducing intermediate buffer sizes by leveraging EP replay to handle load imbalance
  • All-gather updates and reduce-scatter implementations for DP & TP

Performance

Following V3's configuration, we tested with 8K tokens per batch, 7168 hidden dimensions, top 8 experts, FP8 dispatching, and BF16 combining, and obtained the following results:

Arch NIC type Topo Dispatch Bottleneck Bandwidth Combine Bottleneck Bandwidth #SMs
SM90 CX7 EP 8 x 2 90 GB/s (RDMA) 81 GB/s (RDMA) 12
SM90 CX7 EP 8 x 4 61 GB/s (RDMA) 61 GB/s (RDMA) 6
SM100 CX7 EP 8 x 2 90 GB/s (RDMA) 91 GB/s (RDMA) 12
SM100 N/A EP 8 726 GB/s (NVLink) 740 GB/s (NVLink) 64 (Max perf)
SM100 N/A EP 8 643 GB/s (NVLink) 675 GB/s (NVLink) 24 (Min #SM)

Notes, the results are logical bandwidth. For example, under the EP 8 x 2 case, 90 GB/s actually contains local rank traffic.

Comparing with V1, V2 achieves up to 1.3x peak performance, while saving up to 4x SM count.

We omit results for larger EP configurations for the time being, but encourage interested users to benchmark them directly. Based on our internal experience, we expect the kernel to continue saturating hardware bandwidth at scale.


Contributors

@LyricZhao LyricZhao requested a review from sphish April 23, 2026 05:54
@alpha-baby
Copy link
Copy Markdown
Contributor

build failed on cuda 12.8

dependency:

nvidia-nccl-cu12                         2.30.4
nvidia-nvshmem-cu12                      3.5.19
export PATH=/usr/local/cuda/bin:$PATH
export EP_NVSHMEM_ROOT_DIR=/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem
export EP_NCCL_ROOT_DIR=/opt/conda/lib/python3.10/site-packages/nvidia/nccl
python setup.py bdist_wheel
/opt/conda/lib/python3.10/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
Build summary:
 > Sources: ['csrc/python_api.cpp', 'csrc/kernels/legacy/layout.cu', 'csrc/kernels/legacy/intranode.cu', 'csrc/kernels/legacy/internode.cu', 'csrc/kernels/legacy/internode_ll.cu', 'csrc/kernels/backend/nvshmem.cu', 'csrc/kernels/backend/nccl.cu', 'csrc/kernels/backend/cuda_driver.cu']
 > Includes: ['/root/AntDeepEP-v2-bak/deep_ep/include', '/root/AntDeepEP-v2-bak/third-party/fmt/include', '/usr/local/cuda/include/cccl', '/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem/include', '/opt/conda/lib/python3.10/site-packages/nvidia/nccl/include']
 > Libraries: ['/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem/lib']
 > Compilation flags: {'cxx': ['-O3', '-Wno-deprecated-declarations', '-Wno-unused-variable', '-Wno-sign-compare', '-Wno-reorder', '-Wno-attributes', '-DDISABLE_AGGRESSIVE_PTX_INSTRS'], 'nvcc': ['-O3', '-Xcompiler', '-O3', '--extended-lambda', '--diag-suppress=128,2417', '-rdc=true', '--ptxas-options=--register-usage-level=10', '-DDISABLE_AGGRESSIVE_PTX_INSTRS'], 'nvcc_dlink': ['-dlink', '-L/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem/lib', '-lnvshmem_device']}
 > Link flags: ['-lcuda', '-l:libnvshmem_host.so', '-l:libnvshmem_device.a', '-Wl,-rpath,/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem/lib', '-l:libnccl.so', '-Wl,-rpath,/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib']
 > Arch list: 9.0
 > NVSHMEM path: /opt/conda/lib/python3.10/site-packages/nvidia/nvshmem
 > NCCL path: /opt/conda/lib/python3.10/site-packages/nvidia/nccl
 > Persistent envs:
   > EP_NCCL_ROOT_DIR: /opt/conda/lib/python3.10/site-packages/nvidia/nccl

running bdist_wheel
running build
running build_py
copying deep_ep/__init__.py -> build/lib.linux-x86_64-cpython-310/deep_ep
creating build/lib.linux-x86_64-cpython-310/deep_ep/utils
copying deep_ep/utils/envs.py -> build/lib.linux-x86_64-cpython-310/deep_ep/utils
copying deep_ep/utils/refs.py -> build/lib.linux-x86_64-cpython-310/deep_ep/utils
copying deep_ep/utils/find_pkgs.py -> build/lib.linux-x86_64-cpython-310/deep_ep/utils
copying deep_ep/utils/gate.py -> build/lib.linux-x86_64-cpython-310/deep_ep/utils
copying deep_ep/utils/semantic.py -> build/lib.linux-x86_64-cpython-310/deep_ep/utils
copying deep_ep/utils/event.py -> build/lib.linux-x86_64-cpython-310/deep_ep/utils
copying deep_ep/utils/testing.py -> build/lib.linux-x86_64-cpython-310/deep_ep/utils
copying deep_ep/utils/__init__.py -> build/lib.linux-x86_64-cpython-310/deep_ep/utils
copying deep_ep/utils/math.py -> build/lib.linux-x86_64-cpython-310/deep_ep/utils
copying deep_ep/utils/comm.py -> build/lib.linux-x86_64-cpython-310/deep_ep/utils
creating build/lib.linux-x86_64-cpython-310/deep_ep/buffers
copying deep_ep/buffers/elastic.py -> build/lib.linux-x86_64-cpython-310/deep_ep/buffers
copying deep_ep/buffers/__init__.py -> build/lib.linux-x86_64-cpython-310/deep_ep/buffers
copying deep_ep/buffers/legacy.py -> build/lib.linux-x86_64-cpython-310/deep_ep/buffers
creating build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/common
copying deep_ep/include/deep_ep/common/math.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/common
copying deep_ep/include/deep_ep/common/ptx.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/common
copying deep_ep/include/deep_ep/common/comm.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/common
copying deep_ep/include/deep_ep/common/compiled.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/common
copying deep_ep/include/deep_ep/common/handle.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/common
copying deep_ep/include/deep_ep/common/exception.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/common
copying deep_ep/include/deep_ep/common/layout.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/common
creating build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/impls
copying deep_ep/include/deep_ep/impls/hybrid_dispatch.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/impls
copying deep_ep/include/deep_ep/impls/combine_utils.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/impls
copying deep_ep/include/deep_ep/impls/combine.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/impls
copying deep_ep/include/deep_ep/impls/barrier.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/impls
copying deep_ep/include/deep_ep/impls/dispatch.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/impls
copying deep_ep/include/deep_ep/impls/hybrid_combine.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/impls
copying deep_ep/include/deep_ep/impls/dispatch_copy_epilogue.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/impls
copying deep_ep/include/deep_ep/impls/dispatch_deterministic_prologue.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/impls
copying deep_ep/include/deep_ep/impls/pp_send_recv.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/impls
copying deep_ep/include/deep_ep/impls/combine_reduce_epilogue.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/impls
copying deep_ep/include/deep_ep/impls/engram_fetch.cuh -> build/lib.linux-x86_64-cpython-310/deep_ep/include/deep_ep/impls
running build_ext
W0423 20:44:59.975000 77927 site-packages/torch/utils/cpp_extension.py:531] There are no g++ version bounds defined for CUDA version 12.8
building 'deep_ep._C' extension
creating /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/backend
creating /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/legacy
[1/9] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/legacy/layout.o.d -I/root/AntDeepEP-v2-bak/deep_ep/include -I/root/AntDeepEP-v2-bak/third-party/fmt/include -I/usr/local/cuda/include/cccl -I/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem/include -I/opt/conda/lib/python3.10/site-packages/nvidia/nccl/include -I/opt/conda/lib/python3.10/site-packages/torch/include -I/opt/conda/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/opt/conda/include/python3.10 -c -c /root/AntDeepEP-v2-bak/csrc/kernels/legacy/layout.cu -o /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/legacy/layout.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -Xcompiler -O3 --extended-lambda --diag-suppress=128,2417 -rdc=true --ptxas-options=--register-usage-level=10 -DDISABLE_AGGRESSIVE_PTX_INSTRS -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_C -gencode=arch=compute_90,code=sm_90 -std=c++17
[2/9] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/legacy/intranode.o.d -I/root/AntDeepEP-v2-bak/deep_ep/include -I/root/AntDeepEP-v2-bak/third-party/fmt/include -I/usr/local/cuda/include/cccl -I/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem/include -I/opt/conda/lib/python3.10/site-packages/nvidia/nccl/include -I/opt/conda/lib/python3.10/site-packages/torch/include -I/opt/conda/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/opt/conda/include/python3.10 -c -c /root/AntDeepEP-v2-bak/csrc/kernels/legacy/intranode.cu -o /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/legacy/intranode.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -Xcompiler -O3 --extended-lambda --diag-suppress=128,2417 -rdc=true --ptxas-options=--register-usage-level=10 -DDISABLE_AGGRESSIVE_PTX_INSTRS -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_C -gencode=arch=compute_90,code=sm_90 -std=c++17
[3/9] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/backend/nvshmem.o.d -I/root/AntDeepEP-v2-bak/deep_ep/include -I/root/AntDeepEP-v2-bak/third-party/fmt/include -I/usr/local/cuda/include/cccl -I/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem/include -I/opt/conda/lib/python3.10/site-packages/nvidia/nccl/include -I/opt/conda/lib/python3.10/site-packages/torch/include -I/opt/conda/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/opt/conda/include/python3.10 -c -c /root/AntDeepEP-v2-bak/csrc/kernels/backend/nvshmem.cu -o /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/backend/nvshmem.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -Xcompiler -O3 --extended-lambda --diag-suppress=128,2417 -rdc=true --ptxas-options=--register-usage-level=10 -DDISABLE_AGGRESSIVE_PTX_INSTRS -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_C -gencode=arch=compute_90,code=sm_90 -std=c++17
[4/9] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/backend/cuda_driver.o.d -I/root/AntDeepEP-v2-bak/deep_ep/include -I/root/AntDeepEP-v2-bak/third-party/fmt/include -I/usr/local/cuda/include/cccl -I/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem/include -I/opt/conda/lib/python3.10/site-packages/nvidia/nccl/include -I/opt/conda/lib/python3.10/site-packages/torch/include -I/opt/conda/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/opt/conda/include/python3.10 -c -c /root/AntDeepEP-v2-bak/csrc/kernels/backend/cuda_driver.cu -o /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/backend/cuda_driver.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -Xcompiler -O3 --extended-lambda --diag-suppress=128,2417 -rdc=true --ptxas-options=--register-usage-level=10 -DDISABLE_AGGRESSIVE_PTX_INSTRS -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_C -gencode=arch=compute_90,code=sm_90 -std=c++17
[5/9] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/backend/nccl.o.d -I/root/AntDeepEP-v2-bak/deep_ep/include -I/root/AntDeepEP-v2-bak/third-party/fmt/include -I/usr/local/cuda/include/cccl -I/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem/include -I/opt/conda/lib/python3.10/site-packages/nvidia/nccl/include -I/opt/conda/lib/python3.10/site-packages/torch/include -I/opt/conda/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/opt/conda/include/python3.10 -c -c /root/AntDeepEP-v2-bak/csrc/kernels/backend/nccl.cu -o /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/backend/nccl.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -Xcompiler -O3 --extended-lambda --diag-suppress=128,2417 -rdc=true --ptxas-options=--register-usage-level=10 -DDISABLE_AGGRESSIVE_PTX_INSTRS -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_C -gencode=arch=compute_90,code=sm_90 -std=c++17
[6/9] c++ -MMD -MF /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/python_api.o.d -pthread -B /opt/conda/compiler_compat -Wno-unused-result -Wsign-compare -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /opt/conda/include -fPIC -O2 -isystem /opt/conda/include -fPIC -I/root/AntDeepEP-v2-bak/deep_ep/include -I/root/AntDeepEP-v2-bak/third-party/fmt/include -I/usr/local/cuda/include/cccl -I/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem/include -I/opt/conda/lib/python3.10/site-packages/nvidia/nccl/include -I/opt/conda/lib/python3.10/site-packages/torch/include -I/opt/conda/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/opt/conda/include/python3.10 -c -c /root/AntDeepEP-v2-bak/csrc/python_api.cpp -o /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/python_api.o -O3 -Wno-deprecated-declarations -Wno-unused-variable -Wno-sign-compare -Wno-reorder -Wno-attributes -DDISABLE_AGGRESSIVE_PTX_INSTRS -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_C -std=c++17
FAILED: /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/python_api.o 
c++ -MMD -MF /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/python_api.o.d -pthread -B /opt/conda/compiler_compat -Wno-unused-result -Wsign-compare -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /opt/conda/include -fPIC -O2 -isystem /opt/conda/include -fPIC -I/root/AntDeepEP-v2-bak/deep_ep/include -I/root/AntDeepEP-v2-bak/third-party/fmt/include -I/usr/local/cuda/include/cccl -I/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem/include -I/opt/conda/lib/python3.10/site-packages/nvidia/nccl/include -I/opt/conda/lib/python3.10/site-packages/torch/include -I/opt/conda/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/opt/conda/include/python3.10 -c -c /root/AntDeepEP-v2-bak/csrc/python_api.cpp -o /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/python_api.o -O3 -Wno-deprecated-declarations -Wno-unused-variable -Wno-sign-compare -Wno-reorder -Wno-attributes -DDISABLE_AGGRESSIVE_PTX_INSTRS -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_C -std=c++17
In file included from /usr/local/cuda/include/cuda/__ptx/instructions/barrier_cluster.h:26,
                 from /usr/local/cuda/include/cuda/ptx:72,
                 from /usr/local/cuda/include/cuda/barrier:24,
                 from /root/AntDeepEP-v2-bak/deep_ep/include/deep_ep/common/ptx.cuh:3,
                 from /root/AntDeepEP-v2-bak/deep_ep/include/deep_ep/common/layout.cuh:6,
                 from /root/AntDeepEP-v2-bak/csrc/elastic/buffer.hpp:8,
                 from /root/AntDeepEP-v2-bak/csrc/python_api.cpp:7:
/usr/local/cuda/include/cuda/__ptx/ptx_helper_functions.h: In function ‘uint32_t cuda::ptx::__4::__as_ptr_smem(const void*)’:
/usr/local/cuda/include/cuda/__ptx/ptx_helper_functions.h:42:44: error: ‘__cvta_generic_to_shared’ was not declared in this scope
   42 |   return static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(__ptr));
      |                                            ^~~~~~~~~~~~~~~~~~~~~~~~
/usr/local/cuda/include/cuda/__ptx/ptx_helper_functions.h: In function ‘uint64_t cuda::ptx::__4::__as_ptr_gmem(const void*)’:
/usr/local/cuda/include/cuda/__ptx/ptx_helper_functions.h:62:44: error: ‘__cvta_generic_to_global’ was not declared in this scope
   62 |   return static_cast<_CUDA_VSTD::uint64_t>(__cvta_generic_to_global(__ptr));
      |                                            ^~~~~~~~~~~~~~~~~~~~~~~~
/usr/local/cuda/include/cuda/__ptx/ptx_helper_functions.h: In function ‘_Tp* cuda::ptx::__4::__from_ptr_smem(size_t)’:
/usr/local/cuda/include/cuda/__ptx/ptx_helper_functions.h:74:33: error: there are no arguments to ‘__cvta_shared_to_generic’ that depend on a template parameter, so a declaration of ‘__cvta_shared_to_generic’ must be available [-fpermissive]
   74 |   return reinterpret_cast<_Tp*>(__cvta_shared_to_generic(__ptr));
      |                                 ^~~~~~~~~~~~~~~~~~~~~~~~
/usr/local/cuda/include/cuda/__ptx/ptx_helper_functions.h:74:33: note: (if you use ‘-fpermissive’, G++ will accept your code, but allowing the use of an undeclared name is deprecated)
/usr/local/cuda/include/cuda/__ptx/ptx_helper_functions.h: In function ‘_Tp* cuda::ptx::__4::__from_ptr_gmem(size_t)’:
/usr/local/cuda/include/cuda/__ptx/ptx_helper_functions.h:95:33: error: there are no arguments to ‘__cvta_global_to_generic’ that depend on a template parameter, so a declaration of ‘__cvta_global_to_generic’ must be available [-fpermissive]
   95 |   return reinterpret_cast<_Tp*>(__cvta_global_to_generic(__ptr));
      |                                 ^~~~~~~~~~~~~~~~~~~~~~~~
In file included from /root/AntDeepEP-v2-bak/csrc/python_api.cpp:7:
/root/AntDeepEP-v2-bak/csrc/elastic/buffer.hpp: In member function ‘std::function<at::Tensor()> deep_ep::elastic::ElasticBuffer::engram_fetch(const at::Tensor&, int) const’:
/root/AntDeepEP-v2-bak/csrc/elastic/buffer.hpp:255:20: warning: explicit by-copy capture of ‘this’ redundant with by-copy capture default
  255 |         return [=, this]() {
      |                    ^~~~
In file included from /root/AntDeepEP-v2-bak/csrc/python_api.cpp:7:
/root/AntDeepEP-v2-bak/csrc/elastic/buffer.hpp: In member function ‘std::pair<std::vector<at::Tensor>, std::function<void()> > deep_ep::elastic::ElasticBuffer::all_gather(const std::vector<at::Tensor>&)’:
/root/AntDeepEP-v2-bak/csrc/elastic/buffer.hpp:461:27: warning: explicit by-copy capture of ‘this’ redundant with by-copy capture default
  461 |         auto handle = [=, this]() {
      |                           ^~~~
In file included from /root/AntDeepEP-v2-bak/csrc/python_api.cpp:8:
/root/AntDeepEP-v2-bak/csrc/legacy/buffer.hpp: In member function ‘void deep_ep::legacy::Buffer::clean_low_latency_buffer(int, int, int)’:
/root/AntDeepEP-v2-bak/csrc/legacy/buffer.hpp:1438:35: warning: explicit by-copy capture of ‘this’ redundant with by-copy capture default
 1438 |         auto check_boundary = [=, this](void* ptr, size_t num_bytes) {
      |                                   ^~~~
/root/AntDeepEP-v2-bak/csrc/legacy/buffer.hpp: In member function ‘std::tuple<at::Tensor, std::optional<at::Tensor>, at::Tensor, at::Tensor, at::Tensor, std::optional<deep_ep::EventHandle>, std::optional<std::function<void()> > > deep_ep::legacy::Buffer::low_latency_dispatch(const at::Tensor&, const at::Tensor&, const std::optional<at::Tensor>&, const std::optional<at::Tensor>&, int, int, bool, bool, bool, bool, bool)’:
/root/AntDeepEP-v2-bak/csrc/legacy/buffer.hpp:1545:29: warning: explicit by-copy capture of ‘this’ redundant with by-copy capture default
 1545 |         auto launcher = [=, this](int phases) {
      |                             ^~~~
/root/AntDeepEP-v2-bak/csrc/legacy/buffer.hpp: In member function ‘std::tuple<at::Tensor, std::optional<deep_ep::EventHandle>, std::optional<std::function<void()> > > deep_ep::legacy::Buffer::low_latency_combine(const at::Tensor&, const at::Tensor&, const at::Tensor&, const at::Tensor&, const at::Tensor&, const std::optional<at::Tensor>&, int, int, bool, bool, bool, bool, const std::optional<at::Tensor>&)’:
/root/AntDeepEP-v2-bak/csrc/legacy/buffer.hpp:1668:29: warning: explicit by-copy capture of ‘this’ redundant with by-copy capture default
 1668 |         auto launcher = [=, this](int phases) {
      |                             ^~~~
[7/9] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/legacy/internode_ll.o.d -I/root/AntDeepEP-v2-bak/deep_ep/include -I/root/AntDeepEP-v2-bak/third-party/fmt/include -I/usr/local/cuda/include/cccl -I/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem/include -I/opt/conda/lib/python3.10/site-packages/nvidia/nccl/include -I/opt/conda/lib/python3.10/site-packages/torch/include -I/opt/conda/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/opt/conda/include/python3.10 -c -c /root/AntDeepEP-v2-bak/csrc/kernels/legacy/internode_ll.cu -o /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/legacy/internode_ll.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -Xcompiler -O3 --extended-lambda --diag-suppress=128,2417 -rdc=true --ptxas-options=--register-usage-level=10 -DDISABLE_AGGRESSIVE_PTX_INSTRS -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_C -gencode=arch=compute_90,code=sm_90 -std=c++17
[8/9] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/legacy/internode.o.d -I/root/AntDeepEP-v2-bak/deep_ep/include -I/root/AntDeepEP-v2-bak/third-party/fmt/include -I/usr/local/cuda/include/cccl -I/opt/conda/lib/python3.10/site-packages/nvidia/nvshmem/include -I/opt/conda/lib/python3.10/site-packages/nvidia/nccl/include -I/opt/conda/lib/python3.10/site-packages/torch/include -I/opt/conda/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda/include -I/opt/conda/include/python3.10 -c -c /root/AntDeepEP-v2-bak/csrc/kernels/legacy/internode.cu -o /root/AntDeepEP-v2-bak/build/temp.linux-x86_64-cpython-310/csrc/kernels/legacy/internode.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -Xcompiler -O3 --extended-lambda --diag-suppress=128,2417 -rdc=true --ptxas-options=--register-usage-level=10 -DDISABLE_AGGRESSIVE_PTX_INSTRS -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_C -gencode=arch=compute_90,code=sm_90 -std=c++17
ninja: build stopped: subcommand failed.
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 2597, in _run_ninja_build
    subprocess.run(
  File "/opt/conda/lib/python3.10/subprocess.py", line 526, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/root/AntDeepEP-v2-bak/setup.py", line 170, in <module>
    setuptools.setup(
  File "/opt/conda/lib/python3.10/site-packages/setuptools/__init__.py", line 115, in setup
    return distutils.core.setup(**attrs)
  File "/opt/conda/lib/python3.10/site-packages/setuptools/_distutils/core.py", line 186, in setup
    return run_commands(dist)
  File "/opt/conda/lib/python3.10/site-packages/setuptools/_distutils/core.py", line 202, in run_commands
    dist.run_commands()
  File "/opt/conda/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 1002, in run_commands
    self.run_command(cmd)
  File "/opt/conda/lib/python3.10/site-packages/setuptools/dist.py", line 1102, in run_command
    super().run_command(command)
  File "/opt/conda/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 1021, in run_command
    cmd_obj.run()
  File "/opt/conda/lib/python3.10/site-packages/setuptools/command/bdist_wheel.py", line 370, in run
    self.run_command("build")
  File "/opt/conda/lib/python3.10/site-packages/setuptools/_distutils/cmd.py", line 357, in run_command
    self.distribution.run_command(command)
  File "/opt/conda/lib/python3.10/site-packages/setuptools/dist.py", line 1102, in run_command
    super().run_command(command)
  File "/opt/conda/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 1021, in run_command
    cmd_obj.run()
  File "/opt/conda/lib/python3.10/site-packages/setuptools/_distutils/command/build.py", line 135, in run
    self.run_command(cmd_name)
  File "/opt/conda/lib/python3.10/site-packages/setuptools/_distutils/cmd.py", line 357, in run_command
    self.distribution.run_command(command)
  File "/opt/conda/lib/python3.10/site-packages/setuptools/dist.py", line 1102, in run_command
    super().run_command(command)
  File "/opt/conda/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 1021, in run_command
    cmd_obj.run()
  File "/opt/conda/lib/python3.10/site-packages/setuptools/command/build_ext.py", line 96, in run
    _build_ext.run(self)
  File "/opt/conda/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 368, in run
    self.build_extensions()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 1082, in build_extensions
    build_ext.build_extensions(self)
  File "/opt/conda/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 484, in build_extensions
    self._build_extensions_serial()
  File "/opt/conda/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 510, in _build_extensions_serial
    self.build_extension(ext)
  File "/opt/conda/lib/python3.10/site-packages/setuptools/command/build_ext.py", line 261, in build_extension
    _build_ext.build_extension(self, ext)
  File "/opt/conda/lib/python3.10/site-packages/Cython/Distutils/build_ext.py", line 135, in build_extension
    super(build_ext, self).build_extension(ext)
  File "/opt/conda/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 565, in build_extension
    objects = self.compiler.compile(
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 866, in unix_wrap_ninja_compile
    _write_ninja_file_and_compile_objects(
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 2223, in _write_ninja_file_and_compile_objects
    _run_ninja_build(
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 2614, in _run_ninja_build
    raise RuntimeError(message) from e
RuntimeError: Error compiling objects for extension

@alpha-baby
Copy link
Copy Markdown
Contributor

can't use cuda::barrier in host code?

image

@alpha-baby
Copy link
Copy Markdown
Contributor

cuda/barrier contains device-only CUDA intrinsics (such as __cvta_generic_to_shared), which will cause compilation errors in host code.

The previous issue was caused by the following include chain: python_api.cpp (host code) -> elastic/buffer.hpp -> layout.cuh -> ptx.cuh -> cuda/barrier.

@sphish
Copy link
Copy Markdown
Collaborator

sphish commented Apr 24, 2026

cuda/barrier contains device-only CUDA intrinsics (such as __cvta_generic_to_shared), which will cause compilation errors in host code.

The previous issue was caused by the following include chain: python_api.cpp (host code) -> elastic/buffer.hpp -> layout.cuh -> ptx.cuh -> cuda/barrier.

This appears to be caused by some issues with CUDA 12.8. Thank you for pointing it out, a workaround has been applied.

@QizhouZhang97
Copy link
Copy Markdown

Hi, dear DeepEP developers, I'm interested in replacing nvshmem with NCCL GIN — what would be the main benefits of making that switch?

@alpha-baby
Copy link
Copy Markdown
Contributor

cuda/barrier contains device-only CUDA intrinsics (such as __cvta_generic_to_shared), which will cause compilation errors in host code.
The previous issue was caused by the following include chain: python_api.cpp (host code) -> elastic/buffer.hpp -> layout.cuh -> ptx.cuh -> cuda/barrier.

This appears to be caused by some issues with CUDA 12.8. Thank you for pointing it out, a workaround has been applied.

thx very much. we can build and run on cuda12.8 using branch: https://github.com/deepseek-ai/DeepEP/tree/try-fix-cu128

@gangxie112
Copy link
Copy Markdown

gangxie112 commented Apr 27, 2026

It seems that we don't use companion QP at all. It's a waste of the qps. Is there any to avoid creating companion qp?

@gangxie112
Copy link
Copy Markdown

It seems that we don't use companion QP at all. It's a waste of the qps. Is there any to avoid creating companion qp?
2 issues of GIN which could impact deepep v2:

  1. too many companion QPs even if we don't use it. NCCL will resolve this in next release.
  2. load balance issue, refer [Issue]: GIN main qps load balance problem when use a lag port with queue affinity policy NVIDIA/nccl#2136

* enhance: add env:EP_NIC_NAME to config nic name

* enhance: add env:EP_NIC_NAME to config nic name

* enhance: add env:EP_NIC_NAME to config nic name

* enhance: add env:EP_NIC_NAME to config nic name

* enhance: add env:EP_NIC_NAME to config nic name

---------

Co-authored-by: fujianhao.fjh <fujianhao.fjh@alipay.com>
@xiaofanl-nvidia
Copy link
Copy Markdown

It seems that we don't use companion QP at all. It's a waste of the qps. Is there any to avoid creating companion qp?
2 issues of GIN which could impact deepep v2:

  1. too many companion QPs even if we don't use it. NCCL will resolve this in next release.
  2. load balance issue, refer [Issue]: GIN main qps load balance problem when use a lag port with queue affinity policy NVIDIA/nccl#2136

The first will be optimized here in nccl to avoid creating companion QPs. Tracked here NVIDIA/nccl#2134.
The second one is under review now.
We can use the issues on nccl side to track both of these.
Thanks for reporting!

@michaelchen1996
Copy link
Copy Markdown

We found the get_rdma_gbs() function uses ibstat to detect the RDMA NIC bandwidth.
However, in LACP bonding environments, ibstat reports only half of the actual bandwidth, leading to incorrect SM calculation.
To avoid this problem, we can set the environment variable (such as EP_RDMA_GBS ) to manually specify the correct bandwidth.
#614

@alpha-baby
Copy link
Copy Markdown
Contributor

alpha-baby commented Apr 29, 2026

GB200 performance test

DeepEP v2 在超节点的场景下缺乏适配?,导致在同一个超节点内的通信走的是 rdma

EP4

SM32

BF16性能:

 > Testing with do_handle_copy=0, expert_alignment=1, use_fp8_dispatch=0, num_bias=2, with_previous_event=1, async_with_compute_stream=1, allocate_on_comm_stream=1 ...
   * EP:   2/4 | dispatch: 0 GB/s (SO), 860 GB/s (SU), 293.515 us, 252328160 bytes | copy: 5456 GB/s, 92.491 us
   * EP:   3/4 | dispatch: 0 GB/s (SO), 862 GB/s (SU), 293.365 us, 252933184 bytes | copy: 5472 GB/s, 92.450 us
   * EP:   1/4 | dispatch: 0 GB/s (SO), 862 GB/s (SU), 293.444 us, 252808864 bytes | copy: 5528 GB/s, 91.468 us
   * EP:   0/4 | dispatch: 0 GB/s (SO), 864 GB/s (SU), 292.571 us, 252701120 bytes | copy: 5502 GB/s, 91.857 us
   - EP:   2/4 | expanded dispatch: 0 GB/s (SO), 860 GB/s (SU), 293.253 us, 252328160 bytes | copy: 5231 GB/s, 152.202 us
   - EP:   0/4 | expanded dispatch: 0 GB/s (SO), 863 GB/s (SU), 292.893 us, 252701120 bytes | copy: 5602 GB/s, 142.362 us
   - EP:   1/4 | expanded dispatch: 0 GB/s (SO), 863 GB/s (SU), 293.067 us, 252808864 bytes | copy: 5195 GB/s, 153.544 us
   - EP:   3/4 | expanded dispatch: 0 GB/s (SO), 862 GB/s (SU), 293.274 us, 252933184 bytes | copy: 5375 GB/s, 148.201 us
   # EP:   1/4 | cached dispatch: 0 GB/s (SO), 853 GB/s (SU), 296.533 us, 252808864 bytes | copy: 5106 GB/s, 99.032 us
   # EP:   0/4 | cached dispatch: 0 GB/s (SO), 852 GB/s (SU), 296.565 us, 252701120 bytes | copy: 5660 GB/s, 89.290 us
   # EP:   3/4 | cached dispatch: 0 GB/s (SO), 855 GB/s (SU), 295.929 us, 252933184 bytes | copy: 5141 GB/s, 98.404 us
   # EP:   2/4 | cached dispatch: 0 GB/s (SO), 853 GB/s (SU), 295.692 us, 252328160 bytes | copy: 5494 GB/s, 91.854 us
   @ EP:   0/4 | combine: 0 GB/s (SO), 830 GB/s (SU), 302.116 us, 250749760 bytes | reduce: 1521 GB/s, 176.836 us
   @ EP:   2/4 | combine: 0 GB/s (SO), 826 GB/s (SU), 303.037 us, 250379680 bytes | reduce: 1531 GB/s, 175.667 us
   @ EP:   1/4 | combine: 0 GB/s (SO), 825 GB/s (SU), 304.139 us, 250856672 bytes | reduce: 1524 GB/s, 176.495 us
   @ EP:   3/4 | combine: 0 GB/s (SO), 826 GB/s (SU), 303.761 us, 250980032 bytes | reduce: 1530 GB/s, 175.782 us
   + EP:   0/4 | reduced combine: 0 GB/s (SO), 468 GB/s (SU), 536.128 us, 250749760 bytes | reduce: 1554 GB/s, 173.094 us
   + EP:   3/4 | reduced combine: 0 GB/s (SO), 467 GB/s (SU), 537.569 us, 250980032 bytes | reduce: 1561 GB/s, 172.186 us
   + EP:   2/4 | reduced combine: 0 GB/s (SO), 466 GB/s (SU), 537.387 us, 250379680 bytes | reduce: 1563 GB/s, 171.987 us
   + EP:   1/4 | reduced combine: 0 GB/s (SO), 467 GB/s (SU), 536.832 us, 250856672 bytes | reduce: 1560 GB/s, 172.430 us

FP8性能:

 > Testing with do_handle_copy=0, expert_alignment=1, use_fp8_dispatch=1, num_bias=2, with_previous_event=1, async_with_compute_stream=1, allocate_on_comm_stream=1 ...
   * EP:   3/4 | dispatch: 0 GB/s (SO), 812 GB/s (SU), 162.415 us, 131837760 bytes | copy: 5693 GB/s, 46.315 us
   * EP:   2/4 | dispatch: 0 GB/s (SO), 811 GB/s (SU), 162.107 us, 131522400 bytes | copy: 6434 GB/s, 40.882 us
   * EP:   1/4 | dispatch: 0 GB/s (SO), 812 GB/s (SU), 162.337 us, 131772960 bytes | copy: 5820 GB/s, 45.282 us
   * EP:   0/4 | dispatch: 0 GB/s (SO), 816 GB/s (SU), 161.427 us, 131716800 bytes | copy: 6651 GB/s, 39.609 us
   - EP:   3/4 | expanded dispatch: 0 GB/s (SO), 813 GB/s (SU), 162.149 us, 131837760 bytes | copy: 4536 GB/s, 91.665 us
   - EP:   2/4 | expanded dispatch: 0 GB/s (SO), 812 GB/s (SU), 161.914 us, 131522400 bytes | copy: 4530 GB/s, 91.750 us
   - EP:   1/4 | expanded dispatch: 0 GB/s (SO), 812 GB/s (SU), 162.227 us, 131772960 bytes | copy: 4432 GB/s, 93.928 us
   - EP:   0/4 | expanded dispatch: 0 GB/s (SO), 817 GB/s (SU), 161.157 us, 131716800 bytes | copy: 4776 GB/s, 87.148 us
   # EP:   2/4 | cached dispatch: 0 GB/s (SO), 800 GB/s (SU), 164.426 us, 131522400 bytes | copy: 4968 GB/s, 52.950 us
   # EP:   0/4 | cached dispatch: 0 GB/s (SO), 805 GB/s (SU), 163.646 us, 131716800 bytes | copy: 5066 GB/s, 52.001 us
   # EP:   3/4 | cached dispatch: 0 GB/s (SO), 801 GB/s (SU), 164.542 us, 131837760 bytes | copy: 4685 GB/s, 56.283 us
   # EP:   1/4 | cached dispatch: 0 GB/s (SO), 802 GB/s (SU), 164.246 us, 131772960 bytes | copy: 4670 GB/s, 56.433 us
   @ EP:   2/4 | combine: 0 GB/s (SO), 822 GB/s (SU), 304.756 us, 250379680 bytes | reduce: 1530 GB/s, 175.695 us
   @ EP:   3/4 | combine: 0 GB/s (SO), 823 GB/s (SU), 304.825 us, 250980032 bytes | reduce: 1521 GB/s, 176.814 us
   @ EP:   0/4 | combine: 0 GB/s (SO), 827 GB/s (SU), 303.164 us, 250749760 bytes | reduce: 1519 GB/s, 177.088 us
   @ EP:   1/4 | combine: 0 GB/s (SO), 824 GB/s (SU), 304.422 us, 250856672 bytes | reduce: 1522 GB/s, 176.717 us
   + EP:   2/4 | reduced combine: 0 GB/s (SO), 460 GB/s (SU), 544.706 us, 250379680 bytes | reduce: 1564 GB/s, 171.968 us
   + EP:   0/4 | reduced combine: 0 GB/s (SO), 460 GB/s (SU), 544.571 us, 250749760 bytes | reduce: 1548 GB/s, 173.711 us
   + EP:   3/4 | reduced combine: 0 GB/s (SO), 460 GB/s (SU), 545.158 us, 250980032 bytes | reduce: 1560 GB/s, 172.333 us
   + EP:   1/4 | reduced combine: 0 GB/s (SO), 460 GB/s (SU), 545.726 us, 250856672 bytes | reduce: 1554 GB/s, 173.021 us

SM64(default)

BF16性能:

 > Testing with do_handle_copy=0, expert_alignment=1, use_fp8_dispatch=0, num_bias=2, with_previous_event=1, async_with_compute_stream=1, allocate_on_comm_stream=1 ...
   * EP:   3/4 | dispatch: 0 GB/s (SO), 852 GB/s (SU), 296.789 us, 252933184 bytes | copy: 5654 GB/s, 89.467 us
   * EP:   0/4 | dispatch: 0 GB/s (SO), 852 GB/s (SU), 296.492 us, 252701120 bytes | copy: 5239 GB/s, 96.465 us
   * EP:   2/4 | dispatch: 0 GB/s (SO), 850 GB/s (SU), 296.938 us, 252328160 bytes | copy: 5004 GB/s, 100.841 us
   * EP:   1/4 | dispatch: 0 GB/s (SO), 852 GB/s (SU), 296.698 us, 252808864 bytes | copy: 5128 GB/s, 98.600 us
   - EP:   0/4 | expanded dispatch: 0 GB/s (SO), 852 GB/s (SU), 296.656 us, 252701120 bytes | copy: 5250 GB/s, 151.907 us
   - EP:   3/4 | expanded dispatch: 0 GB/s (SO), 851 GB/s (SU), 297.261 us, 252933184 bytes | copy: 5485 GB/s, 145.241 us
   - EP:   2/4 | expanded dispatch: 0 GB/s (SO), 848 GB/s (SU), 297.569 us, 252328160 bytes | copy: 5068 GB/s, 157.112 us
   - EP:   1/4 | expanded dispatch: 0 GB/s (SO), 851 GB/s (SU), 297.086 us, 252808864 bytes | copy: 5046 GB/s, 158.050 us
   # EP:   0/4 | cached dispatch: 0 GB/s (SO), 852 GB/s (SU), 296.548 us, 252701120 bytes | copy: 5273 GB/s, 95.852 us
   # EP:   2/4 | cached dispatch: 0 GB/s (SO), 846 GB/s (SU), 298.164 us, 252328160 bytes | copy: 5097 GB/s, 99.014 us
   # EP:   1/4 | cached dispatch: 0 GB/s (SO), 849 GB/s (SU), 297.923 us, 252808864 bytes | copy: 5000 GB/s, 101.121 us
   # EP:   3/4 | cached dispatch: 0 GB/s (SO), 846 GB/s (SU), 298.808 us, 252933184 bytes | copy: 5558 GB/s, 91.019 us
   @ EP:   2/4 | combine: 0 GB/s (SO), 832 GB/s (SU), 301.024 us, 250379680 bytes | reduce: 1531 GB/s, 175.663 us
   @ EP:   0/4 | combine: 0 GB/s (SO), 838 GB/s (SU), 299.209 us, 250749760 bytes | reduce: 1518 GB/s, 177.167 us
   @ EP:   1/4 | combine: 0 GB/s (SO), 834 GB/s (SU), 300.819 us, 250856672 bytes | reduce: 1520 GB/s, 176.920 us
   @ EP:   3/4 | combine: 0 GB/s (SO), 834 GB/s (SU), 300.902 us, 250980032 bytes | reduce: 1526 GB/s, 176.180 us
   + EP:   0/4 | reduced combine: 0 GB/s (SO), 652 GB/s (SU), 384.700 us, 250749760 bytes | reduce: 1553 GB/s, 173.144 us
   + EP:   2/4 | reduced combine: 0 GB/s (SO), 649 GB/s (SU), 385.709 us, 250379680 bytes | reduce: 1566 GB/s, 171.727 us
   + EP:   1/4 | reduced combine: 0 GB/s (SO), 650 GB/s (SU), 386.044 us, 250856672 bytes | reduce: 1556 GB/s, 172.833 us
   + EP:   3/4 | reduced combine: 0 GB/s (SO), 650 GB/s (SU), 386.387 us, 250980032 bytes | reduce: 1562 GB/s, 172.179 us

FP8性能:

 > Testing with do_handle_copy=0, expert_alignment=1, use_fp8_dispatch=1, num_bias=2, with_previous_event=1, async_with_compute_stream=1, allocate_on_comm_stream=1 ...
   * EP:   1/4 | dispatch: 0 GB/s (SO), 809 GB/s (SU), 162.851 us, 131772960 bytes | copy: 5712 GB/s, 46.142 us
   * EP:   3/4 | dispatch: 0 GB/s (SO), 810 GB/s (SU), 162.850 us, 131837760 bytes | copy: 6582 GB/s, 40.062 us
   * EP:   0/4 | dispatch: 0 GB/s (SO), 809 GB/s (SU), 162.749 us, 131716800 bytes | copy: 6239 GB/s, 42.226 us
   * EP:   2/4 | dispatch: 0 GB/s (SO), 806 GB/s (SU), 163.113 us, 131522400 bytes | copy: 5912 GB/s, 44.492 us
   - EP:   1/4 | expanded dispatch: 0 GB/s (SO), 806 GB/s (SU), 163.491 us, 131772960 bytes | copy: 4314 GB/s, 96.494 us
   - EP:   0/4 | expanded dispatch: 0 GB/s (SO), 809 GB/s (SU), 162.832 us, 131716800 bytes | copy: 4522 GB/s, 92.051 us
   - EP:   3/4 | expanded dispatch: 0 GB/s (SO), 807 GB/s (SU), 163.431 us, 131837760 bytes | copy: 4468 GB/s, 93.074 us
   - EP:   2/4 | expanded dispatch: 0 GB/s (SO), 804 GB/s (SU), 163.525 us, 131522400 bytes | copy: 4352 GB/s, 95.493 us
   # EP:   1/4 | cached dispatch: 0 GB/s (SO), 796 GB/s (SU), 165.521 us, 131772960 bytes | copy: 4577 GB/s, 57.575 us
   # EP:   3/4 | cached dispatch: 0 GB/s (SO), 795 GB/s (SU), 165.820 us, 131837760 bytes | copy: 5016 GB/s, 52.562 us
   # EP:   0/4 | cached dispatch: 0 GB/s (SO), 802 GB/s (SU), 164.295 us, 131716800 bytes | copy: 4938 GB/s, 53.352 us
   # EP:   2/4 | cached dispatch: 0 GB/s (SO), 795 GB/s (SU), 165.486 us, 131522400 bytes | copy: 4644 GB/s, 56.645 us
   @ EP:   3/4 | combine: 0 GB/s (SO), 834 GB/s (SU), 300.810 us, 250980032 bytes | reduce: 1526 GB/s, 176.147 us
   @ EP:   1/4 | combine: 0 GB/s (SO), 835 GB/s (SU), 300.538 us, 250856672 bytes | reduce: 1521 GB/s, 176.852 us
   @ EP:   0/4 | combine: 0 GB/s (SO), 839 GB/s (SU), 298.836 us, 250749760 bytes | reduce: 1517 GB/s, 177.329 us
   @ EP:   2/4 | combine: 0 GB/s (SO), 831 GB/s (SU), 301.383 us, 250379680 bytes | reduce: 1531 GB/s, 175.668 us
   + EP:   1/4 | reduced combine: 0 GB/s (SO), 642 GB/s (SU), 390.674 us, 250856672 bytes | reduce: 1559 GB/s, 172.527 us
   + EP:   3/4 | reduced combine: 0 GB/s (SO), 643 GB/s (SU), 390.280 us, 250980032 bytes | reduce: 1563 GB/s, 172.016 us
   + EP:   0/4 | reduced combine: 0 GB/s (SO), 646 GB/s (SU), 388.443 us, 250749760 bytes | reduce: 1553 GB/s, 173.197 us
   + EP:   2/4 | reduced combine: 0 GB/s (SO), 642 GB/s (SU), 390.233 us, 250379680 bytes | reduce: 1563 GB/s, 172.024 us

EP8

Config:

Ranks: 1 x 8

Experts: 8/64

Tokens: 8192 (max: 8192), hidden: 4096

#SM: 64, #QPs: 9/17

SM64

command:

MASTER_ADDR=xxxxx MASTER_PORT=29500 WORLD_SIZE=2 RANK=0 LD_LIBRARY_PATH=/usr/local/lib/python3.12/dist-packages/nvidia/nccl/lib/ CUDA_HOME=/usr/local/cuda NVSHMEM_DEBUG=WARN NCCL_DEBUG=INFO python3 /AntDeepEP_V2/tests/elastic/test_ep.py --num-processes 4 --num-tokens 8192 --hidden 4096 --num-topk 8 --num-experts 64 --allow-hybrid-mode 0 --num-sms 64  > /tmp/debug_bench_mark_ep8_local_experts_8.txt

BF16

 > Testing with do_handle_copy=0, expert_alignment=1, use_fp8_dispatch=0, num_bias=2, with_previous_event=1, async_with_compute_stream=1, allocate_on_comm_stream=1 ...
   * EP:   0/8 | dispatch: 0 GB/s (SO), 94 GB/s (SU), 3920.000 us, 367904320 bytes | copy: 4781 GB/s, 153.904 us
   * EP:   1/8 | dispatch: 0 GB/s (SO), 94 GB/s (SU), 3922.000 us, 369338144 bytes | copy: 5158 GB/s, 143.206 us
   * EP:   3/8 | dispatch: 0 GB/s (SO), 94 GB/s (SU), 3920.000 us, 368061792 bytes | copy: 4706 GB/s, 156.416 us
   * EP:   2/8 | dispatch: 0 GB/s (SO), 94 GB/s (SU), 3920.000 us, 368302144 bytes | copy: 5408 GB/s, 136.196 us
   - EP:   2/8 | expanded dispatch: 0 GB/s (SO), 94 GB/s (SU), 3920.000 us, 368302144 bytes | copy: 5365 GB/s, 170.496 us
   - EP:   1/8 | expanded dispatch: 0 GB/s (SO), 94 GB/s (SU), 3921.000 us, 369338144 bytes | copy: 5190 GB/s, 175.985 us
   - EP:   0/8 | expanded dispatch: 0 GB/s (SO), 94 GB/s (SU), 3921.000 us, 367904320 bytes | copy: 4950 GB/s, 184.379 us
   - EP:   3/8 | expanded dispatch: 0 GB/s (SO), 94 GB/s (SU), 3921.000 us, 368061792 bytes | copy: 4861 GB/s, 187.454 us
   # EP:   2/8 | cached dispatch: 0 GB/s (SO), 93 GB/s (SU), 3957.000 us, 368302144 bytes | copy: 5087 GB/s, 144.790 us
   # EP:   1/8 | cached dispatch: 0 GB/s (SO), 93 GB/s (SU), 3959.000 us, 369338144 bytes | copy: 4773 GB/s, 154.754 us
   # EP:   0/8 | cached dispatch: 0 GB/s (SO), 93 GB/s (SU), 3958.000 us, 367904320 bytes | copy: 4464 GB/s, 164.836 us
   # EP:   3/8 | cached dispatch: 0 GB/s (SO), 93 GB/s (SU), 3957.000 us, 368061792 bytes | copy: 4325 GB/s, 170.210 us
   @ EP:   1/8 | combine: 0 GB/s (SO), 93 GB/s (SU), 3934.000 us, 366486112 bytes | reduce: 2253 GB/s, 251.729 us
   @ EP:   0/8 | combine: 0 GB/s (SO), 93 GB/s (SU), 3934.000 us, 365063360 bytes | reduce: 2213 GB/s, 256.929 us
   @ EP:   3/8 | combine: 0 GB/s (SO), 93 GB/s (SU), 3933.000 us, 365219616 bytes | reduce: 2235 GB/s, 253.655 us
   @ EP:   2/8 | combine: 0 GB/s (SO), 93 GB/s (SU), 3933.000 us, 365458112 bytes | reduce: 2218 GB/s, 255.599 us
   + EP:   0/8 | reduced combine: 0 GB/s (SO), 93 GB/s (SU), 3933.000 us, 365063360 bytes | reduce: 2246 GB/s, 253.104 us
   + EP:   1/8 | reduced combine: 0 GB/s (SO), 93 GB/s (SU), 3934.000 us, 366486112 bytes | reduce: 2287 GB/s, 247.955 us
   + EP:   3/8 | reduced combine: 0 GB/s (SO), 93 GB/s (SU), 3934.000 us, 365219616 bytes | reduce: 2270 GB/s, 249.717 us
   + EP:   2/8 | reduced combine: 0 GB/s (SO), 93 GB/s (SU), 3934.000 us, 365458112 bytes | reduce: 2261 GB/s, 250.659 us

FP8

 > Testing with do_handle_copy=0, expert_alignment=1, use_fp8_dispatch=1, num_bias=2, with_previous_event=1, async_with_compute_stream=1, allocate_on_comm_stream=1 ...
   * EP:   0/8 | dispatch: 0 GB/s (SO), 91 GB/s (SU), 2098.000 us, 191764800 bytes | copy: 3967 GB/s, 96.668 us
   * EP:   3/8 | dispatch: 0 GB/s (SO), 91 GB/s (SU), 2098.000 us, 191846880 bytes | copy: 3957 GB/s, 96.969 us
   * EP:   1/8 | dispatch: 0 GB/s (SO), 92 GB/s (SU), 2099.000 us, 192512160 bytes | copy: 4222 GB/s, 91.197 us
   * EP:   2/8 | dispatch: 0 GB/s (SO), 92 GB/s (SU), 2098.000 us, 191972160 bytes | copy: 4489 GB/s, 85.533 us
   - EP:   0/8 | expanded dispatch: 0 GB/s (SO), 91 GB/s (SU), 2098.000 us, 191764800 bytes | copy: 4056 GB/s, 117.505 us
   - EP:   1/8 | expanded dispatch: 0 GB/s (SO), 92 GB/s (SU), 2099.000 us, 192512160 bytes | copy: 4143 GB/s, 115.136 us
   - EP:   3/8 | expanded dispatch: 0 GB/s (SO), 91 GB/s (SU), 2098.000 us, 191846880 bytes | copy: 4000 GB/s, 118.936 us
   - EP:   2/8 | expanded dispatch: 0 GB/s (SO), 92 GB/s (SU), 2097.000 us, 191972160 bytes | copy: 4424 GB/s, 107.969 us
   # EP:   0/8 | cached dispatch: 0 GB/s (SO), 91 GB/s (SU), 2115.000 us, 191764800 bytes | copy: 3814 GB/s, 100.550 us
   # EP:   1/8 | cached dispatch: 0 GB/s (SO), 91 GB/s (SU), 2116.000 us, 192512160 bytes | copy: 4205 GB/s, 91.571 us
   # EP:   3/8 | cached dispatch: 0 GB/s (SO), 91 GB/s (SU), 2115.000 us, 191846880 bytes | copy: 3696 GB/s, 103.805 us
   # EP:   2/8 | cached dispatch: 0 GB/s (SO), 91 GB/s (SU), 2114.000 us, 191972160 bytes | copy: 4214 GB/s, 91.105 us
   @ EP:   0/8 | combine: 0 GB/s (SO), 93 GB/s (SU), 3938.000 us, 365063360 bytes | reduce: 2211 GB/s, 257.147 us
   @ EP:   1/8 | combine: 0 GB/s (SO), 93 GB/s (SU), 3939.000 us, 366486112 bytes | reduce: 2253 GB/s, 251.724 us
   @ EP:   3/8 | combine: 0 GB/s (SO), 93 GB/s (SU), 3938.000 us, 365219616 bytes | reduce: 2236 GB/s, 253.542 us
   @ EP:   2/8 | combine: 0 GB/s (SO), 93 GB/s (SU), 3938.000 us, 365458112 bytes | reduce: 2219 GB/s, 255.429 us
   + EP:   0/8 | reduced combine: 0 GB/s (SO), 91 GB/s (SU), 3992.000 us, 365063360 bytes | reduce: 2249 GB/s, 252.874 us
   + EP:   1/8 | reduced combine: 0 GB/s (SO), 92 GB/s (SU), 3993.000 us, 366486112 bytes | reduce: 2288 GB/s, 247.921 us
   + EP:   3/8 | reduced combine: 0 GB/s (SO), 91 GB/s (SU), 3993.000 us, 365219616 bytes | reduce: 2270 GB/s, 249.732 us
   + EP:   2/8 | reduced combine: 0 GB/s (SO), 92 GB/s (SU), 3992.000 us, 365458112 bytes | reduce: 2254 GB/s, 251.504 us

加了下日志去定位

在 csrc/kernels/backend/nccl.cu:103 后面临时打印:

printf("DeepEP dev_comm: lsaSize=%d lsaRank=%d world=%d rank=%d\n",
dev_comm.lsaSize, dev_comm.lsaRank, num_ranks, rank_idx);

是不是 NCCL dev_comm 创建得有问题?
 debug 日志:

 DeepEP dev_comm: lsaSize=4 lsaRank=2 world=8 rank=2
 DeepEP dev_comm: lsaSize=4 lsaRank=3 world=8 rank=3
 DeepEP dev_comm: lsaSize=4 lsaRank=1 world=8 rank=5
 DeepEP dev_comm: lsaSize=4 lsaRank=0 world=8 rank=4
 DeepEP dev_comm: lsaSize=4 lsaRank=3 world=8 rank=7
 DeepEP dev_comm: lsaSize=4 lsaRank=2 world=8 rank=6


运行命令:

  "MASTER_ADDR=xxxx MASTER_PORT=29500 WORLD_SIZE=2 RANK=0 LD_LIBRARY_PATH=/usr/local/lib/python3.12/dist-packages/nvidia/nccl/lib/ CUDA_HOME=/usr/
 local/cuda NVSHMEM_DEBUG=WARN NCCL_DEBUG=WARN nohup python3 /AntDeepEP_V2/tests/elastic/
 test_ep.py --num-processes 4 --num-tokens 8192 --hidden 4096 --num-topk 8 --num-experts 64 --allow-hybrid-mode 0 --num-sms 64 --skip-check --test-first-
 only  > /tmp/debug_bench_mark_ep8_local_experts_8.txt 2>&1 &"

 "MASTER_ADDR=xxxx MASTER_PORT=29500 WORLD_SIZE=2 RANK=1 LD_LIBRARY_PATH=/usr/local/lib/python3.12/dist-packages/nvidia/nccl/lib/ CUDA_HOME=/usr/
 local/cuda NVSHMEM_DEBUG=WARN NCCL_DEBUG=WARN nohup python3 /AntDeepEP_V2/tests/elastic/
 test_ep.py --num-processes 4 --num-tokens 8192 --hidden 4096 --num-topk 8 --num-experts 64 --allow-hybrid-mode 0 --num-sms 64 --skip-check --test-first-
 only  > /tmp/debug_bench_mark_ep8_local_experts_8.txt 2>&1 &"

NIC has traffic:

image

@alpha-baby
Copy link
Copy Markdown
Contributor

@xiaofanl-nvidia

This looks like an issue with NCCL dev_comm. Could you help take a look? Thank you very much!

@alpha-baby
Copy link
Copy Markdown
Contributor

@xiaofanl-nvidia

This looks like an issue with NCCL dev_comm. Could you help take a look? Thank you very much!

Ignore this issue. This is because my two nodes are not in the same nvlink-domain.

dmvevents pushed a commit to dmvevents/RL that referenced this pull request May 6, 2026
Bumps the deep_ep git pin in pyproject.toml from bfded348
(2025-10-29, pre-V2) to b306af0 (2026-04-29), which is the
merge commit of DeepEP PR NVIDIA-NeMo#605 "Introducing EPv2".

Why
---
The current pin predates the DeepEP V2 API (ElasticBuffer,
PP/CP/Engram support). Consumers of NeMo-RL's Megatron backend
that follow NVIDIA/Megatron-LM#4632 ("Shape Y" Megatron V2
adoption) cannot resolve deep_ep.ElasticBuffer with the
current pin; the virtualenv still installs the pre-V2 tree.

This change bumps only the pin. It does not by itself change
any NeMo-RL code path. Paired with Megatron-LM#4632, it
enables the end-to-end V2 path that is already running on
AWS p5en.48xlarge 2x H200 in the reproduction repo below.

Upstream references
-------------------
* deepseek-ai/DeepEP#605 (V2 merge 2026-04-29)
* NVIDIA/Megatron-LM#4632 (Megatron-side V2 adoption)

Reproduction
------------
End-to-end reproduction (Dockerfile + K8s manifests + smoke
bench) is public at:
  https://github.com/antonai-work/nemo-rl-deepep-v2-efa

Related NeMo-RL PR (separate concern, same fleet):
  NVIDIA-NeMo#2410 (Dockerfile LD_LIBRARY_PATH for EFA
  OFI discovery)

Signed-off-by: Anton Alexander <antonai@users.noreply.github.com>
zhijiehou pushed a commit to zhijiehou/rtp-llm that referenced this pull request May 14, 2026
…itch

Add `RTP_LLM_DEEPEP_BACKEND={legacy,elastic}` runtime switch so rtp-llm can
keep the v1-compatible `deep_ep::legacy::Buffer` path (default) and
opt into the v2 unified `ElasticBuffer` (PR deepseek-ai/DeepEP#605) without
forking the engine. Unknown values silently fall back to legacy so a typo
never flips a production deploy to v2.

- deepep_wrapper.py: add DeepEPBackend / deepep_backend() / is_elastic_buffer
  helpers, branch _init_deepep_buffer at the entry, and add a new
  _init_elastic_buffer that uses ElasticBuffer.get_buffer_size_hint to size
  the unified buffer. The three legacy init paths
  (_init_normal_buffer / _init_low_latency_buffer / _init_low_latency_m2n_buffer)
  are kept verbatim so v1 behavior is preserved.
- deepep_normal_router.py: branch prepare()/finalize() at the entry on
  is_elastic_buffer(); the elastic path skips get_dispatch_layout (unified
  API computes layout internally), uses the v2 dispatch 5-tuple and reads
  num_recv_tokens_per_expert_list from EPHandle. The legacy 6-tuple call
  site is untouched.
- deepep_low_latency_router.py: collapse the v1 two-API setup
  (low_latency_dispatch / low_latency_combine) onto ElasticBuffer's unified
  dispatch / combine when the elastic backend is active. _is_elastic is
  cached at construction time so prepare()/finalize() only branch once.
zhijiehou pushed a commit to zhijiehou/rtp-llm that referenced this pull request May 15, 2026
…itch

Add `RTP_LLM_DEEPEP_BACKEND={legacy,elastic}` runtime switch so rtp-llm can
keep the v1-compatible `deep_ep::legacy::Buffer` path (default) and
opt into the v2 unified `ElasticBuffer` (PR deepseek-ai/DeepEP#605) without
forking the engine. Unknown values silently fall back to legacy so a typo
never flips a production deploy to v2.

- deepep_wrapper.py: add DeepEPBackend / deepep_backend() / is_elastic_buffer
  helpers, branch _init_deepep_buffer at the entry, and add a new
  _init_elastic_buffer that uses ElasticBuffer.get_buffer_size_hint to size
  the unified buffer. The three legacy init paths
  (_init_normal_buffer / _init_low_latency_buffer / _init_low_latency_m2n_buffer)
  are kept verbatim so v1 behavior is preserved.
- deepep_normal_router.py: branch prepare()/finalize() at the entry on
  is_elastic_buffer(); the elastic path skips get_dispatch_layout (unified
  API computes layout internally), uses the v2 dispatch 5-tuple and reads
  num_recv_tokens_per_expert_list from EPHandle. The legacy 6-tuple call
  site is untouched.
- deepep_low_latency_router.py: collapse the v1 two-API setup
  (low_latency_dispatch / low_latency_combine) onto ElasticBuffer's unified
  dispatch / combine when the elastic backend is active. _is_elastic is
  cached at construction time so prepare()/finalize() only branch once.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants