Skip to content

Conversation

@itayalroy
Copy link

@itayalroy itayalroy commented Oct 22, 2025

Add Elasticity Support to DeepEP via NIXL Integration

Summary

This PR adds elastic scaling capabilities to DeepEP, enabling dynamic addition and removal of processes (ranks) during runtime on demand, without affecting the existing connections. This is achieved by integrating NVIDIA Inference Xfer Library (NIXL), a high-performance communication library that utilizes RDMA and NVLink transports and supports dynamic rank management.

Note: Currently, this PR replaces NVSHMEM calls with NIXL calls. However, we would like to discuss the best way to enable support for multiple communication libraries.

Included in this PR:

✅ Integrated NIXL with DeepEP (low-latency & internode modes), tested with DeepEP's original benchmark
✅ Tested with TRT-LLM, vLLM and SGLANG
✅ Introduced new buffer APIs for elastic addition/removal of ranks
✅ Extended DeepEP's benchmark to add/remove ranks in runtime

Next Steps

⬜ Support multiple calls to update_memory_buffers() for elastic allocation/deallocation of GPU memory
⬜ Integrate DeepEP's failure detection with remove_ranks API
⬜ Support elasticity in intranode kernels

New Buffer Initialization Pattern:

# New API: Dynamic initialization
buffer = deep_ep.nixl_buffer(rank, low_latency_mode=True, explicitly_destroy=True)
buffer.update_memory_buffers(num_ranks, num_experts_per_rank, nvl_bytes, rdma_bytes)
buffer.connect_ranks(initial_ranks)

# Dispatch & Combine calls
buffer.dispatch(...)
buffer.combine(...)

# Later: Add new ranks dynamically
buffer.connect_ranks(new_ranks)

# Dispatch & Combine calls
buffer.dispatch(...)
buffer.combine(...)

# Remove ranks when scaling down
buffer.remove_ranks(ranks_to_remove)

New Buffer APIs:

  • nixl_buffer(rank_id, low_latency_mode, low_latency_nvlink_backend, explicitly_destroy): Initialize the NIXL communication buffer
  • update_memory_buffers(num_ranks, num_experts_per_rank, num_nvl_bytes, num_rdma_bytes): Prepare buffers for up to num_ranks ranks and num_experts_per_rank experts
  • connect_ranks(remote_ranks): Establish NIXL connections to new peers (can be called multiple times)
  • remove_ranks(remote_ranks): Clean up connections to departing peers

Testing

New elastic test suite in tests/elastic/:

  • A plan file (representing an orchestrator) defines the scaling phases
  • The test validates correctness and measures bandwidth between scaling phases

Example Plan (expansion_contraction.json):

[
  [0, 1, 2, 3],
  [0, 1, 2, 3, 4, 5, 6, 7],
  [0, 1, 2, 3, 4, 5]
]

This plan defines three phases:

  • Phase 0: Initial state with ranks 0-3
  • Phase 1: Ranks 4-7 are added dynamically (launched independently from initial ranks)
  • Phase 2: Ranks 6-7 are removed dynamically

Performance Testing

All benchmarks were conducted on 2 NVIDIA EOS cluster nodes (8× H100 GPUs and 8× CX7 NICs per node, InfiniBand interconnect), totaling 16 ranks.

Low-Latency Kernels (128 tokens, hidden = 7168, top-k = 8, 16 ranks, 32 experts):

Implementation Avg Bandwidth (GB/s) Avg Latency (µs) Min (µs) Max (µs)
DeepEP main 59.17 372.7 347.5 387.5
This PR (NIXL) 59.88 368.2 344.6 383.7

Internode Kernels (4096 tokens, hidden = 7168, top-k = 8, 16 ranks, 256 experts):

Operation Precision Implementation Transmit (µs) RDMA BW (GB/s) NVLink BW (GB/s)
Dispatch FP8 DeepEP main 759.15 79.51 259.53
This PR (NIXL) 773.84 78.00 254.61
Dispatch BF16 DeepEP main 1363.00 85.89 280.35
This PR (NIXL) 1359.00 86.14 281.17
Combine BF16 DeepEP main 1767.00 66.25 216.25
This PR (NIXL) 1856.00 63.08 205.88

Example Launch

Refer to NIXL_README.md for detailed instructions on how to run the elastic test suite.

Co-authored-by: Roey Azran <roeya@nvidia.com>
Co-authored-by: Micha Dery <mdery@nvidia.com>
Co-authored-by: Asaf Schwartz <aschwartz@nvidia.com>

Signed-off-by: Itay Alroy <ialroy@nvidia.com>
@sphish
Copy link
Collaborator

sphish commented Oct 24, 2025

Amazing work, and thank you so much for your PR!

However, I currently believe that the symmetric memory communication model provided by NVSHMEM is particularly well-suited for machine learning communication scenarios. Compared with the connection-based communication model, it is simpler and easier to understand, and even developers without a networking background can use it effectively.

To maintain the current project’s ease of maintenance, I’m not considering introducing NIXL at this time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants