Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 30 additions & 22 deletions csrc/deep_ep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,7 @@ Buffer::internode_dispatch(const torch::Tensor& x,
const std::optional<torch::Tensor>& cached_gbl_channel_prefix_matrix,
const std::optional<torch::Tensor>& cached_recv_gbl_rank_prefix_sum,
int expert_alignment,
int num_worst_tokens,
const Config& config,
std::optional<EventHandle>& previous_event,
bool async,
Expand Down Expand Up @@ -1112,6 +1113,7 @@ Buffer::internode_dispatch(const torch::Tensor& x,
num_experts,
is_token_in_rank.data_ptr<bool>(),
num_tokens,
num_worst_tokens,
num_channels,
hidden_int4,
num_scales,
Expand All @@ -1133,30 +1135,35 @@ Buffer::internode_dispatch(const torch::Tensor& x,
low_latency_mode);

// Synchronize total received tokens and tokens per expert
auto start_time = std::chrono::high_resolution_clock::now();
while (true) {
// Read total count
num_recv_tokens = static_cast<int>(*moe_recv_counter);
num_rdma_recv_tokens = static_cast<int>(*moe_recv_rdma_counter);

// Read per-expert count
bool ready = (num_recv_tokens >= 0) and (num_rdma_recv_tokens >= 0);
for (int i = 0; i < num_local_experts and ready; ++i)
ready &= moe_recv_expert_counter[i] >= 0;

if (ready)
break;

// Timeout check
if (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::high_resolution_clock::now() - start_time).count() >
NUM_CPU_TIMEOUT_SECS) {
printf("Global rank: %d, num_recv_tokens: %d, num_rdma_recv_tokens: %d\n", rank, num_recv_tokens, num_rdma_recv_tokens);
for (int i = 0; i < num_local_experts; ++i)
printf("moe_recv_expert_counter[%d]: %d\n", i, moe_recv_expert_counter[i]);
throw std::runtime_error("DeepEP error: timeout (dispatch CPU)");
if (num_worst_tokens > 0) {
num_recv_tokens = num_worst_tokens;
num_rdma_recv_tokens = num_worst_tokens;
} else {
auto start_time = std::chrono::high_resolution_clock::now();
while (true) {
// Read total count
num_recv_tokens = static_cast<int>(*moe_recv_counter);
num_rdma_recv_tokens = static_cast<int>(*moe_recv_rdma_counter);

// Read per-expert count
bool ready = (num_recv_tokens >= 0) and (num_rdma_recv_tokens >= 0);
for (int i = 0; i < num_local_experts and ready; ++i)
ready &= moe_recv_expert_counter[i] >= 0;

if (ready)
break;

// Timeout check
if (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::high_resolution_clock::now() - start_time).count() >
NUM_CPU_TIMEOUT_SECS) {
printf("Global rank: %d, num_recv_tokens: %d, num_rdma_recv_tokens: %d\n", rank, num_recv_tokens, num_rdma_recv_tokens);
for (int i = 0; i < num_local_experts; ++i)
printf("moe_recv_expert_counter[%d]: %d\n", i, moe_recv_expert_counter[i]);
throw std::runtime_error("DeepEP error: timeout (dispatch CPU)");
}
}
num_recv_tokens_per_expert_list = std::vector<int>(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts);
}
num_recv_tokens_per_expert_list = std::vector<int>(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts);
}

// Allocate new tensors
Expand Down Expand Up @@ -1213,6 +1220,7 @@ Buffer::internode_dispatch(const torch::Tensor& x,
recv_gbl_rank_prefix_sum.data_ptr<int>(),
is_token_in_rank.data_ptr<bool>(),
num_tokens,
num_worst_tokens,
hidden_int4,
num_scales,
num_topk,
Expand Down
1 change: 1 addition & 0 deletions csrc/deep_ep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ struct Buffer {
const std::optional<torch::Tensor>& cached_gbl_channel_prefix_matrix,
const std::optional<torch::Tensor>& cached_recv_gbl_rank_prefix_sum,
int expert_alignment,
int num_worst_tokens,
const Config& config,
std::optional<EventHandle>& previous_event,
bool async,
Expand Down
2 changes: 2 additions & 0 deletions csrc/kernels/api.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ void notify_dispatch(const int* num_tokens_per_rank,
int num_experts,
const bool* is_token_in_rank,
int num_tokens,
int num_worst_tokens,
int num_channels,
int hidden_int4,
int num_scales,
Expand Down Expand Up @@ -193,6 +194,7 @@ void dispatch(void* recv_x,
const int* recv_gbl_rank_prefix_sum,
const bool* is_token_in_rank,
int num_tokens,
int num_worst_tokens,
int hidden_int4,
int num_scales,
int num_topk,
Expand Down
46 changes: 37 additions & 9 deletions csrc/kernels/internode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank,
int num_experts,
const bool* is_token_in_rank,
int num_tokens,
int num_worst_tokens,
int num_channels,
int expert_alignment,
const int rdma_clean_offset,
Expand Down Expand Up @@ -236,9 +237,11 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank,
sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + num_rdma_experts];
recv_rdma_rank_prefix_sum[i] = sum;
}
while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1)
;
*moe_recv_rdma_counter_mapped = sum;
if (num_worst_tokens == 0) {
while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1)
;
*moe_recv_rdma_counter_mapped = sum;
}
}

