Skip to content

Commit 3552481

Browse files
Merge pull request vllm-project#2 from jiangkuaixue123/jcz_afd_v0.11.0rc3_dev
modify deepseekv2model and decodelayer forward
2 parents c9cbf11 + 1a8b6c6 commit 3552481

File tree

2 files changed

+113
-98
lines changed

2 files changed

+113
-98
lines changed

vllm/distributed/afd_transfer/afd_connector/metadata.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,39 @@
99

1010
import torch
1111

12+
from abc import ABC, abstractmethod
13+
1214
#TODO(yxj):move to AFDExtraFields
1315
from vllm_ascend.ascend_forward_context import MoECommType
1416
from dataclasses import dataclass, field
1517
from typing import Dict
1618

1719

20+
class AFDRecvHandle(ABC):
21+
"""
22+
Abstract base class for AFD receive handles.
23+
24+
This provides a handle interface for managing asynchronous AFD operations,
25+
allowing waiting for completion of data transfer operations.
26+
"""
27+
@abstractmethod
28+
def __init__(self, handle: Any):
29+
"""Initialize the AFD receive handle.
30+
31+
Args:
32+
handle: Backend-specific handle object
33+
"""
34+
raise NotImplementedError
35+
36+
@abstractmethod
37+
def wait(self):
38+
"""Wait for the operation associated with this handle to complete.
39+
40+
Blocks until the data transfer or computation is finished.
41+
"""
42+
raise NotImplementedError
43+
44+
1845
class FFNNeedForwardData:
1946
def __init__(self,
2047
moe_comm_type:Optional[MoECommType] = None,

vllm/model_executor/models/deepseek_v2.py

Lines changed: 86 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -790,81 +790,6 @@ def forward(
790790
# Fully Connected
791791
hidden_states, residual = self.post_attention_layernorm(
792792
hidden_states, residual)
793-
if self.role is not None and self.layer_idx >= self.first_k_dense_replace:
794-
# --------- ffn need data
795-
moe_comm_type = forward_ctx.moe_comm_type
796-
num_tokens = hidden_states.shape[0]
797-
with_prefill = forward_ctx.with_prefill
798-
num_actual_tokens = None
799-
ffn_need_forward_data = FFNNeedForwardData(moe_comm_type,num_tokens,with_prefill,num_actual_tokens)
800-
801-
# if self.is_sequence_parallel:
802-
# hidden_states = sequence_parallel_chunk(hidden_states)
803-
804-
# router_logits: (num_tokens, n_experts)
805-
#
806-
router_logits, _ = self.gate(hidden_states)
807-
# 这里增加调用gating的逻辑,到时候要一并传输hiddenstates,这里直接用fused_moe里的函数,不用把函数拿出来
808-
topk_weights, topk_ids, row_idx = select_experts(
809-
hidden_states=hidden_states,
810-
router_logits=router_logits,
811-
top_k=8,
812-
use_grouped_topk=False,
813-
renormalize=True,
814-
)
815-
816-
topk_weights = topk_weights.to(torch.float)
817-
# print(f'topk_weights shape is {topk_weights.shape},dtype is {topk_weights.dtype}')
818-
# print(f'topk_ids shape is {topk_ids.shape},dtype is {topk_ids.dtype}')
819-
# print(f'row_idx shape is {row_idx.shape},dtype is {row_idx.dtype}')
820-
821-
if self.connector_name == "m2nconnector":
822-
from vllm_ascend.distributed.M2NAFDConnector import M2NAFDConnectorMetadata
823-
m2n_afdconnector_data = M2NAFDConnectorMetadata()
824-
m2n_afdconnector_data.moe_expert_num = 64
825-
m2n_afdconnector_data.quant_mode = 0
826-
m2n_afdconnector_data.aiv_num = 48
827-
m2n_afdconnector_data.scale = None
828-
829-
if self.connector_name == "camconnector":
830-
cam_afdconnector_data = CAMAFDConnectorMetadata(
831-
moe_expert_num = 64,
832-
shared_expert_num = 0,
833-
scale = None,
834-
handle = None,
835-
quant_mode = 0,
836-
aiv_num = 48,
837-
batch_size = hidden_states.shape[0],
838-
h = 2048,
839-
k = 6
840-
)
841-
# TODO(yxj):每推理一个token 传一次
842-
metadata = AFDConnectorMetadata.create_attention_metadata(
843-
layer_idx=self.layer_idx,
844-
stage_idx=0,
845-
seq_len=hidden_states.shape[0],
846-
dtype=hidden_states.dtype,
847-
device=hidden_states.device,
848-
ffn_need_forward_data=ffn_need_forward_data,
849-
m2n_afdconnector_data=m2n_afdconnector_data if self.connector_name == "m2nconnector" else None,
850-
cam_afdconnector_data=cam_afdconnector_data if self.connector_name == "camconnector" else None,
851-
)
852-
853-
if self.connector_name == "m2nconnector":
854-
handle = afd_connector.send_attn_output(hidden_states,topk_weights,topk_ids,metadata)
855-
# print(f'send_attn_output success ,layer id is {self.layer_idx}')
856-
metadata.m2n_afdconnector_data.handle = handle
857-
hidden_states = afd_connector.recv_ffn_output(hidden_states,metadata)
858-
# print(f'recv_ffn_output success ,layer id is {self.layer_idx}')
859-
elif self.connector_name == "camconnector":
860-
afd_connector.send_attn_output(hidden_states, topk_weights, topk_ids, metadata)
861-
hidden_states = afd_connector.recv_ffn_output(metadata)
862-
else:
863-
afd_connector.send_attn_output(hidden_states,router_logits,topk_weights, topk_ids, row_idx, metadata)
864-
hidden_states, _ = afd_connector.recv_ffn_output()
865-
if self.role == "attention":
866-
return hidden_states, residual
867-
868793
hidden_states = self.mlp(hidden_states)
869794

870795
if isinstance(self.mlp,
@@ -877,7 +802,7 @@ def forward(
877802
hidden_states *= 1. / self.routed_scaling_factor
878803

879804
return hidden_states, residual
880-
805+
881806
def compute_attn_output(
882807
self,
883808
positions: torch.Tensor,
@@ -908,8 +833,24 @@ def compute_attn_output(
908833
# Fully Connected
909834
hidden_states, residual = self.post_attention_layernorm(
910835
hidden_states, residual)
836+
837+
topk_weights = None
838+
topk_ids = None
839+
row_idx = None
840+
# Compute gate on attention side.
841+
if self.layer_idx >= self.first_k_dense_replace and self.afd_config.compute_gate_on_attention:
842+
router_logits, _ = self.gate(hidden_states)
843+
topk_weights, topk_ids, row_idx = select_experts(
844+
hidden_states=hidden_states,
845+
router_logits=router_logits,
846+
top_k=8,
847+
use_grouped_topk=False,
848+
renormalize=True,
849+
)
850+
851+
topk_weights = topk_weights.to(torch.float)
911852

912-
return hidden_states, residual
853+
return hidden_states, residual, topk_weights, topk_ids, row_idx
913854

914855
def compute_ffn_output(self,
915856
hidden_states: torch.Tensor,
@@ -954,6 +895,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
954895
config = vllm_config.model_config.hf_config
955896
quant_config = vllm_config.quant_config
956897
self.config = config
898+
self.first_k_dense_replace = config.first_k_dense_replace
899+
self.afd_config = vllm_config.afd_config
900+
self.connector_name = self.afd_config.afd_connector if self.afd_config is not None else None
957901

958902
self.vocab_size = config.vocab_size
959903

@@ -993,15 +937,28 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
993937
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
994938
return self.embed_tokens(input_ids)
995939

996-
def forward_with_afd(
940+
def forward_m2n(
997941
self,
998942
hidden_states: torch.Tensor,
999943
residual: torch.Tensor,
1000944
positions: torch.Tensor,
1001945
afd_metadata: AFDMetadata
1002-
) -> tuple[torch.Tensor, torch.Tensor]:
946+
)-> tuple[torch.Tensor, torch.Tensor]:
1003947
recv_handle = None
948+
949+
forward_ctx = get_forward_context()
950+
moe_comm_type = forward_ctx.moe_comm_type
951+
num_tokens = hidden_states.shape[0]
952+
with_prefill = forward_ctx.with_prefill
953+
num_actual_tokens = None
954+
ffn_need_forward_data = FFNNeedForwardData(moe_comm_type,num_tokens,with_prefill,num_actual_tokens)
955+
1004956
for layer in islice(self.layers, self.start_layer, self.end_layer):
957+
# Compute dense layers on attn side.
958+
if layer.layer_idx < self.first_k_dense_replace:
959+
hidden_states, residual = layer(positions, hidden_states, residual)
960+
continue
961+
1005962
logger.info(f"jcz deepseekv2 layer_idx:{layer.layer_idx} metadata:{afd_metadata} hidden_states:{hidden_states.shape}")
1006963
afd_connector = afd_metadata.afd_connector
1007964
afd_metadata.afd_stage_idx = dbo_current_ubatch_id()
@@ -1010,21 +967,55 @@ def forward_with_afd(
1010967
logger.info(f"jcz deepseekv2 layer_idx:{layer.layer_idx} start_loc:{afd_metadata.afd_tokens_start_loc} "
1011968
f"start_idx:{start_idx} end_idx:{end_idx} "
1012969
f"stage_idx:{afd_metadata.afd_stage_idx}")
970+
1013971
if recv_handle is not None:
1014972
for work in recv_handle:
1015973
work.wait()
1016-
current_hidden, residual = layer(positions, hidden_states, residual)
974+
975+
current_hidden, residual, topk_weights, topk_ids, row_idx = \
976+
layer.compute_attn_output(positions, hidden_states, residual)
977+
if self.connector_name == "m2nconnector":
978+
from vllm_ascend.distributed.M2NAFDConnector import M2NAFDConnectorMetadata
979+
m2n_afdconnector_data = M2NAFDConnectorMetadata()
980+
m2n_afdconnector_data.moe_expert_num = 64
981+
m2n_afdconnector_data.quant_mode = 0
982+
m2n_afdconnector_data.aiv_num = 48
983+
m2n_afdconnector_data.scale = None
984+
if self.connector_name == "camconnector":
985+
cam_afdconnector_data = CAMAFDConnectorMetadata(
986+
moe_expert_num = 64,
987+
shared_expert_num = 0,
988+
scale = None,
989+
handle = None,
990+
quant_mode = 0,
991+
aiv_num = 48,
992+
batch_size = hidden_states.shape[0],
993+
h = 2048,
994+
k = 6
995+
)
996+
1017997
metadata = AFDConnectorMetadata.create_attention_metadata(
1018998
layer_idx=layer.layer_idx,
1019999
stage_idx=afd_metadata.afd_stage_idx,
1020-
seq_len=current_hidden.shape[0],
1021-
dtype=current_hidden.dtype,
1022-
device=current_hidden.device,
1000+
seq_len=hidden_states.shape[0],
1001+
dtype=hidden_states.dtype,
1002+
device=hidden_states.device,
1003+
ffn_need_forward_data=ffn_need_forward_data,
1004+
m2n_afdconnector_data=m2n_afdconnector_data if self.connector_name == "m2nconnector" else None,
1005+
cam_afdconnector_data=cam_afdconnector_data if self.connector_name == "camconnector" else None,
10231006
)
1024-
afd_connector.send_attn_output(current_hidden, metadata)
1025-
hidden_states, recv_metadata = afd_connector.recv_ffn_output()
1026-
if recv_metadata.recv_handle_list is not None:
1027-
recv_handle = recv_metadata.recv_handle_list
1007+
1008+
if self.connector_name == "m2nconnector":
1009+
handle = afd_connector.send_attn_output(current_hidden,topk_weights,topk_ids,metadata)
1010+
metadata.m2n_afdconnector_data.handle = handle
1011+
hidden_states, recv_handle = afd_connector.recv_ffn_output(hidden_states,metadata)
1012+
elif self.connector_name == "camconnector":
1013+
afd_connector.send_attn_output(current_hidden, topk_weights, topk_ids, metadata)
1014+
hidden_states = afd_connector.recv_ffn_output(metadata)
1015+
else:
1016+
afd_connector.send_attn_output(current_hidden, router_logits, topk_weights, topk_ids, row_idx, metadata)
1017+
hidden_states, _ = afd_connector.recv_ffn_output()
1018+
10281019
if dbo_enabled():
10291020
dbo_yield()
10301021
return hidden_states, residual
@@ -1048,18 +1039,15 @@ def forward(
10481039
residual = intermediate_tensors["residual"]
10491040

10501041
# TODO(jcz): later need fix this
1051-
# forward_ctx = get_forward_context()
1052-
# afd_metadata = (forward_ctx.afd_metadata
1053-
# if forward_ctx is not None else None)
1054-
# if afd_metadata != None:
1055-
# hidden_states, residual = self.forward_with_afd(hidden_states, residual,
1056-
# positions, afd_metadata)
1057-
# else:
1058-
# for layer in islice(self.layers, self.start_layer, self.end_layer):
1059-
# hidden_states, residual = layer(positions, hidden_states, residual)
1042+
forward_ctx = get_forward_context()
1043+
afd_metadata = (forward_ctx.afd_metadata
1044+
if forward_ctx is not None else None)
1045+
if afd_metadata != None:
1046+
hidden_states, residual = self.forward_m2n(hidden_states, residual, positions, afd_metadata)
1047+
else:
1048+
for layer in islice(self.layers, self.start_layer, self.end_layer):
1049+
hidden_states, residual = layer(positions, hidden_states, residual)
10601050

1061-
for layer in islice(self.layers, self.start_layer, self.end_layer):
1062-
hidden_states, residual = layer(positions, hidden_states, residual)
10631051

10641052
if not get_pp_group().is_last_rank:
10651053
return IntermediateTensors({

0 commit comments

Comments
 (0)