@@ -790,81 +790,6 @@ def forward(
790790 # Fully Connected
791791 hidden_states , residual = self .post_attention_layernorm (
792792 hidden_states , residual )
793- if self .role is not None and self .layer_idx >= self .first_k_dense_replace :
794- # --------- ffn need data
795- moe_comm_type = forward_ctx .moe_comm_type
796- num_tokens = hidden_states .shape [0 ]
797- with_prefill = forward_ctx .with_prefill
798- num_actual_tokens = None
799- ffn_need_forward_data = FFNNeedForwardData (moe_comm_type ,num_tokens ,with_prefill ,num_actual_tokens )
800-
801- # if self.is_sequence_parallel:
802- # hidden_states = sequence_parallel_chunk(hidden_states)
803-
804- # router_logits: (num_tokens, n_experts)
805- #
806- router_logits , _ = self .gate (hidden_states )
807- # 这里增加调用gating的逻辑,到时候要一并传输hiddenstates,这里直接用fused_moe里的函数,不用把函数拿出来
808- topk_weights , topk_ids , row_idx = select_experts (
809- hidden_states = hidden_states ,
810- router_logits = router_logits ,
811- top_k = 8 ,
812- use_grouped_topk = False ,
813- renormalize = True ,
814- )
815-
816- topk_weights = topk_weights .to (torch .float )
817- # print(f'topk_weights shape is {topk_weights.shape},dtype is {topk_weights.dtype}')
818- # print(f'topk_ids shape is {topk_ids.shape},dtype is {topk_ids.dtype}')
819- # print(f'row_idx shape is {row_idx.shape},dtype is {row_idx.dtype}')
820-
821- if self .connector_name == "m2nconnector" :
822- from vllm_ascend .distributed .M2NAFDConnector import M2NAFDConnectorMetadata
823- m2n_afdconnector_data = M2NAFDConnectorMetadata ()
824- m2n_afdconnector_data .moe_expert_num = 64
825- m2n_afdconnector_data .quant_mode = 0
826- m2n_afdconnector_data .aiv_num = 48
827- m2n_afdconnector_data .scale = None
828-
829- if self .connector_name == "camconnector" :
830- cam_afdconnector_data = CAMAFDConnectorMetadata (
831- moe_expert_num = 64 ,
832- shared_expert_num = 0 ,
833- scale = None ,
834- handle = None ,
835- quant_mode = 0 ,
836- aiv_num = 48 ,
837- batch_size = hidden_states .shape [0 ],
838- h = 2048 ,
839- k = 6
840- )
841- # TODO(yxj):每推理一个token 传一次
842- metadata = AFDConnectorMetadata .create_attention_metadata (
843- layer_idx = self .layer_idx ,
844- stage_idx = 0 ,
845- seq_len = hidden_states .shape [0 ],
846- dtype = hidden_states .dtype ,
847- device = hidden_states .device ,
848- ffn_need_forward_data = ffn_need_forward_data ,
849- m2n_afdconnector_data = m2n_afdconnector_data if self .connector_name == "m2nconnector" else None ,
850- cam_afdconnector_data = cam_afdconnector_data if self .connector_name == "camconnector" else None ,
851- )
852-
853- if self .connector_name == "m2nconnector" :
854- handle = afd_connector .send_attn_output (hidden_states ,topk_weights ,topk_ids ,metadata )
855- # print(f'send_attn_output success ,layer id is {self.layer_idx}')
856- metadata .m2n_afdconnector_data .handle = handle
857- hidden_states = afd_connector .recv_ffn_output (hidden_states ,metadata )
858- # print(f'recv_ffn_output success ,layer id is {self.layer_idx}')
859- elif self .connector_name == "camconnector" :
860- afd_connector .send_attn_output (hidden_states , topk_weights , topk_ids , metadata )
861- hidden_states = afd_connector .recv_ffn_output (metadata )
862- else :
863- afd_connector .send_attn_output (hidden_states ,router_logits ,topk_weights , topk_ids , row_idx , metadata )
864- hidden_states , _ = afd_connector .recv_ffn_output ()
865- if self .role == "attention" :
866- return hidden_states , residual
867-
868793 hidden_states = self .mlp (hidden_states )
869794
870795 if isinstance (self .mlp ,
@@ -877,7 +802,7 @@ def forward(
877802 hidden_states *= 1. / self .routed_scaling_factor
878803
879804 return hidden_states , residual
880-
805+
881806 def compute_attn_output (
882807 self ,
883808 positions : torch .Tensor ,
@@ -908,8 +833,24 @@ def compute_attn_output(
908833 # Fully Connected
909834 hidden_states , residual = self .post_attention_layernorm (
910835 hidden_states , residual )
836+
837+ topk_weights = None
838+ topk_ids = None
839+ row_idx = None
840+ # Compute gate on attention side.
841+ if self .layer_idx >= self .first_k_dense_replace and self .afd_config .compute_gate_on_attention :
842+ router_logits , _ = self .gate (hidden_states )
843+ topk_weights , topk_ids , row_idx = select_experts (
844+ hidden_states = hidden_states ,
845+ router_logits = router_logits ,
846+ top_k = 8 ,
847+ use_grouped_topk = False ,
848+ renormalize = True ,
849+ )
850+
851+ topk_weights = topk_weights .to (torch .float )
911852
912- return hidden_states , residual
853+ return hidden_states , residual , topk_weights , topk_ids , row_idx
913854
914855 def compute_ffn_output (self ,
915856 hidden_states : torch .Tensor ,
@@ -954,6 +895,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
954895 config = vllm_config .model_config .hf_config
955896 quant_config = vllm_config .quant_config
956897 self .config = config
898+ self .first_k_dense_replace = config .first_k_dense_replace
899+ self .afd_config = vllm_config .afd_config
900+ self .connector_name = self .afd_config .afd_connector if self .afd_config is not None else None
957901
958902 self .vocab_size = config .vocab_size
959903
@@ -993,15 +937,28 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
993937 def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
994938 return self .embed_tokens (input_ids )
995939
996- def forward_with_afd (
940+ def forward_m2n (
997941 self ,
998942 hidden_states : torch .Tensor ,
999943 residual : torch .Tensor ,
1000944 positions : torch .Tensor ,
1001945 afd_metadata : AFDMetadata
1002- ) -> tuple [torch .Tensor , torch .Tensor ]:
946+ )-> tuple [torch .Tensor , torch .Tensor ]:
1003947 recv_handle = None
948+
949+ forward_ctx = get_forward_context ()
950+ moe_comm_type = forward_ctx .moe_comm_type
951+ num_tokens = hidden_states .shape [0 ]
952+ with_prefill = forward_ctx .with_prefill
953+ num_actual_tokens = None
954+ ffn_need_forward_data = FFNNeedForwardData (moe_comm_type ,num_tokens ,with_prefill ,num_actual_tokens )
955+
1004956 for layer in islice (self .layers , self .start_layer , self .end_layer ):
957+ # Compute dense layers on attn side.
958+ if layer .layer_idx < self .first_k_dense_replace :
959+ hidden_states , residual = layer (positions , hidden_states , residual )
960+ continue
961+
1005962 logger .info (f"jcz deepseekv2 layer_idx:{ layer .layer_idx } metadata:{ afd_metadata } hidden_states:{ hidden_states .shape } " )
1006963 afd_connector = afd_metadata .afd_connector
1007964 afd_metadata .afd_stage_idx = dbo_current_ubatch_id ()
@@ -1010,21 +967,55 @@ def forward_with_afd(
1010967 logger .info (f"jcz deepseekv2 layer_idx:{ layer .layer_idx } start_loc:{ afd_metadata .afd_tokens_start_loc } "
1011968 f"start_idx:{ start_idx } end_idx:{ end_idx } "
1012969 f"stage_idx:{ afd_metadata .afd_stage_idx } " )
970+
1013971 if recv_handle is not None :
1014972 for work in recv_handle :
1015973 work .wait ()
1016- current_hidden , residual = layer (positions , hidden_states , residual )
974+
975+ current_hidden , residual , topk_weights , topk_ids , row_idx = \
976+ layer .compute_attn_output (positions , hidden_states , residual )
977+ if self .connector_name == "m2nconnector" :
978+ from vllm_ascend .distributed .M2NAFDConnector import M2NAFDConnectorMetadata
979+ m2n_afdconnector_data = M2NAFDConnectorMetadata ()
980+ m2n_afdconnector_data .moe_expert_num = 64
981+ m2n_afdconnector_data .quant_mode = 0
982+ m2n_afdconnector_data .aiv_num = 48
983+ m2n_afdconnector_data .scale = None
984+ if self .connector_name == "camconnector" :
985+ cam_afdconnector_data = CAMAFDConnectorMetadata (
986+ moe_expert_num = 64 ,
987+ shared_expert_num = 0 ,
988+ scale = None ,
989+ handle = None ,
990+ quant_mode = 0 ,
991+ aiv_num = 48 ,
992+ batch_size = hidden_states .shape [0 ],
993+ h = 2048 ,
994+ k = 6
995+ )
996+
1017997 metadata = AFDConnectorMetadata .create_attention_metadata (
1018998 layer_idx = layer .layer_idx ,
1019999 stage_idx = afd_metadata .afd_stage_idx ,
1020- seq_len = current_hidden .shape [0 ],
1021- dtype = current_hidden .dtype ,
1022- device = current_hidden .device ,
1000+ seq_len = hidden_states .shape [0 ],
1001+ dtype = hidden_states .dtype ,
1002+ device = hidden_states .device ,
1003+ ffn_need_forward_data = ffn_need_forward_data ,
1004+ m2n_afdconnector_data = m2n_afdconnector_data if self .connector_name == "m2nconnector" else None ,
1005+ cam_afdconnector_data = cam_afdconnector_data if self .connector_name == "camconnector" else None ,
10231006 )
1024- afd_connector .send_attn_output (current_hidden , metadata )
1025- hidden_states , recv_metadata = afd_connector .recv_ffn_output ()
1026- if recv_metadata .recv_handle_list is not None :
1027- recv_handle = recv_metadata .recv_handle_list
1007+
1008+ if self .connector_name == "m2nconnector" :
1009+ handle = afd_connector .send_attn_output (current_hidden ,topk_weights ,topk_ids ,metadata )
1010+ metadata .m2n_afdconnector_data .handle = handle
1011+ hidden_states , recv_handle = afd_connector .recv_ffn_output (hidden_states ,metadata )
1012+ elif self .connector_name == "camconnector" :
1013+ afd_connector .send_attn_output (current_hidden , topk_weights , topk_ids , metadata )
1014+ hidden_states = afd_connector .recv_ffn_output (metadata )
1015+ else :
1016+ afd_connector .send_attn_output (current_hidden , router_logits , topk_weights , topk_ids , row_idx , metadata )
1017+ hidden_states , _ = afd_connector .recv_ffn_output ()
1018+
10281019 if dbo_enabled ():
10291020 dbo_yield ()
10301021 return hidden_states , residual
@@ -1048,18 +1039,15 @@ def forward(
10481039 residual = intermediate_tensors ["residual" ]
10491040
10501041 # TODO(jcz): later need fix this
1051- # forward_ctx = get_forward_context()
1052- # afd_metadata = (forward_ctx.afd_metadata
1053- # if forward_ctx is not None else None)
1054- # if afd_metadata != None:
1055- # hidden_states, residual = self.forward_with_afd(hidden_states, residual,
1056- # positions, afd_metadata)
1057- # else:
1058- # for layer in islice(self.layers, self.start_layer, self.end_layer):
1059- # hidden_states, residual = layer(positions, hidden_states, residual)
1042+ forward_ctx = get_forward_context ()
1043+ afd_metadata = (forward_ctx .afd_metadata
1044+ if forward_ctx is not None else None )
1045+ if afd_metadata != None :
1046+ hidden_states , residual = self .forward_m2n (hidden_states , residual , positions , afd_metadata )
1047+ else :
1048+ for layer in islice (self .layers , self .start_layer , self .end_layer ):
1049+ hidden_states , residual = layer (positions , hidden_states , residual )
10601050
1061- for layer in islice (self .layers , self .start_layer , self .end_layer ):
1062- hidden_states , residual = layer (positions , hidden_states , residual )
10631051
10641052 if not get_pp_group ().is_last_rank :
10651053 return IntermediateTensors ({
0 commit comments