// Send numbers of tokens per rank/expert to NVL ranks
Expand All @@ -263,19 +266,23 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank,
sum += nvl_recv_num_tokens_per_rank.buffer(src_nvl_rank)[src_rdma_rank];
recv_gbl_rank_prefix_sum[i] = sum;
}
while (ld_volatile_global(moe_recv_counter_mapped) != -1)
;
*moe_recv_counter_mapped = sum;
if (num_worst_tokens == 0) {
while (ld_volatile_global(moe_recv_counter_mapped) != -1)
;
*moe_recv_counter_mapped = sum;
}
}
if (thread_id < num_nvl_experts) {
int sum = 0;
#pragma unroll
for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
sum += nvl_recv_num_tokens_per_expert.buffer(i)[thread_id];
sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment;
while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != -1)
;
moe_recv_expert_counter_mapped[thread_id] = sum;
if (num_worst_tokens == 0) {
while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != -1)
;
moe_recv_expert_counter_mapped[thread_id] = sum;
}
}

// Finally barrier
Expand Down Expand Up @@ -346,6 +353,7 @@ void notify_dispatch(const int* num_tokens_per_rank,
int num_experts,
const bool* is_token_in_rank,
int num_tokens,
int num_worst_tokens,
int num_channels,
int hidden_int4,
int num_scales,
Expand Down Expand Up @@ -380,6 +388,7 @@ void notify_dispatch(const int* num_tokens_per_rank,
num_experts, \
is_token_in_rank, \
num_tokens, \
num_worst_tokens, \
num_channels, \
expert_alignment, \
rdma_clean_meta.first, \
Expand Down Expand Up @@ -455,6 +464,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
const int* recv_gbl_rank_prefix_sum,
const bool* is_token_in_rank,
int num_tokens,
int num_worst_tokens,
int hidden_int4,
int num_scales,
int num_topk,
Expand Down Expand Up @@ -1179,6 +1189,22 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx);
}
}

// Clean unused `recv_topk_idx` as -1
if (num_worst_tokens > 0) {
if (is_forwarder)
return;
// get the actual number of num_recv_tokens on the current rank
int num_recv_tokens = recv_gbl_rank_prefix_sum[num_ranks - 1];
// some ForwarderCoordinator threads exit early, so we only use non-forwarder in clean-up
// channel_id * num_threads is the offset of the current non-forwarder sms
const auto clean_start = num_recv_tokens * num_topk + channel_id * num_threads;
const auto clean_end = num_worst_tokens * num_topk;
const auto clean_stride = num_channels * num_threads;
#pragma unroll
for (int i = clean_start + thread_id; i < clean_end; i += clean_stride)
recv_topk_idx[i] = -1;
}
}

