@@ -85,7 +85,7 @@ def _moe_problem_size(
8585        M  =  a1 .size (0 )
8686    else :
8787        assert  a1 .dim () ==  3 
88-         # assert a1.size(0) == E, f"{a1.size(0)} == {E}"
88+         assert  a1 .size (0 ) ==  E , f"{ a1 .size (0 )} { E }  
8989        M  =  a1 .size (1 )  # This is max_num_tokens 
9090
9191    assert  topk_ids .dim () ==  2 
@@ -536,11 +536,12 @@ def apply(
536536        global_num_experts : int ,
537537        expert_map : Optional [torch .Tensor ],
538538        a1q_scale : Optional [torch .Tensor ],
539+         a2_scale : Optional [torch .Tensor ],
539540        workspace13 : torch .Tensor ,
540541        workspace2 : torch .Tensor ,
541542        expert_tokens_meta : Optional [ExpertTokensMetadata ],
542543        apply_router_weight_on_input : bool ,
543-     ):
544+     )  ->   None :
544545        """ 
545546        This function computes the intermediate result of a Mixture of Experts 
546547        (MoE) layer using two sets of weights, w1 and w2. 
@@ -674,22 +675,22 @@ def _allocate_buffers(
674675
675676        # We can reuse the memory between cache1 and cache3 because by the 
676677        # time we need cache3, we're done with cache1. 
677-         workspace13  =  torch . zeros ( prod ( workspace13_shape ) ,
678-                                   device = device ,
679-                                   dtype = workspace_dtype )
680-         workspace2  =  torch . zeros ( prod ( workspace2_shape ) ,
681-                                  device = device ,
682-                                  dtype = workspace_dtype )
678+         workspace13  =  self . workspace13_buffer . get ( workspace13_shape ,
679+                                                    device = device ,
680+                                                    dtype = workspace_dtype )
681+         workspace2  =  self . workspace2_buffer . get ( workspace2_shape ,
682+                                                  device = device ,
683+                                                  dtype = workspace_dtype )
683684
684685        # Construct the entire output that can then be processed in chunks. 
685686        if  num_chunks  ==  1  and  prod (workspace13_shape ) >=  prod (
686687                fused_out_shape ):
687688            # Reuse workspace13 for the output in the non-chunked case. 
688689            fused_out  =  _resize_cache (workspace13 , fused_out_shape )
689690        else :
690-             fused_out  =  torch . empty (fused_out_shape ,
691-                                     device = device ,
692-                                     dtype = out_dtype )
691+             fused_out  =  self . fused_out_buffer . get (fused_out_shape ,
692+                                                    device = device ,
693+                                                    dtype = out_dtype )
693694
694695        return  workspace13 , workspace2 , fused_out 
695696
@@ -785,7 +786,10 @@ def forward(
785786        - torch.Tensor: The output tensor after applying the MoE layer. 
786787        """ 
787788
788-         output  =  hidden_states  if  inplace  else  torch .zeros_like (hidden_states )
789+         if  inplace  and  self .shared_experts  is  None :
790+             output  =  hidden_states 
791+         else :
792+             output  =  torch .zeros_like (hidden_states )
789793
790794        local_num_experts  =  w1 .size (0 )
791795        if  global_num_experts  ==  - 1 :
@@ -799,8 +803,6 @@ def forward(
799803            (a1q , a1q_scale , expert_tokens_meta , _expert_topk_ids ,
800804             _expert_topk_weights ) =  self .prepare_finalize .prepare (
801805                 hidden_states ,
802-                  a1_scale ,
803-                  a2_scale ,
804806                 topk_weights ,
805807                 topk_ids ,
806808                 global_num_experts ,
@@ -810,10 +812,9 @@ def forward(
810812             )
811813        else :
812814            # Overlap shared expert compute with all2all dispatch. 
813-             receiver  =  self .prepare_finalize .prepare_async (
815+             dbo_maybe_run_recv_hook ()
816+             hook , receiver  =  self .prepare_finalize .prepare_async (
814817                hidden_states ,
815-                 a1_scale ,
816-                 a2_scale ,
817818                topk_weights ,
818819                topk_ids ,
819820                global_num_experts ,
@@ -838,6 +839,8 @@ def forward(
838839        topk_weights  =  (topk_weights  if  _expert_topk_weights  is  None  else 
839840                        _expert_topk_weights )
840841
842+         fused_out  =  None 
843+ 
841844        if  a1q .numel () ==  0 :
842845            # This happens when none of the tokens from the all2all reach this 
843846            # EP rank. Also, note that this is only relevant for CUDAGraph 
@@ -853,7 +856,7 @@ def forward(
853856                CHUNK_SIZE  =  envs .VLLM_FUSED_MOE_CHUNK_SIZE 
854857                num_chunks  =  cdiv (M , CHUNK_SIZE )
855858            else :
856-                 CHUNK_SIZE  =  M  #a1q.size(0) 
859+                 CHUNK_SIZE  =  M    #a1q.size(0) 
857860                num_chunks  =  1 
858861
859862            def  input_chunk_range (chunk_idx : int ) ->  tuple [int , int ]:
@@ -892,12 +895,8 @@ def input_chunk_range(chunk_idx: int) -> tuple[int, int]:
892895                    activation = activation ,
893896                    global_num_experts = global_num_experts ,
894897                    expert_map = expert_map ,
895-                     w1_scale = w1_scale ,
896-                     w2_scale = w2_scale ,
897-                     w1_zp = w1_zp ,
898-                     w2_zp = w2_zp ,
899898                    a1q_scale = _chunk_scales (a1q_scale , s , e ),
900-                     a2_scale = _chunk_scales (a2_scale , e , e ),
899+                     a2_scale = _chunk_scales (self . fused_experts . a2_scale , e , e ),
901900                    workspace13 = workspace13 ,
902901                    workspace2 = workspace2 ,
903902                    expert_tokens_meta = c_expert_tokens_meta ,
@@ -918,7 +917,7 @@ def input_chunk_range(chunk_idx: int) -> tuple[int, int]:
918917                self .fused_experts .finalize_weight_and_reduce_impl (),
919918            )
920919            if  self .shared_experts  is  not None :
921-                 shared_output  =  self .shared_experts (a1 )
920+                 shared_output  =  self .shared_experts (hidden_states )
922921        else :
923922            recv_hook  =  self .prepare_finalize .finalize_async (
924923                output ,
@@ -930,7 +929,7 @@ def input_chunk_range(chunk_idx: int) -> tuple[int, int]:
930929            )
931930
932931            if  self .shared_experts  is  not None :
933-                 shared_output  =  self .shared_experts (a1 )
932+                 shared_output  =  self .shared_experts (hidden_states )
934933
935934            assert  recv_hook  is  not None 
936935            dbo_register_recv_hook (recv_hook )
0 commit comments