Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions paddleformers/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2978,7 +2978,12 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
with ContextManagers(init_contexts):
model = cls(config, *init_args, **model_kwargs)

if hasattr(cls, "_gen_aoa_config") and load_checkpoint_format == "flex_checkpoint":
if load_checkpoint_format == "flex_checkpoint":
if not hasattr(cls, "_gen_aoa_config"):
raise RuntimeError(
"When using flex_checkpoint to load Hugging Face open-source weights, "
"the model must implement the _gen_aoa_config function to provide checkpoint conversion rules."
)
aoa_config = cls._gen_aoa_config(config)
sharded_state_dict = model.sharded_state_dict()
dist.load_state_dict(
Expand Down Expand Up @@ -3216,7 +3221,12 @@ def save_pretrained(
# Only save the model in distributed training setup
model_to_save = unwrap_model(self)

if hasattr(self.__class__, "_gen_inv_aoa_config") and save_checkpoint_format == "flex_checkpoint":
if save_checkpoint_format == "flex_checkpoint":
if not hasattr(self.__class__, "_gen_inv_aoa_config"):
raise RuntimeError(
"When using flex_checkpoint to save Hugging Face weights, "
"the model must implement the _gen_inv_aoa_config function to provide checkpoint conversion rules."
)
aoa_config = self.__class__._gen_inv_aoa_config(model_to_save.config)

clean_unrelated_safetensors(save_dir)
Expand Down
25 changes: 20 additions & 5 deletions paddleformers/transformers/qwen3_moe/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,13 +633,15 @@ def _gen_aoa_config(cls, config: Qwen3MoeConfig):
model_prefix = "" if cls == cls.base_model_class else "model."
aoa_config = {
"aoa_statements": [
f"model.layers.$LAYER_ID.mlp.gate.weight^T -> {model_prefix}layers.$LAYER_ID.mlp.gate.weight, dtype='float32'",
f"model.layers.$LAYER_ID.mlp.gate.weight^T -> {model_prefix}layers.$LAYER_ID.mlp.gate.weight",
f"model.layers.$LAYER_ID.self_attn.o_proj.weight^T -> {model_prefix}layers.$LAYER_ID.self_attn.o_proj.weight",
f"model.layers.$LAYER_ID.mlp.experts.$EXPERT_ID.down_proj.weight^T -> {model_prefix}layers.$LAYER_ID.mlp.experts.$EXPERT_ID.down_proj.weight",
f"model.embed_tokens.weight -> {model_prefix}embed_tokens.weight",
f"model.layers.$LAYER_ID.input_layernorm.weight -> {model_prefix}layers.$LAYER_ID.input_layernorm.weight",
f"model.layers.$LAYER_ID.post_attention_layernorm.weight -> {model_prefix}layers.$LAYER_ID.post_attention_layernorm.weight",
f"model.norm.weight -> {model_prefix}norm.weight",
f"model.layers.$LAYER_ID.self_attn.k_norm.weight -> {model_prefix}layers.$LAYER_ID.self_attn.k_norm.weight",
f"model.layers.$LAYER_ID.self_attn.q_norm.weight -> {model_prefix}layers.$LAYER_ID.self_attn.q_norm.weight",
]
}

Expand Down Expand Up @@ -679,13 +681,17 @@ def _gen_aoa_config(cls, config: Qwen3MoeConfig):
def _gen_inv_aoa_config(cls, config: Qwen3MoeConfig):
model_prefix = "" if cls == cls.base_model_class else "model."
aoa_statements = [
f"{model_prefix}layers.$LAYER_ID.mlp.gate.weight^T -> model.layers.$LAYER_ID.mlp.gate.weight, dtype='bfloat16'",
# do cast
f"{model_prefix}layers.$LAYER_ID.mlp.gate.weight^T -> model.layers.$LAYER_ID.mlp.gate.weight",
# do transpose
f"{model_prefix}layers.$LAYER_ID.self_attn.o_proj.weight^T -> model.layers.$LAYER_ID.self_attn.o_proj.weight",
f"{model_prefix}layers.$LAYER_ID.mlp.experts.$EXPERT_ID.down_proj.weight^T -> model.layers.$LAYER_ID.mlp.experts.$EXPERT_ID.down_proj.weight",
f"{model_prefix}embed_tokens.weight -> model.embed_tokens.weight",
f"{model_prefix}layers.$LAYER_ID.input_layernorm.weight -> model.layers.$LAYER_ID.input_layernorm.weight",
f"{model_prefix}layers.$LAYER_ID.post_attention_layernorm.weight -> model.layers.$LAYER_ID.post_attention_layernorm.weight",
f"{model_prefix}norm.weight -> model.norm.weight",
f"{model_prefix}layers.$LAYER_ID.self_attn.k_norm.weight -> model.layers.$LAYER_ID.self_attn.k_norm.weight",
f"{model_prefix}layers.$LAYER_ID.self_attn.q_norm.weight -> model.layers.$LAYER_ID.self_attn.q_norm.weight",
]

if not config.fuse_attention_qkv:
Expand All @@ -703,9 +709,14 @@ def _gen_inv_aoa_config(cls, config: Qwen3MoeConfig):
]

aoa_statements += [
f"model.layers.$LAYER_ID.self_attn.{x}_proj.weight^T -> model.layers.$LAYER_ID.self_attn.{x}_proj.weight"
f"model.layers.{layer_idx}.self_attn.{x}_proj.weight^T -> model.layers.{layer_idx}.self_attn.{x}_proj.weight"
for layer_idx in range(config.num_hidden_layers)
for x in ("q", "k", "v")
]
if config.attention_bias:
aoa_statements += [
f"{model_prefix}layers.$LAYER_ID.self_attn.qkv_proj.bias -> model.layers.$LAYER_ID.self_attn.q_proj.bias, model.layers.$LAYER_ID.self_attn.k_proj.bias, model.layers.$LAYER_ID.self_attn.v_proj.bias , fused_qkv, num_heads={config.num_attention_heads}, num_key_value_groups = {config.num_key_value_heads}, axis=0",
]

if not config.fuse_attention_ffn:
aoa_statements += [
Expand All @@ -715,8 +726,12 @@ def _gen_inv_aoa_config(cls, config: Qwen3MoeConfig):
else:
aoa_statements += [
f"{model_prefix}layers.$LAYER_ID.mlp.experts.$EXPERT_ID.up_gate_proj.weight -> model.layers.$LAYER_ID.mlp.experts.$EXPERT_ID.gate_proj.weight, model.layers.$LAYER_ID.mlp.experts.$EXPERT_ID.up_proj.weight, fused_ffn",
"model.layers.$LAYER_ID.mlp.experts.$EXPERT_ID.gate_proj.weight^T -> model.layers.$LAYER_ID.mlp.experts.$EXPERT_ID.gate_proj.weight",
"model.layers.$LAYER_ID.mlp.experts.$EXPERT_ID.up_proj.weight^T -> model.layers.$LAYER_ID.mlp.experts.$EXPERT_ID.up_proj.weight",
]
aoa_statements += [
f"model.layers.{layer_idx}.mlp.experts.{e}.{y}_proj.weight^T -> model.layers.{layer_idx}.mlp.experts.{e}.{y}_proj.weight"
for layer_idx in range(config.num_hidden_layers)
for e in range(config.num_experts)
for y in ("gate", "up")
]

if config.tie_word_embeddings:
Expand Down
49 changes: 44 additions & 5 deletions tests/transformers/glm4_moe/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
import paddle
from parameterized import parameterized

from paddleformers.transformers import Glm4MoeConfig, Glm4MoeForCausalLM, Glm4MoeModel
from paddleformers.transformers import (
AutoConfig,
Glm4MoeConfig,
Glm4MoeForCausalLM,
Glm4MoeModel,
)
from tests.testing_utils import require_package
from tests.transformers.test_configuration_common import ConfigTester
from tests.transformers.test_generation_utils import GenerationTesterMixin
Expand Down Expand Up @@ -372,14 +377,49 @@ def test_model_name_list(self):
pass

def test_save_load(self):
model_path = "PaddleFormers/tiny-random-glm4moe"
for model_class in self.all_model_classes:
# test from_pretrained
model1 = model_class.from_pretrained(model_path, download_hub="aistudio", convert_from_hf=True)

model2 = model_class.from_pretrained(
model_path, download_hub="aistudio", load_checkpoint_format="flex_checkpoint"
)

model_state_1 = model1.state_dict()
model_state_2 = model2.state_dict()

for k, v in model_state_1.items():
md51 = v._md5sum()
md52 = model_state_2[k]._md5sum()
assert md51 == md52

# test save_pretrained
with tempfile.TemporaryDirectory() as tmpdirname:
model2.save_pretrained(tmpdirname, save_checkpoint_format="flex_checkpoint")
model3 = model_class.from_pretrained(tmpdirname, convert_from_hf=True)
model_state_3 = model3.state_dict()

for k, v in model_state_3.items():
md53 = v._md5sum()
md52 = model_state_2[k]._md5sum()
if k.endswith(".mlp.gate.weight"):
md52 = model_state_2[k].cast("bfloat16")._md5sum()
md53 = model_state_3[k].cast("bfloat16")._md5sum()
assert md52 == md53
# test fused_qkv and fused_ffn
for model_class in self.all_model_classes:
model_config = AutoConfig.from_pretrained(model_path)

model_config.fuse_attention_qkv = True
model_config.fuse_attention_ffn = True

model1 = model_class.from_pretrained(
"PaddleFormers/tiny-random-glm4moe", download_hub="aistudio", convert_from_hf=True
model_path, config=model_config, download_hub="aistudio", convert_from_hf=True
)

model2 = model_class.from_pretrained(
"PaddleFormers/tiny-random-glm4moe", download_hub="aistudio", load_checkpoint_format="flex_checkpoint"
model_path, config=model_config, download_hub="aistudio", load_checkpoint_format="flex_checkpoint"
)

model_state_1 = model1.state_dict()
Expand All @@ -393,7 +433,7 @@ def test_save_load(self):
# test save_pretrained
with tempfile.TemporaryDirectory() as tmpdirname:
model2.save_pretrained(tmpdirname, save_checkpoint_format="flex_checkpoint")
model3 = model_class.from_pretrained(tmpdirname, convert_from_hf=True)
model3 = model_class.from_pretrained(tmpdirname, config=model_config, convert_from_hf=True)
model_state_3 = model3.state_dict()

for k, v in model_state_3.items():
Expand All @@ -402,7 +442,6 @@ def test_save_load(self):
if k.endswith(".mlp.gate.weight"):
md52 = model_state_2[k].cast("bfloat16")._md5sum()
md53 = model_state_3[k].cast("bfloat16")._md5sum()
print(k)
assert md52 == md53

def test_hidden_states_output(self):
Expand Down
48 changes: 32 additions & 16 deletions tests/transformers/qwen3moe/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,24 +321,40 @@ def test_model_causal_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)

# def test_save_load(self):
# for model_class in self.all_model_classes:
# with tempfile.TemporaryDirectory() as tmpdirname:
# config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
# model = model_class(config)
# model.save_pretrained(tmpdirname, save_checkpoint_format="flex_checkpoint")

# model1 = model_class.from_pretrained(tmpdirname, convert_from_hf=True)

# model2 = model_class.from_pretrained(tmpdirname, load_checkpoint_format="flex_checkpoint")
def test_save_load(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个case直接写到父类吧,所有模型都要FC

for model_class in self.all_model_classes:
# test from_pretrained
model1 = model_class.from_pretrained(
"PaddleFormers/tiny-random-qwen3moev2",
download_hub="aistudio",
convert_from_hf=True,
)

# model_state_1 = model1.state_dict()
# model_state_2 = model2.state_dict()
model2 = model_class.from_pretrained(
"PaddleFormers/tiny-random-qwen3moev2",
download_hub="aistudio",
load_checkpoint_format="flex_checkpoint",
)

# for k, v in model_state_1.items():
# md51 = v._md5sum()
# md52 = model_state_2[k]._md5sum()
# assert md51 == md52
model_state_1 = model1.state_dict()
model_state_2 = model2.state_dict()

for k, v in model_state_1.items():
md51 = v._md5sum()
md52 = model_state_2[k]._md5sum()
assert md51 == md52

# test save_pretrained
with tempfile.TemporaryDirectory() as tmpdirname:
model2.save_pretrained(tmpdirname, save_checkpoint_format="flex_checkpoint")
model3 = model_class.from_pretrained(tmpdirname, convert_from_hf=True)
model_state_3 = model3.state_dict()

for k, v in model_state_3.items():
md53 = v._md5sum()
md52 = model_state_2[k]._md5sum()
md53 = model_state_3[k]._md5sum()
assert md52 == md53


class Qwen3MoeIntegrationTest(unittest.TestCase):
Expand Down
Loading