Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions paddleformers/cli/hparams/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ class ModelArguments:
default=False,
metadata={"help": "GPT3 model, use fast layernorm"},
)
fuse_attention_qkv: bool = field(
default=None,
metadata={"help": "whether to fuse attention qkv"},
)
fuse_attention_ffn: bool = field(
default=None,
metadata={"help": "whether to fuse first up and gate proj in mlp block"},
)
attn_impl: str = field(default="flashmask", metadata={"help": "Attention implementation"})
fuse_gate_detach_matmul: bool = field(
default=True,
Expand Down
2 changes: 1 addition & 1 deletion paddleformers/cli/hparams/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -
Returns:
_TRAIN_CLS: _description_
"""
parser = PdArgumentParser(_TRAIN_ARGS)
parser = PdArgumentParser(_TRAIN_ARGS, conflict_handler="resolve")
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)

Expand Down
52 changes: 38 additions & 14 deletions paddleformers/cli/train/auto_parallel/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@
from paddleformers.trainer.trainer import Trainer
from paddleformers.trainer.trainer_utils import set_seed
from paddleformers.transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForCausalLMPipe,
AutoTokenizer,
CosineAnnealingWithWarmupDecay,
LinearAnnealingWithWarmupDecay,
LlamaConfig,
LlamaForCausalLM,
)
from paddleformers.transformers.configuration_utils import LlmMetaConfig
from paddleformers.utils.log import logger
Expand Down Expand Up @@ -202,15 +203,8 @@ def run_auto_parallel(model_args, data_args, generating_args, training_args):
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)

# TODO: only support llama model now
config_class = LlamaConfig
model_class = LlamaForCausalLM

config = config_class.from_pretrained(model_args.model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
# config = AutoConfig.from_pretrained(model_args.model_name_or_path)
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
LlmMetaConfig.set_llm_config(config, training_args)
config.use_fast_layer_norm = model_args.use_fast_layer_norm

Expand Down Expand Up @@ -276,6 +270,13 @@ def run_auto_parallel(model_args, data_args, generating_args, training_args):
if training_args.no_recompute_layers is not None:
training_args.no_recompute_layers.sort()

if training_args.use_intermediate_api:
config.use_single_model_implementation = True
config.tensor_parallel_degree = 1
config.sharding_parallel_degree = 1
config.sep_parallel_degree = 1
config.context_parallel_degree = 1

print("Final pre-training config:", config)

# Set the dtype for loading model
Expand All @@ -286,9 +287,33 @@ def run_auto_parallel(model_args, data_args, generating_args, training_args):
if training_args.bf16:
dtype = "bfloat16"

with paddle.LazyGuard():
model = model_class.from_config(config, dtype=dtype)
criterion = model.criterion
model_class = AutoModelForCausalLM

if not training_args.enable_auto_parallel and training_args.pipeline_parallel_degree > 1:
model_class = AutoModelForCausalLMPipe

architectures_to_check = {"Qwen2Moe", "DeepseekV2", "DeepseekV3"}
if (
any(architecture in str(config.architectures) for architecture in architectures_to_check)
and training_args.data_parallel_degree > 1
):
training_args.use_expert_parallel = True

if model_args.continue_training:
if training_args.autotuner_benchmark:
model = model_class.from_config(config, dtype=dtype)
else:
model = model_class.from_pretrained(
model_args.model_name_or_path,
config=config,
dtype=dtype,
)
else:
if training_args.enable_auto_parallel:
with paddle.LazyGuard():
model = model_class.from_config(config, dtype=dtype)
else:
model = model_class.from_config(config, dtype=dtype)

if training_args.recompute:

Expand Down Expand Up @@ -344,7 +369,6 @@ def fn(layer):

trainer = PretrainingTrainer(
model=model,
criterion=criterion,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset if training_args.do_train else None,
Expand Down
1 change: 0 additions & 1 deletion paddleformers/trainer/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ def _add_dataclass_arguments(self, dtype: DataClassType):
f"removing line of `from __future__ import annotations` which opts in Postponed "
f"Evaluation of Annotations (PEP 563)"
)

for field in dataclasses.fields(dtype):
if not field.init:
continue
Expand Down
10 changes: 10 additions & 0 deletions paddleformers/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,9 @@ class PretrainedConfig:
Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
model has a output word embedding layer.

use_single_model_implementation (`bool`, *optional*, defaults to `False`):
Whether to run the model in single card mode. When enabled, all parallel degree configurations will be disabled.

dtype (`str`, *optional*):
The `dtype` of the weights. This attribute can be used to initialize the model to a non-default `dtype`
(which is normally `float32`) and thus allow for optimal storage allocation. For example, if the saved
Expand Down Expand Up @@ -609,6 +612,13 @@ def __init__(self, **kwargs):
self.use_cache = kwargs.pop("use_cache", False)
self.tie_word_embeddings = kwargs.pop("tie_word_embeddings", True)

# for run model in single card mode
self.use_single_model_implementation = kwargs.pop("use_single_model_implementation", False)
if self.use_single_model_implementation:
self.tensor_parallel_degree = 1
self.sep_parallel_degree = 1
self.context_parallel_degree = 1

# for transformers fuse
self.fuse_linear = kwargs.pop("fuse_linear", False)
self.fuse_attention_qkv = kwargs.pop("fuse_attention_qkv", False)
Expand Down
40 changes: 40 additions & 0 deletions paddleformers/transformers/llama/auto_dist_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle.distributed as dist


def get_dist_config(model, prefix=""):
"""Generate distributed configuration for Llama model"""
if prefix != "":
assert prefix.endswith(".")

config = {
"mp_config": {
"parallelize_plan": {
f"{prefix}llama.embed_tokens": dist.ColWiseParallel(gather_output=True),
f"{prefix}llama.layers.*.self_attn.qkv_proj": dist.ColWiseParallel(),
f"{prefix}llama.layers.*.self_attn.q_proj": dist.ColWiseParallel(),
f"{prefix}llama.layers.*.self_attn.k_proj": dist.ColWiseParallel(),
f"{prefix}llama.layers.*.self_attn.v_proj": dist.ColWiseParallel(),
f"{prefix}llama.layers.*.self_attn.o_proj": dist.RowWiseParallel(),
f"{prefix}llama.layers.*.mlp.gate_proj": dist.ColWiseParallel(),
f"{prefix}llama.layers.*.mlp.up_proj": dist.ColWiseParallel(),
f"{prefix}llama.layers.*.mlp.gate_up_fused_proj": dist.ColWiseParallel(),
f"{prefix}llama.layers.*.mlp.down_proj": dist.RowWiseParallel(),
f"{prefix}lm_head.weight": dist.ColWiseParallel(),
}
},
}
return config
50 changes: 47 additions & 3 deletions paddleformers/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
from paddle.distributed.fleet.utils import recompute
from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp

from paddleformers.transformers.conversion_utils import (
StateDictNameMapping,
init_name_mappings,
)

from ...nn.attention.interface import ALL_ATTENTION_FUNCTIONS
from ...nn.criterion.interface import CriterionLayer
from ...nn.embedding import Embedding as GeneralEmbedding
Expand All @@ -34,6 +39,7 @@
from ..model_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ..model_utils import PretrainedModel, register_base_model
from ..modeling_rope_utils import dynamic_rope_update
from .auto_dist_config import get_dist_config
from .configuration import LlamaConfig


Expand Down Expand Up @@ -160,9 +166,9 @@ def forward(
q_shape = (batch_size, seq_len, self.num_heads, self.head_dim)
kv_shape = (batch_size, seq_len, self.num_key_value_heads, self.head_dim)

query_states = self.q_proj(hidden_states).view(q_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(kv_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(kv_shape).transpose(1, 2)
query_states = self.q_proj(hidden_states).reshape(q_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).reshape(kv_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).reshape(kv_shape).transpose(1, 2)

cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
Expand Down Expand Up @@ -336,6 +342,40 @@ class LlamaPretrainedModel(PretrainedModel):
"down_proj",
]

@classmethod
def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]:
mappings: list[StateDictNameMapping] = []
model_mappings = [
["embed_tokens.weight"],
["norm.weight"],
]
for layer_index in range(config.num_hidden_layers):
layer_mappings = [
[f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"],
[f"layers.{layer_index}.self_attn.k_proj.weight", None, "transpose"],
[f"layers.{layer_index}.self_attn.v_proj.weight", None, "transpose"],
[f"layers.{layer_index}.self_attn.o_proj.weight", None, "transpose"],
[f"layers.{layer_index}.self_attn.rotary_emb.inv_freq"],
[f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"],
[f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"],
[f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"],
[f"layers.{layer_index}.input_layernorm.weight"],
[f"layers.{layer_index}.post_attention_layernorm.weight"],
]
model_mappings.extend(layer_mappings)

init_name_mappings(mappings=model_mappings)
# base-model prefix "LlamaModel"
if "LlamaModel" not in config.architectures:
for mapping in model_mappings:
mapping[0] = "model." + mapping[0]
mapping[1] = "llama." + mapping[1]
if not config.tie_word_embeddings:
model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"])

mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)]
return mappings

@classmethod
def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True):
from ..conversion_utils import split_or_merge_func
Expand Down Expand Up @@ -701,6 +741,10 @@ def forward(
attentions=outputs.attentions,
)

def auto_dist_config(self, prefix=""):
assert self.config.use_single_model_implementation, "Use `get_dist_config` only in single card mode."
return get_dist_config(self, prefix)


class LlamaForCausalLMPipe(GeneralModelForCausalLMPipe):
config_class = LlamaConfig
Expand Down
Loading