@@ -374,10 +374,9 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
374374
375375 # Internode
376376 if self .runtime .get_num_rdma_ranks () > 1 :
377- assert num_worst_tokens == 0 , 'Internode dispatch does not support `num_worst_tokens > 0`'
378377 return self .internode_dispatch (x , handle , num_tokens_per_rank , num_tokens_per_rdma_rank , is_token_in_rank ,
379- num_tokens_per_expert , topk_idx , topk_weights , expert_alignment , config , previous_event ,
380- async_finish , allocate_on_comm_stream )
378+ num_tokens_per_expert , topk_idx , topk_weights , expert_alignment , num_worst_tokens , config ,
379+ previous_event , async_finish , allocate_on_comm_stream )
381380
382381 # Launch the kernel with cached or non-cached mode
383382 x , x_scales = x if isinstance (x , tuple ) else (x , None )
@@ -456,7 +455,7 @@ def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Te
456455 num_tokens_per_rank : Optional [torch .Tensor ] = None , num_tokens_per_rdma_rank : Optional [torch .Tensor ] = None ,
457456 is_token_in_rank : Optional [torch .Tensor ] = None , num_tokens_per_expert : Optional [torch .Tensor ] = None ,
458457 topk_idx : Optional [torch .Tensor ] = None , topk_weights : Optional [torch .Tensor ] = None , expert_alignment : int = 1 ,
459- config : Optional [Config ] = None ,
458+ num_worst_tokens : int = 0 , config : Optional [Config ] = None ,
460459 previous_event : Optional [EventOverlap ] = None , async_finish : bool = False ,
461460 allocate_on_comm_stream : bool = False ) -> \
462461 Tuple [Union [Tuple [torch .Tensor , torch .Tensor ], torch .Tensor ], Optional [torch .Tensor ],
@@ -480,7 +479,7 @@ def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Te
480479 recv_x , recv_x_scales , _ , _ , _ , _ , _ , _ , _ , _ , _ , _ , _ , _ , event = self .runtime .internode_dispatch (
481480 x , x_scales , topk_idx , topk_weights , None , None , is_token_in_rank , None , num_recv_tokens , num_rdma_recv_tokens ,
482481 rdma_channel_prefix_matrix , recv_rdma_rank_prefix_sum , gbl_channel_prefix_matrix , recv_gbl_rank_prefix_sum ,
483- expert_alignment , config , getattr (previous_event , 'event' , None ), async_finish , allocate_on_comm_stream )
482+ expert_alignment , num_worst_tokens , config , getattr (previous_event , 'event' , None ), async_finish , allocate_on_comm_stream )
484483 return (recv_x , recv_x_scales ) if x_scales is not None else recv_x , None , None , None , None , EventOverlap (event )
485484 else :
486485 assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
@@ -492,7 +491,7 @@ def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Te
492491 x , x_scales , topk_idx , topk_weights ,
493492 num_tokens_per_rank , num_tokens_per_rdma_rank , is_token_in_rank , num_tokens_per_expert ,
494493 0 , 0 , None , None , None , None ,
495- expert_alignment , config , getattr (previous_event , 'event' , None ), async_finish , allocate_on_comm_stream )
494+ expert_alignment , num_worst_tokens , config , getattr (previous_event , 'event' , None ), async_finish , allocate_on_comm_stream )
496495 handle = (is_token_in_rank , rdma_channel_prefix_matrix , gbl_channel_prefix_matrix , recv_rdma_channel_prefix_matrix ,
497496 recv_rdma_rank_prefix_sum , recv_gbl_channel_prefix_matrix , recv_gbl_rank_prefix_sum , recv_src_meta , send_rdma_head ,
498497 send_nvl_head )
@@ -526,7 +525,8 @@ def internode_combine(self, x: torch.Tensor, handle: Union[tuple, list],
526525 combined_x , combined_topk_weights , event = self .runtime .internode_combine (x , topk_weights , bias_0 , bias_1 , src_meta ,
527526 is_combined_token_in_rank , rdma_channel_prefix_matrix ,
528527 rdma_rank_prefix_sum , gbl_channel_prefix_matrix ,
529- send_rdma_head , send_nvl_head , config ,
528+ gbl_rank_prefix_sum , send_rdma_head ,
529+ send_nvl_head , config ,
530530 getattr (previous_event , 'event' ,
531531 None ), async_finish , allocate_on_comm_stream )
532532 return combined_x , combined_topk_weights , EventOverlap (event )
0 commit comments