Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions fastdeploy/model_executor/models/ernie4_5_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import re
from functools import partial
from typing import Dict, Union

Expand Down Expand Up @@ -250,7 +251,7 @@ def __init__(
self.embed_tokens = fd_config.speculative_config.sharing_model.ernie.embed_tokens
self.norm = fd_config.speculative_config.sharing_model.ernie.norm

self.layers = nn.LayerList(
self.mtp_block = nn.LayerList(
[
Ernie4_5_DecoderLayer(
fd_config=fd_config,
Expand Down Expand Up @@ -296,7 +297,7 @@ def load_state_dict(self, state_dict):
self.eh_proj.load_state_dict(state_dict)
for i in range(self.num_layers):
logger.info(f"Start load layer {i}")
self.layers[i].load_state_dict(state_dict)
self.mtp_block[i].load_state_dict(state_dict)

def forward(
self,
Expand All @@ -315,7 +316,7 @@ def forward(
hidden_states = self.eh_proj(inputs_embedding)
residual = None
for i in range(self.num_layers):
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
hidden_states, residual = self.mtp_block[i](forward_meta, hidden_states, residual)

hidden_states = hidden_states + residual

Expand Down Expand Up @@ -374,17 +375,23 @@ def load_weights(self, weights_iterator) -> None:
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
"""

from fastdeploy.model_executor.utils import default_weight_loader
from fastdeploy.model_executor.utils import (
default_weight_loader,
process_weights_after_loading,
)

all_param_mapping = [
# (param_name, weight_name, expert_id, shard_id)
("embed_tokens.embeddings", "embed_tokens", None, None),
("lm_head.linear", "lm_head", None, None),
("enorm", "mtp_emb_norm.0", None, None),
("hnorm", "mtp_hidden_norm.0", None, None),
("eh_proj.linear", "mtp_linear_proj.0", None, None),
]

params_dict = dict(self.named_parameters())
shard_id = None

process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
for loaded_weight_name, loaded_weight in weights_iterator:
for param_name, weight_name, exp_id, shard_id in all_param_mapping:
if weight_name not in loaded_weight_name:
Expand All @@ -396,11 +403,16 @@ def load_weights(self, weights_iterator) -> None:
else:
if loaded_weight_name not in params_dict.keys():
continue
model_param_name = loaded_weight_name
param = params_dict[loaded_weight_name]

# Get weight loader from parameter and set weight
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight)
model_sublayer_name = re.sub(
r"\.(up_gate_proj_weight|down_proj_weight|weight|cache_k_scale|cache_v_scale)$", "", model_param_name
)
process_weights_after_loading_fn(model_sublayer_name, param)

def compute_logits(self, hidden_states: paddle.Tensor):
"""
Expand Down
Loading