Skip to content

Commit 51540fa

Browse files
committed
feat: remove record_stream of normal mode
1 parent 4623c67 commit 51540fa

File tree

3 files changed

+57
-111
lines changed

3 files changed

+57
-111
lines changed

csrc/deep_ep.cpp

Lines changed: 5 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -320,16 +320,7 @@ Buffer::get_dispatch_layout(
320320
std::optional<EventHandle> event;
321321
if (async) {
322322
event = EventHandle(comm_stream);
323-
for (auto& t : {topk_idx, num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank}) {
324-
t.record_stream(comm_stream);
325-
if (allocate_on_comm_stream)
326-
t.record_stream(compute_stream);
327-
}
328-
for (auto& to : {num_tokens_per_rdma_rank}) {
329-
to.has_value() ? to->record_stream(comm_stream) : void();
330-
if (allocate_on_comm_stream)
331-
to.has_value() ? to->record_stream(compute_stream) : void();
332-
}
323+
// NOTES: record_stream removed, tensors are now held by Python-layer extra_tensors
333324
} else {
334325
stream_wait(compute_stream, comm_stream);
335326
}
@@ -606,32 +597,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x,
606597
std::optional<EventHandle> event;
607598
if (async) {
608599
event = EventHandle(comm_stream);
609-
for (auto& t : {x,
610-
is_token_in_rank,
611-
rank_prefix_matrix,
612-
channel_prefix_matrix,
613-
recv_x,
614-
recv_src_idx,
615-
recv_channel_prefix_matrix,
616-
send_head}) {
617-
t.record_stream(comm_stream);
618-
if (allocate_on_comm_stream)
619-
t.record_stream(compute_stream);
620-
}
621-
for (auto& to : {x_scales,
622-
topk_idx,
623-
topk_weights,
624-
num_tokens_per_rank,
625-
num_tokens_per_expert,
626-
cached_channel_prefix_matrix,
627-
cached_rank_prefix_matrix,
628-
recv_topk_idx,
629-
recv_topk_weights,
630-
recv_x_scales}) {
631-
to.has_value() ? to->record_stream(comm_stream) : void();
632-
if (allocate_on_comm_stream)
633-
to.has_value() ? to->record_stream(compute_stream) : void();
634-
}
600+
// NOTES: record_stream removed, tensors are now held by Python-layer extra_tensors
635601
} else {
636602
stream_wait(compute_stream, comm_stream);
637603
}
@@ -774,16 +740,7 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandl
774740
std::optional<EventHandle> event;
775741
if (async) {
776742
event = EventHandle(comm_stream);
777-
for (auto& t : {x, src_idx, send_head, rank_prefix_matrix, channel_prefix_matrix, recv_x}) {
778-
t.record_stream(comm_stream);
779-
if (allocate_on_comm_stream)
780-
t.record_stream(compute_stream);
781-
}
782-
for (auto& to : {topk_weights, recv_topk_weights, bias_0, bias_1}) {
783-
to.has_value() ? to->record_stream(comm_stream) : void();
784-
if (allocate_on_comm_stream)
785-
to.has_value() ? to->record_stream(compute_stream) : void();
786-
}
743+
// NOTES: record_stream removed, tensors are now held by Python-layer extra_tensors
787744
} else {
788745
stream_wait(compute_stream, comm_stream);
789746
}
@@ -1121,39 +1078,7 @@ Buffer::internode_dispatch(const torch::Tensor& x,
11211078
std::optional<EventHandle> event;
11221079
if (async) {
11231080
event = EventHandle(comm_stream);
1124-
for (auto& t : {x,
1125-
is_token_in_rank,
1126-
recv_x,
1127-
rdma_channel_prefix_matrix,
1128-
recv_rdma_rank_prefix_sum,
1129-
gbl_channel_prefix_matrix,
1130-
recv_gbl_rank_prefix_sum}) {
1131-
t.record_stream(comm_stream);
1132-
if (allocate_on_comm_stream)
1133-
t.record_stream(compute_stream);
1134-
}
1135-
for (auto& to : {x_scales,
1136-
topk_idx,
1137-
topk_weights,
1138-
num_tokens_per_rank,
1139-
num_tokens_per_rdma_rank,
1140-
num_tokens_per_expert,
1141-
cached_rdma_channel_prefix_matrix,
1142-
cached_recv_rdma_rank_prefix_sum,
1143-
cached_gbl_channel_prefix_matrix,
1144-
cached_recv_gbl_rank_prefix_sum,
1145-
recv_topk_idx,
1146-
recv_topk_weights,
1147-
recv_x_scales,
1148-
recv_rdma_channel_prefix_matrix,
1149-
recv_gbl_channel_prefix_matrix,
1150-
send_rdma_head,
1151-
send_nvl_head,
1152-
recv_src_meta}) {
1153-
to.has_value() ? to->record_stream(comm_stream) : void();
1154-
if (allocate_on_comm_stream)
1155-
to.has_value() ? to->record_stream(compute_stream) : void();
1156-
}
1081+
// NOTES: record_stream removed, tensors are now held by Python-layer extra_tensors
11571082
} else {
11581083
stream_wait(compute_stream, comm_stream);
11591084
}
@@ -1338,24 +1263,7 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandl
13381263
std::optional<EventHandle> event;
13391264
if (async) {
13401265
event = EventHandle(comm_stream);
1341-
for (auto& t : {x,
1342-
src_meta,
1343-
is_combined_token_in_rank,
1344-
rdma_channel_prefix_matrix,
1345-
rdma_rank_prefix_sum,
1346-
gbl_channel_prefix_matrix,
1347-
combined_x,
1348-
combined_rdma_head,
1349-
combined_nvl_head}) {
1350-
t.record_stream(comm_stream);
1351-
if (allocate_on_comm_stream)
1352-
t.record_stream(compute_stream);
1353-
}
1354-
for (auto& to : {topk_weights, combined_topk_weights, bias_0, bias_1}) {
1355-
to.has_value() ? to->record_stream(comm_stream) : void();
1356-
if (allocate_on_comm_stream)
1357-
to.has_value() ? to->record_stream(compute_stream) : void();
1358-
}
1266+
// NOTES: record_stream removed, tensors are now held by Python-layer extra_tensors
13591267
} else {
13601268
stream_wait(compute_stream, comm_stream);
13611269
}

deep_ep/buffer.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,11 @@ def get_dispatch_layout(self, topk_idx: torch.Tensor, num_experts: int,
314314
num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event = \
315315
self.runtime.get_dispatch_layout(topk_idx, num_experts, getattr(previous_event, 'event', None),
316316
async_finish, allocate_on_comm_stream)
317-
return num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, EventOverlap(event)
317+
if async_finish:
318+
tensors_to_record = tuple(t for t in (topk_idx, num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank) if t is not None)
319+
return num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, EventOverlap(event, tensors_to_record)
320+
else:
321+
return num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, EventOverlap(event)
318322

319323
# noinspection PyTypeChecker
320324
def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
@@ -386,7 +390,11 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
386390
recv_x, recv_x_scales, _, _, _, _, _, _, _, _, event = self.runtime.intranode_dispatch(
387391
x, x_scales, None, None, None, is_token_in_rank, None, num_recv_tokens, rank_prefix_matrix, channel_prefix_matrix,
388392
expert_alignment, num_worst_tokens, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
389-
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event)
393+
if async_finish:
394+
tensors_to_record = tuple(t for t in (x, x_scales, is_token_in_rank, rank_prefix_matrix, channel_prefix_matrix, recv_x, recv_x_scales, recv_src_idx) if t is not None)
395+
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event, tensors_to_record)
396+
else:
397+
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event)
390398
else:
391399
assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
392400
recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event = \
@@ -395,10 +403,13 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
395403
expert_alignment, num_worst_tokens, config,
396404
getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
397405
handle = (rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head)
398-
return (
399-
recv_x, recv_x_scales
400-
) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(
401-
event)
406+
if async_finish:
407+
tensors_to_record = tuple(t for t in (x, x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_expert,
408+
is_token_in_rank, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix,
409+
recv_x, recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, send_head) if t is not None)
410+
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event, tensors_to_record)
411+
else:
412+
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event)
402413

