Skip to content

Conversation

@Gaojiaqi
Copy link
Contributor

TL;DR: This feature branch replaces the standard “write-then-atomic ACK” path with an Eager RDMA layout that bundles data and per-token signals into a single IBV_WR_RDMA_WRITE. It removes the extra RTT latency introduced by atomics on multi-path NICs and enables fine-grained overlap at the receiver. In our tests (Hopper + BlueField-3, multi-path enabled), Eager RDMA reduces latency by up to 20%.

Background

DeepEP’s current implementation follows the InfiniBand Architecture Specification and uses a standard acknowledgement sequence on the same QP:

  • Sender issues a burst of IBV_WR_RDMA_WRITE (data “tokens”).
  • Sender follows with an IBV_WR_RDMA_ATOMIC_FETCH_AND_ADD as the confirmation.
  • Receiver waits for the atomic to detect completion.

This guarantees that all tokens posted before the atomic are visible when the atomic is observed.

On modern multi-path capable adapters (e.g., NVIDIA ConnectX-7, BlueField-3) and multi-path fabrics (InfiniBand Adaptive Routing, multi-path Ethernet), the network may reorder the atomic and write packets. To preserve ordering semantics, the stack will refrain from issuing the atomic until all in-flight writes on that QP have completed (i.e., the fence)—effectively injecting one RTT of latency for each atomic that follows writes.

Eager RDMA removes atomics and fences from the critical path by interleaving data and a small “signal” field inside every write WQE. The sender emits a single IBV_WR_RDMA_WRITE that carries both payload and signals; the receiver polls those signals directly in device memory to determine readiness. This further enables per-token acknowledgement, receiver can process tokens as soon as they’re ready (fine-grained overlap between intra-node D2D copies and inter-node sending/receiving).

Requirements

  1. GPU: NVIDIA Hopper (SM90) or newer.
  2. MTU: 4096 bytes.

Ordering Guarantees

On Hopper or newer GPUs, the DMA writes arriving from the RNIC to GPU memory and the memory update order visible to GPU kernels are ordered when all of the following hold:

  1. Architecture: GPU is Hopper (SM90) or newer.
  2. CUDA pointer attribute: CU_POINTER_ATTRIBUTE_SYNC_MEMOPS is set via cuPointerSetAttribute (NVSHMEM sets this by default).
  3. IB verbs MR: the MR is registered without IBV_ACCESS_RELAXED_ORDERING (By disabling NVSHMEM_IB_ENABLE_RELAXED_ORDERING).

Under these conditions, when a full-MTU RDMA write completes, the GPU’s memory view is ordered—i.e., once the last 16-byte signal is updated, the preceding 4080 bytes of data in that MTU tile are also visible.

Layout & Data Path

Sender layout

Each 4096-byte MTU tile is laid out as:

Data Signal Data Signal Data Signal
4080B 16B 4080B 16B (≤4080B) 16B
  • Data: up to 4080 bytes.
  • Signal: 16 bytes (per-token signal field).
  • The final tile may have a short data segment (≤4080B) followed by its 16B signal.

The sender maps its normal memory view to the Eager RDMA view and issues IBV_WR_RDMA_WRITE WQEs with this layout. Original atomics are replaced by RDMA writes.

Receiver progress

  1. The receiver polls the embedded 16B signal of each token.
  2. When a token’s signals all match the expected value, that token’s data is guaranteed ready.
  3. The receiver can immediately consume/copy that token and continue polling subsequent tokens—achieving per-token overlap and reducing overall collective completion time.

Programming Model & APIs

Eager RDMA provides two ways to integrate with existing kernels:

  • Wrapped load/store APIs (global memory):
    If your code uses explicit loads/stores, call the Eager RDMA wrappers to access the interleaved view safely.

  • In-place TMA view transforms (shared memory):
    If your code uses TMA (Tensor Memory Accelerator), use the provided in-place transform to construct/deconstruct the Eager RDMA layout when staging through shared memory.

The goal is to keep your kernel logic largely unchanged while ensuring that reads/writes respect the interleaved data+signal format and ordered visibility rules above.

Performance

  • Setup: Hopper GPU + BlueField-3 RNIC.
  • Result: Eager RDMA achieves up to 20% latency reduction vs. the original DeepEP path.
EP16eager EP32eager

@Chaser-wind Chaser-wind merged commit f4e0dd4 into deepseek-ai:eager-rdma Sep 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants