Skip to content

Commit d5e6717

Browse files
committed
Enable CUDA Graph for internode dispatch
1 parent 9249c25 commit d5e6717

File tree

6 files changed

+93
-37
lines changed

6 files changed

+93
-37
lines changed

csrc/deep_ep.cpp

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -940,6 +940,7 @@ Buffer::internode_dispatch(const torch::Tensor& x,
940940
const std::optional<torch::Tensor>& cached_gbl_channel_prefix_matrix,
941941
const std::optional<torch::Tensor>& cached_recv_gbl_rank_prefix_sum,
942942
int expert_alignment,
943+
int num_worst_tokens,
943944
const Config& config,
944945
std::optional<EventHandle>& previous_event,
945946
bool async,
@@ -1112,6 +1113,7 @@ Buffer::internode_dispatch(const torch::Tensor& x,
11121113
num_experts,
11131114
is_token_in_rank.data_ptr<bool>(),
11141115
num_tokens,
1116+
num_worst_tokens,
11151117
num_channels,
11161118
hidden_int4,
11171119
num_scales,
@@ -1133,30 +1135,35 @@ Buffer::internode_dispatch(const torch::Tensor& x,
11331135
low_latency_mode);
11341136

11351137
// Synchronize total received tokens and tokens per expert
1136-
auto start_time = std::chrono::high_resolution_clock::now();
1137-
while (true) {
1138-
// Read total count
1139-
num_recv_tokens = static_cast<int>(*moe_recv_counter);
1140-
num_rdma_recv_tokens = static_cast<int>(*moe_recv_rdma_counter);
1141-
1142-
// Read per-expert count
1143-
bool ready = (num_recv_tokens >= 0) and (num_rdma_recv_tokens >= 0);
1144-
for (int i = 0; i < num_local_experts and ready; ++i)
1145-
ready &= moe_recv_expert_counter[i] >= 0;
1146-
1147-
if (ready)
1148-
break;
1149-
1150-
// Timeout check
1151-
if (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::high_resolution_clock::now() - start_time).count() >
1152-
NUM_CPU_TIMEOUT_SECS) {
1153-
printf("Global rank: %d, num_recv_tokens: %d, num_rdma_recv_tokens: %d\n", rank, num_recv_tokens, num_rdma_recv_tokens);
1154-
for (int i = 0; i < num_local_experts; ++i)
1155-
printf("moe_recv_expert_counter[%d]: %d\n", i, moe_recv_expert_counter[i]);
1156-
throw std::runtime_error("DeepEP error: timeout (dispatch CPU)");
1138+
if (num_worst_tokens > 0) {
1139+
num_recv_tokens = num_worst_tokens;
1140+
num_rdma_recv_tokens = num_worst_tokens;
1141+
} else {
1142+
auto start_time = std::chrono::high_resolution_clock::now();
1143+
while (true) {
1144+
// Read total count
1145+
num_recv_tokens = static_cast<int>(*moe_recv_counter);
1146+
num_rdma_recv_tokens = static_cast<int>(*moe_recv_rdma_counter);
1147+
1148+
// Read per-expert count
1149+
bool ready = (num_recv_tokens >= 0) and (num_rdma_recv_tokens >= 0);
1150+
for (int i = 0; i < num_local_experts and ready; ++i)
1151+
ready &= moe_recv_expert_counter[i] >= 0;
1152+
1153+
if (ready)
1154+
break;
1155+
1156+
// Timeout check
1157+
if (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::high_resolution_clock::now() - start_time).count() >
1158+
NUM_CPU_TIMEOUT_SECS) {
1159+
printf("Global rank: %d, num_recv_tokens: %d, num_rdma_recv_tokens: %d\n", rank, num_recv_tokens, num_rdma_recv_tokens);
1160+
for (int i = 0; i < num_local_experts; ++i)
1161+
printf("moe_recv_expert_counter[%d]: %d\n", i, moe_recv_expert_counter[i]);
1162+
throw std::runtime_error("DeepEP error: timeout (dispatch CPU)");
1163+
}
11571164
}
1165+
num_recv_tokens_per_expert_list = std::vector<int>(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts);
11581166
}
1159-
num_recv_tokens_per_expert_list = std::vector<int>(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts);
11601167
}
11611168

11621169
// Allocate new tensors
@@ -1213,6 +1220,7 @@ Buffer::internode_dispatch(const torch::Tensor& x,
12131220
recv_gbl_rank_prefix_sum.data_ptr<int>(),
12141221
is_token_in_rank.data_ptr<bool>(),
12151222
num_tokens,
1223+
num_worst_tokens,
12161224
hidden_int4,
12171225
num_scales,
12181226
num_topk,

csrc/deep_ep.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ struct Buffer {
229229
const std::optional<torch::Tensor>& cached_gbl_channel_prefix_matrix,
230230
const std::optional<torch::Tensor>& cached_recv_gbl_rank_prefix_sum,
231231
int expert_alignment,
232+
int num_worst_tokens,
232233
const Config& config,
233234
std::optional<EventHandle>& previous_event,
234235
bool async,

csrc/kernels/api.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ void notify_dispatch(const int* num_tokens_per_rank,
154154
int num_experts,
155155
const bool* is_token_in_rank,
156156
int num_tokens,
157+
int num_worst_tokens,
157158
int num_channels,
158159
int hidden_int4,
159160
int num_scales,
@@ -193,6 +194,7 @@ void dispatch(void* recv_x,
193194
const int* recv_gbl_rank_prefix_sum,
194195
const bool* is_token_in_rank,
195196
int num_tokens,
197+
int num_worst_tokens,
196198
int hidden_int4,
197199
int num_scales,
198200
int num_topk,

csrc/kernels/internode.cu

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank,
100100
int num_experts,
101101
const bool* is_token_in_rank,
102102
int num_tokens,
103+
int num_worst_tokens,
103104
int num_channels,
104105
int expert_alignment,
105106
const int rdma_clean_offset,
@@ -236,9 +237,11 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank,
236237
sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + num_rdma_experts];
237238
recv_rdma_rank_prefix_sum[i] = sum;
238239
}
239-
while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1)
240-
;
241-
*moe_recv_rdma_counter_mapped = sum;
240+
if (num_worst_tokens == 0) {
241+
while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1)
242+
;
243+
*moe_recv_rdma_counter_mapped = sum;
244+
}
242245
}
243246

