@@ -314,7 +314,11 @@ def get_dispatch_layout(self, topk_idx: torch.Tensor, num_experts: int,
314314 num_tokens_per_rank , num_tokens_per_rdma_rank , num_tokens_per_expert , is_token_in_rank , event = \
315315 self .runtime .get_dispatch_layout (topk_idx , num_experts , getattr (previous_event , 'event' , None ),
316316 async_finish , allocate_on_comm_stream )
317- return num_tokens_per_rank , num_tokens_per_rdma_rank , num_tokens_per_expert , is_token_in_rank , EventOverlap (event )
317+ if async_finish :
318+ tensors_to_record = tuple (t for t in (topk_idx , num_tokens_per_rank , num_tokens_per_rdma_rank , num_tokens_per_expert , is_token_in_rank ) if t is not None )
319+ return num_tokens_per_rank , num_tokens_per_rdma_rank , num_tokens_per_expert , is_token_in_rank , EventOverlap (event , tensors_to_record )
320+ else :
321+ return num_tokens_per_rank , num_tokens_per_rdma_rank , num_tokens_per_expert , is_token_in_rank , EventOverlap (event )
318322
319323 # noinspection PyTypeChecker
320324 def dispatch (self , x : Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]],
@@ -386,7 +390,11 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
386390 recv_x , recv_x_scales , _ , _ , _ , _ , _ , _ , _ , _ , event = self .runtime .intranode_dispatch (
387391 x , x_scales , None , None , None , is_token_in_rank , None , num_recv_tokens , rank_prefix_matrix , channel_prefix_matrix ,
388392 expert_alignment , num_worst_tokens , config , getattr (previous_event , 'event' , None ), async_finish , allocate_on_comm_stream )
389- return (recv_x , recv_x_scales ) if x_scales is not None else recv_x , None , None , None , None , EventOverlap (event )
393+ if async_finish :
394+ tensors_to_record = tuple (t for t in (x , x_scales , is_token_in_rank , rank_prefix_matrix , channel_prefix_matrix , recv_x , recv_x_scales , recv_src_idx ) if t is not None )
395+ return (recv_x , recv_x_scales ) if x_scales is not None else recv_x , None , None , None , None , EventOverlap (event , tensors_to_record )
396+ else :
397+ return (recv_x , recv_x_scales ) if x_scales is not None else recv_x , None , None , None , None , EventOverlap (event )
390398 else :
391399 assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
392400 recv_x , recv_x_scales , recv_topk_idx , recv_topk_weights , num_recv_tokens_per_expert_list , rank_prefix_matrix , channel_prefix_matrix , recv_channel_prefix_matrix , recv_src_idx , send_head , event = \
@@ -395,10 +403,13 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
395403 expert_alignment , num_worst_tokens , config ,
396404 getattr (previous_event , 'event' , None ), async_finish , allocate_on_comm_stream )
397405 handle = (rank_prefix_matrix , channel_prefix_matrix , recv_channel_prefix_matrix , recv_src_idx , is_token_in_rank , send_head )
398- return (
399- recv_x , recv_x_scales
400- ) if x_scales is not None else recv_x , recv_topk_idx , recv_topk_weights , num_recv_tokens_per_expert_list , handle , EventOverlap (
401- event )
406+ if async_finish :
407+ tensors_to_record = tuple (t for t in (x , x_scales , topk_idx , topk_weights , num_tokens_per_rank , num_tokens_per_expert ,
408+ is_token_in_rank , rank_prefix_matrix , channel_prefix_matrix , recv_channel_prefix_matrix ,
409+ recv_x , recv_x_scales , recv_src_idx , recv_topk_idx , recv_topk_weights , send_head ) if t is not None )
410+ return (recv_x , recv_x_scales ) if x_scales is not None else recv_x , recv_topk_idx , recv_topk_weights , num_recv_tokens_per_expert_list , handle , EventOverlap (event , tensors_to_record )
411+ else :
412+ return (recv_x , recv_x_scales ) if x_scales is not None else recv_x , recv_topk_idx , recv_topk_weights , num_recv_tokens_per_expert_list , handle , EventOverlap (event )
402413
403414 # noinspection PyTypeChecker
404415 def combine (self , x : torch .Tensor , handle : Tuple ,
@@ -446,7 +457,11 @@ def combine(self, x: torch.Tensor, handle: Tuple,
446457 channel_prefix_matrix , send_head , config ,
447458 getattr (previous_event , 'event' ,
448459 None ), async_finish , allocate_on_comm_stream )
449- return recv_x , recv_topk_weights , EventOverlap (event )
460+ if async_finish :
461+ tensors_to_record = tuple (t for t in (x , topk_weights , bias_0 , bias_1 , src_idx , rank_prefix_matrix , channel_prefix_matrix , send_head , recv_x , recv_topk_weights ) if t is not None )
462+ return recv_x , recv_topk_weights , EventOverlap (event , tensors_to_record )
463+ else :
464+ return recv_x , recv_topk_weights , EventOverlap (event )
450465
451466 # noinspection PyTypeChecker
452467 def internode_dispatch (self , x : Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]],
@@ -479,7 +494,14 @@ def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Te
479494 x , x_scales , topk_idx , topk_weights , None , None , is_token_in_rank , None , num_recv_tokens , num_rdma_recv_tokens ,
480495 rdma_channel_prefix_matrix , recv_rdma_rank_prefix_sum , gbl_channel_prefix_matrix , recv_gbl_rank_prefix_sum ,
481496 expert_alignment , config , getattr (previous_event , 'event' , None ), async_finish , allocate_on_comm_stream )
482- return (recv_x , recv_x_scales ) if x_scales is not None else recv_x , None , None , None , None , EventOverlap (event )
497+
498+ if async_finish :
499+ tensors_to_record = tuple (t for t in (x , x_scales , is_token_in_rank , recv_x , recv_x_scales ,
500+ rdma_channel_prefix_matrix , recv_rdma_rank_prefix_sum , gbl_channel_prefix_matrix , recv_gbl_rank_prefix_sum ,
501+ recv_rdma_channel_prefix_matrix , recv_src_meta , send_rdma_head , send_nvl_head ) if t is not None )
502+ return (recv_x , recv_x_scales ) if x_scales is not None else recv_x , None , None , None , None , EventOverlap (event , tensors_to_record )
503+ else :
504+ return (recv_x , recv_x_scales ) if x_scales is not None else recv_x , None , None , None , None , EventOverlap (event )
483505 else :
484506 assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
485507 recv_x , recv_x_scales , recv_topk_idx , recv_topk_weights , num_recv_tokens_per_expert_list , \
@@ -494,10 +516,16 @@ def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Te
494516 handle = (is_token_in_rank , rdma_channel_prefix_matrix , gbl_channel_prefix_matrix , recv_rdma_channel_prefix_matrix ,
495517 recv_rdma_rank_prefix_sum , recv_gbl_channel_prefix_matrix , recv_gbl_rank_prefix_sum , recv_src_meta , send_rdma_head ,
496518 send_nvl_head )
497- return (
498- recv_x , recv_x_scales
499- ) if x_scales is not None else recv_x , recv_topk_idx , recv_topk_weights , num_recv_tokens_per_expert_list , handle , EventOverlap (
500- event )
519+ if async_finish :
520+ tensors_to_record = tuple (t for t in (x , x_scales , topk_idx , topk_weights , num_tokens_per_rank , num_tokens_per_rdma_rank , num_tokens_per_expert ,
521+ is_token_in_rank , recv_x , recv_x_scales , recv_topk_idx , recv_topk_weights ,
522+ rdma_channel_prefix_matrix , gbl_channel_prefix_matrix ,
523+ recv_rdma_channel_prefix_matrix , recv_rdma_rank_prefix_sum ,
524+ recv_gbl_channel_prefix_matrix , recv_gbl_rank_prefix_sum ,
525+ recv_src_meta , send_rdma_head , send_nvl_head ) if t is not None )
526+ return (recv_x , recv_x_scales ) if x_scales is not None else recv_x , recv_topk_idx , recv_topk_weights , num_recv_tokens_per_expert_list , handle , EventOverlap (event , tensors_to_record )
527+ else :
528+ return (recv_x , recv_x_scales ) if x_scales is not None else recv_x , recv_topk_idx , recv_topk_weights , num_recv_tokens_per_expert_list , handle , EventOverlap (event )
501529
502530 # noinspection PyTypeChecker
503531 def internode_combine (self , x : torch .Tensor , handle : Union [tuple , list ],
@@ -527,7 +555,13 @@ def internode_combine(self, x: torch.Tensor, handle: Union[tuple, list],
527555 send_rdma_head , send_nvl_head , config ,
528556 getattr (previous_event , 'event' ,
529557 None ), async_finish , allocate_on_comm_stream )
530- return combined_x , combined_topk_weights , EventOverlap (event )
558+ if async_finish :
559+ tensors_to_record = tuple (t for t in (x , topk_weights , bias_0 , bias_1 , src_meta , is_combined_token_in_rank ,
560+ rdma_channel_prefix_matrix , rdma_rank_prefix_sum , gbl_channel_prefix_matrix ,
561+ send_rdma_head , send_nvl_head , combined_x , combined_topk_weights ) if t is not None )
562+ return combined_x , combined_topk_weights , EventOverlap (event , tensors_to_record )
563+ else :
564+ return combined_x , combined_topk_weights , EventOverlap (event )
531565
532566 def clean_low_latency_buffer (self , num_max_dispatch_tokens_per_rank : int , hidden : int , num_experts : int ) -> None :
533567 """
0 commit comments