void dispatch(void* recv_x,
Expand All @@ -1200,6 +1226,7 @@ void dispatch(void* recv_x,
const int* recv_gbl_rank_prefix_sum,
const bool* is_token_in_rank,
int num_tokens,
int num_worst_tokens,
int hidden_int4,
int num_scales,
int num_topk,
Expand Down Expand Up @@ -1254,6 +1281,7 @@ void dispatch(void* recv_x,
recv_gbl_rank_prefix_sum, \
is_token_in_rank, \
num_tokens, \
num_worst_tokens, \
hidden_int4, \
num_scales, \
num_topk, \
Expand Down
11 changes: 5 additions & 6 deletions deep_ep/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,9 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],

# Internode
if self.runtime.get_num_rdma_ranks() > 1:
assert num_worst_tokens == 0, 'Internode dispatch does not support `num_worst_tokens > 0`'
return self.internode_dispatch(x, handle, num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank,
num_tokens_per_expert, topk_idx, topk_weights, expert_alignment, config, previous_event,
async_finish, allocate_on_comm_stream)
num_tokens_per_expert, topk_idx, topk_weights, expert_alignment, num_worst_tokens, config,
previous_event, async_finish, allocate_on_comm_stream)

# Launch the kernel with cached or non-cached mode
x, x_scales = x if isinstance(x, tuple) else (x, None)
Expand Down Expand Up @@ -456,7 +455,7 @@ def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Te
num_tokens_per_rank: Optional[torch.Tensor] = None, num_tokens_per_rdma_rank: Optional[torch.Tensor] = None,
is_token_in_rank: Optional[torch.Tensor] = None, num_tokens_per_expert: Optional[torch.Tensor] = None,
topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, expert_alignment: int = 1,
config: Optional[Config] = None,
num_worst_tokens: int = 0, config: Optional[Config] = None,
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
allocate_on_comm_stream: bool = False) -> \
Tuple[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], Optional[torch.Tensor],
Expand All @@ -480,7 +479,7 @@ def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Te
recv_x, recv_x_scales, _, _, _, _, _, _, _, _, _, _, _, _, event = self.runtime.internode_dispatch(
x, x_scales, topk_idx, topk_weights, None, None, is_token_in_rank, None, num_recv_tokens, num_rdma_recv_tokens,
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
expert_alignment, num_worst_tokens, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event)
else:
assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
Expand All @@ -492,7 +491,7 @@ def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Te
x, x_scales, topk_idx, topk_weights,
num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert,
0, 0, None, None, None, None,
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
expert_alignment, num_worst_tokens, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
handle = (is_token_in_rank, rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, recv_rdma_channel_prefix_matrix,
recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, recv_src_meta, send_rdma_head,
send_nvl_head)
Expand Down
18 changes: 18 additions & 0 deletions tests/test_internode.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def check_data(check_x, recv_gbl_rank_prefix_sum):
assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list
if not is_rand:
check_data(recv_x, recv_gbl_rank_prefix_sum)
recv_topk_weights_clone = None
if with_topk:
# Check `topk_idx`
assert (recv_topk_idx.eq(-1) |
Expand All @@ -163,11 +164,28 @@ def check_data(check_x, recv_gbl_rank_prefix_sum):
assert recv_topk_idx.eq(i).sum().item() == count

# Check `topk_weights`
recv_topk_weights_clone = recv_topk_weights.clone()
if not is_rand:
recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(
dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]
check_data(recv_topk_weights, recv_gbl_rank_prefix_sum)

# Test `num_worst_tokens != 0`
if with_topk:
num_worst_tokens = num_tokens * num_ranks
dispatch_args.update({'num_worst_tokens': num_worst_tokens})
recv_worst_x, recv_worst_topk_idx, recv_worst_topk_weights, empty_list, _, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_worst_x = per_token_cast_back(*recv_worst_x) if isinstance(recv_worst_x, tuple) else recv_worst_x
assert len(empty_list) == 0
assert num_worst_tokens == recv_worst_x.size(0)
assert num_worst_tokens == recv_worst_topk_idx.size(0)
assert num_worst_tokens == recv_worst_topk_weights.size(0)
assert torch.equal(recv_x, recv_worst_x[:recv_x.size(0)])
assert torch.equal(recv_topk_idx, recv_worst_topk_idx[:recv_x.size(0)])
assert torch.equal(recv_topk_weights_clone, recv_worst_topk_weights[:recv_x.size(0)])
assert torch.all(recv_worst_topk_idx[recv_x.size(0):] == -1).item()

# Test cached dispatch (must without top-k staffs)
if not with_topk:
dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
Expand Down