244247
// Send numbers of tokens per rank/expert to NVL ranks
@@ -263,19 +266,23 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank,
263266
sum += nvl_recv_num_tokens_per_rank.buffer(src_nvl_rank)[src_rdma_rank];
264267
recv_gbl_rank_prefix_sum[i] = sum;
265268
}
266-
while (ld_volatile_global(moe_recv_counter_mapped) != -1)
267-
;
268-
*moe_recv_counter_mapped = sum;
269+
if (num_worst_tokens == 0) {
270+
while (ld_volatile_global(moe_recv_counter_mapped) != -1)
271+
;
272+
*moe_recv_counter_mapped = sum;
273+
}
269274
}
270275
if (thread_id < num_nvl_experts) {
271276
int sum = 0;
272277
#pragma unroll
273278
for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i)
274279
sum += nvl_recv_num_tokens_per_expert.buffer(i)[thread_id];
275280
sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment;
276-
while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != -1)
277-
;
278-
moe_recv_expert_counter_mapped[thread_id] = sum;
281+
if (num_worst_tokens == 0) {
282+
while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != -1)
283+
;
284+
moe_recv_expert_counter_mapped[thread_id] = sum;
285+
}
279286
}
280287

281288
// Finally barrier
@@ -346,6 +353,7 @@ void notify_dispatch(const int* num_tokens_per_rank,
346353
int num_experts,
347354
const bool* is_token_in_rank,
348355
int num_tokens,
356+
int num_worst_tokens,
349357
int num_channels,
350358
int hidden_int4,
351359
int num_scales,
@@ -380,6 +388,7 @@ void notify_dispatch(const int* num_tokens_per_rank,
380388
num_experts, \
381389
is_token_in_rank, \
382390
num_tokens, \
391+
num_worst_tokens, \
383392
num_channels, \
384393
expert_alignment, \
385394
rdma_clean_meta.first, \
@@ -455,6 +464,7 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
455464
const int* recv_gbl_rank_prefix_sum,
456465
const bool* is_token_in_rank,
457466
int num_tokens,
467+
int num_worst_tokens,
458468
int hidden_int4,
459469
int num_scales,
460470
int num_topk,
@@ -1179,6 +1189,22 @@ __global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NV
11791189
st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx);
11801190
}
11811191
}
1192+
1193+
// Clean unused `recv_topk_idx` as -1
1194+
if (num_worst_tokens > 0) {
1195+
if (is_forwarder)
1196+
return;
1197+
// get the actual number of num_recv_tokens on the current rank
1198+
int num_recv_tokens = recv_gbl_rank_prefix_sum[num_ranks - 1];
1199+
// some ForwarderCoordinator threads exit early, so we only use non-forwarder in clean-up
1200+
// channel_id * num_threads is the offset of the current non-forwarder sms
1201+
const auto clean_start = num_recv_tokens * num_topk + channel_id * num_threads;
1202+
const auto clean_end = num_worst_tokens * num_topk;
1203+
const auto clean_stride = num_channels * num_threads;
1204+
#pragma unroll
1205+
for (int i = clean_start + thread_id; i < clean_end; i += clean_stride)
1206+
recv_topk_idx[i] = -1;
1207+
}
11821208
}
11831209