403414
# noinspection PyTypeChecker
404415
def combine(self, x: torch.Tensor, handle: Tuple,
@@ -446,7 +457,11 @@ def combine(self, x: torch.Tensor, handle: Tuple,
446457
channel_prefix_matrix, send_head, config,
447458
getattr(previous_event, 'event',
448459
None), async_finish, allocate_on_comm_stream)
449-
return recv_x, recv_topk_weights, EventOverlap(event)
460+
if async_finish:
461+
tensors_to_record = tuple(t for t in (x, topk_weights, bias_0, bias_1, src_idx, rank_prefix_matrix, channel_prefix_matrix, send_head, recv_x, recv_topk_weights) if t is not None)
462+
return recv_x, recv_topk_weights, EventOverlap(event, tensors_to_record)
463+
else:
464+
return recv_x, recv_topk_weights, EventOverlap(event)
450465

451466
# noinspection PyTypeChecker
452467
def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
@@ -479,7 +494,14 @@ def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Te
479494
x, x_scales, topk_idx, topk_weights, None, None, is_token_in_rank, None, num_recv_tokens, num_rdma_recv_tokens,
480495
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,
481496
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
482-
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event)
497+
498+
if async_finish:
499+
tensors_to_record = tuple(t for t in (x, x_scales, is_token_in_rank, recv_x, recv_x_scales,
500+
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,
501+
recv_rdma_channel_prefix_matrix, recv_src_meta, send_rdma_head, send_nvl_head) if t is not None)
502+
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event, tensors_to_record)
503+
else:
504+
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event)
483505
else:
484506
assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
485507
recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, \
@@ -494,10 +516,16 @@ def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Te
494516
handle = (is_token_in_rank, rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, recv_rdma_channel_prefix_matrix,
495517
recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, recv_src_meta, send_rdma_head,
496518
send_nvl_head)
497-
return (
498-
recv_x, recv_x_scales
499-
) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(
500-
event)
519+
if async_finish:
520+
tensors_to_record = tuple(t for t in (x, x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert,
521+
is_token_in_rank, recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights,
522+
rdma_channel_prefix_matrix, gbl_channel_prefix_matrix,
523+
recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum,
524+
recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,
525+
recv_src_meta, send_rdma_head, send_nvl_head) if t is not None)
526+
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event, tensors_to_record)
527+
else:
528+
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event)
501529

502530
# noinspection PyTypeChecker
503531
def internode_combine(self, x: torch.Tensor, handle: Union[tuple, list],
@@ -527,7 +555,13 @@ def internode_combine(self, x: torch.Tensor, handle: Union[tuple, list],
527555
send_rdma_head, send_nvl_head, config,
528556
getattr(previous_event, 'event',
529557
None), async_finish, allocate_on_comm_stream)
530-
return combined_x, combined_topk_weights, EventOverlap(event)
558+
if async_finish:
559+
tensors_to_record = tuple(t for t in (x, topk_weights, bias_0, bias_1, src_meta, is_combined_token_in_rank,
560+
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix,
561+
send_rdma_head, send_nvl_head, combined_x, combined_topk_weights) if t is not None)
562+
return combined_x, combined_topk_weights, EventOverlap(event, tensors_to_record)
563+
else:
564+
return combined_x, combined_topk_weights, EventOverlap(event)
531565

532566
def clean_low_latency_buffer(self, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> None:
533567
"""

deep_ep/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,12 @@ def __init__(self, event: Optional[EventHandle] = None, extra_tensors: Optional[
3333
def current_stream_wait(self) -> None:
3434
"""
3535
The current stream `torch.cuda.current_stream()` waits for the event to be finished.
36+
After synchronization completes, tensor references are released to allow memory reuse.
3637
"""
3738
assert self.event is not None
3839
self.event.current_stream_wait()
40+
# Release tensor references after synchronization is complete
41+
self.extra_tensors = None
3942

4043
def __enter__(self) -> Any:
4144
"""
@@ -56,9 +59,10 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
5659
Utility for overlapping and Python `with` syntax.
5760
5861
Please follow the example in the `__enter__` function.
62+
After synchronization completes, tensor references are released to allow memory reuse.
5963
"""
6064
if self.event is not None:
61-
self.event.current_stream_wait()
65+
self.current_stream_wait()
6266

6367

6468
def check_nvlink_connections(group: dist.ProcessGroup):

0 commit comments

Comments
 (0)