11841210
void dispatch(void* recv_x,
@@ -1200,6 +1226,7 @@ void dispatch(void* recv_x,
12001226
const int* recv_gbl_rank_prefix_sum,
12011227
const bool* is_token_in_rank,
12021228
int num_tokens,
1229+
int num_worst_tokens,
12031230
int hidden_int4,
12041231
int num_scales,
12051232
int num_topk,
@@ -1254,6 +1281,7 @@ void dispatch(void* recv_x,
12541281
recv_gbl_rank_prefix_sum, \
12551282
is_token_in_rank, \
12561283
num_tokens, \
1284+
num_worst_tokens, \
12571285
hidden_int4, \
12581286
num_scales, \
12591287
num_topk, \

deep_ep/buffer.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -374,10 +374,9 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
374374

375375
# Internode
376376
if self.runtime.get_num_rdma_ranks() > 1:
377-
assert num_worst_tokens == 0, 'Internode dispatch does not support `num_worst_tokens > 0`'
378377
return self.internode_dispatch(x, handle, num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank,
379-
num_tokens_per_expert, topk_idx, topk_weights, expert_alignment, config, previous_event,
380-
async_finish, allocate_on_comm_stream)
378+
num_tokens_per_expert, topk_idx, topk_weights, expert_alignment, num_worst_tokens, config,
379+
previous_event, async_finish, allocate_on_comm_stream)
381380

382381
# Launch the kernel with cached or non-cached mode
383382
x, x_scales = x if isinstance(x, tuple) else (x, None)
@@ -456,7 +455,7 @@ def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Te
456455
num_tokens_per_rank: Optional[torch.Tensor] = None, num_tokens_per_rdma_rank: Optional[torch.Tensor] = None,
457456
is_token_in_rank: Optional[torch.Tensor] = None, num_tokens_per_expert: Optional[torch.Tensor] = None,
458457
topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, expert_alignment: int = 1,
459-
config: Optional[Config] = None,
458+
num_worst_tokens: int = 0, config: Optional[Config] = None,
460459
previous_event: Optional[EventOverlap] = None, async_finish: bool = False,
461460
allocate_on_comm_stream: bool = False) -> \
462461
Tuple[Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], Optional[torch.Tensor],
@@ -480,7 +479,7 @@ def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Te
480479
recv_x, recv_x_scales, _, _, _, _, _, _, _, _, _, _, _, _, event = self.runtime.internode_dispatch(
481480
x, x_scales, topk_idx, topk_weights, None, None, is_token_in_rank, None, num_recv_tokens, num_rdma_recv_tokens,
482481
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,
483-
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
482+
expert_alignment, num_worst_tokens, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
484483
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event)
485484
else:
486485
assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
@@ -492,7 +491,7 @@ def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Te
492491
x, x_scales, topk_idx, topk_weights,
493492
num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert,
494493
0, 0, None, None, None, None,
495-
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
494+
expert_alignment, num_worst_tokens, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
496495
handle = (is_token_in_rank, rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, recv_rdma_channel_prefix_matrix,
497496
recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, recv_src_meta, send_rdma_head,
498497
send_nvl_head)

tests/test_internode.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def check_data(check_x, recv_gbl_rank_prefix_sum):
154154
assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list
155155
if not is_rand:
156156
check_data(recv_x, recv_gbl_rank_prefix_sum)
157+
recv_topk_weights_clone = None
157158
if with_topk:
158159
# Check `topk_idx`
159160
assert (recv_topk_idx.eq(-1) |
@@ -163,11 +164,28 @@ def check_data(check_x, recv_gbl_rank_prefix_sum):
163164
assert recv_topk_idx.eq(i).sum().item() == count
164165

165166
# Check `topk_weights`
167+
recv_topk_weights_clone = recv_topk_weights.clone()
166168
if not is_rand:
167169
recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(
168170
dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)]
169171
check_data(recv_topk_weights, recv_gbl_rank_prefix_sum)
170172

173+
# Test `num_worst_tokens != 0`
174+
if with_topk:
175+
num_worst_tokens = num_tokens * num_ranks
176+
dispatch_args.update({'num_worst_tokens': num_worst_tokens})
177+
recv_worst_x, recv_worst_topk_idx, recv_worst_topk_weights, empty_list, _, event = buffer.dispatch(**dispatch_args)
178+
event.current_stream_wait() if async_mode else ()
179+
recv_worst_x = per_token_cast_back(*recv_worst_x) if isinstance(recv_worst_x, tuple) else recv_worst_x
180+
assert len(empty_list) == 0
181+
assert num_worst_tokens == recv_worst_x.size(0)
182+
assert num_worst_tokens == recv_worst_topk_idx.size(0)
183+
assert num_worst_tokens == recv_worst_topk_weights.size(0)
184+
assert torch.equal(recv_x, recv_worst_x[:recv_x.size(0)])
185+
assert torch.equal(recv_topk_idx, recv_worst_topk_idx[:recv_x.size(0)])
186+
assert torch.equal(recv_topk_weights_clone, recv_worst_topk_weights[:recv_x.size(0)])
187+
assert torch.all(recv_worst_topk_idx[recv_x.size(0):] == -1).item()
188+
171189
# Test cached dispatch (must without top-k staffs)
172190
if not with_topk:
173191
dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode}

0 commit comments

Comments
 (0)