Skip to content

Support pass python file as config. #10489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions llm/config/llama/pretrain_argument.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# models
model_name_or_path = "meta-llama/Meta-Llama-3-8B"
tokenizer_name_or_path = "meta-llama/Meta-Llama-3-8B"

# data
checkpoint_dirs = {
"input_dir": "./data",
"output_dir": "./checkpoints/pretrain_ckpts",
"unified_checkpoint": True,
"save_total_limit": 2,
}

training_contronl = {
"do_train": True,
"do_eval": True,
"do_predict": True,
"disable_tqdm": True,
"recompute": False,
"distributed_dataloader": 1,
"recompute_granularity": "full",
}


training_args = {
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 16,
"per_device_eval_batch_size": 2,
"tensor_parallel_degree": 2,
"pipeline_parallel_degree": 1,
"sharding": "stage2",
"virtual_pp_degree": 1,
"sequence_parallel": 0,
"max_seq_length": 4096,
"learning_rate": 3e-05,
"min_learning_rate": 3e-06,
"warmup_steps": 30,
"logging_steps": 1,
"max_steps": 10000,
"save_steps": 5000,
"eval_steps": 1000,
"weight_decay": 0.01,
"warmup_ratio": 0.01,
"max_grad_norm": 1.0,
"dataloader_num_workers": 1,
"continue_training": 0,
}
accelerate = {
"use_flash_attention": True,
"use_fused_rms_norm": True,
"use_fused_rope": True,
"bf16": True,
"fp16_opt_level": "O2",
}
2 changes: 2 additions & 0 deletions llm/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ def main():
gen_args, model_args, reft_args, data_args, training_args = parser.parse_json_file_and_cmd_lines()
elif len(sys.argv) >= 2 and sys.argv[1].endswith(".yaml"):
gen_args, model_args, reft_args, data_args, training_args = parser.parse_yaml_file_and_cmd_lines()
elif len(sys.argv) >= 2 and sys.argv[1].endswith(".py"):
gen_args, model_args, reft_args, data_args, training_args = parser.parse_python_file_and_cmd_lines()
else:
gen_args, model_args, reft_args, data_args, training_args = parser.parse_args_into_dataclasses()

Expand Down
2 changes: 2 additions & 0 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,8 @@ def main():
model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines()
elif len(sys.argv) >= 2 and sys.argv[1].endswith(".yaml"):
model_args, data_args, training_args = parser.parse_yaml_file_and_cmd_lines()
elif len(sys.argv) >= 2 and sys.argv[1].endswith(".py"):
model_args, data_args, training_args = parser.parse_python_file_and_cmd_lines()
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

Expand Down
58 changes: 58 additions & 0 deletions paddlenlp/trainer/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,64 @@
args = yaml_args + sys.argv[2:]
return self.common_parse(args, return_remaining_strings)

def read_python(self, python_file: str) -> list:

python_file = Path(python_file)

Check warning on line 347 in paddlenlp/trainer/argparser.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/argparser.py#L347

Added line #L347 was not covered by tests

def get_variables_exec(file_path):
def flatten(config):
ret = {}
for k, v in config.items():
if type(v) is dict:
sub = flatten(v)
for sk, sv in sub.items():
ret[sk] = sv

Check warning on line 356 in paddlenlp/trainer/argparser.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/argparser.py#L349-L356

Added lines #L349 - L356 were not covered by tests
else:
ret[k] = v
return ret

Check warning on line 359 in paddlenlp/trainer/argparser.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/argparser.py#L358-L359

Added lines #L358 - L359 were not covered by tests

with open(file_path, "r", encoding="utf-8") as f:
code = compile(f.read(), file_path, "exec")
globals_dict = {}
exec(code, globals_dict)
ret_dict = {k: globals_dict[k] for k in globals_dict if not k.startswith("__")}
return flatten(ret_dict)

Check warning on line 366 in paddlenlp/trainer/argparser.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/argparser.py#L361-L366

Added lines #L361 - L366 were not covered by tests

if python_file.exists():
data = get_variables_exec(python_file)

Check warning on line 369 in paddlenlp/trainer/argparser.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/argparser.py#L368-L369

Added lines #L368 - L369 were not covered by tests

python_args = []
for key, value in data.items():
if isinstance(value, list):
python_args.extend([f"--{key}", *[str(v) for v in value]])
elif isinstance(value, dict):
python_args.extend([f"--{key}", json.dumps(value)])

Check warning on line 376 in paddlenlp/trainer/argparser.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/argparser.py#L371-L376

Added lines #L371 - L376 were not covered by tests
else:
python_args.extend([f"--{key}", str(value)])
return python_args

Check warning on line 379 in paddlenlp/trainer/argparser.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/argparser.py#L378-L379

Added lines #L378 - L379 were not covered by tests
else:
raise FileNotFoundError(f"The argument file {python_file} does not exist.")

Check warning on line 381 in paddlenlp/trainer/argparser.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/argparser.py#L381

Added line #L381 was not covered by tests

def parse_python_file_and_cmd_lines(self, return_remaining_strings=False) -> Tuple[DataClass, ...]:
"""
Extend the functionality of `parse_python_file` to handle command line arguments in addition to loading a python
file.

When there is a conflict between the command line arguments and the YAML file configuration,
the command line arguments will take precedence.

Returns:
Tuple consisting of:

- the dataclass instances in the same order as they were passed to the initializer.abspath
"""
if not sys.argv[1].endswith(".py"):
raise ValueError(f"The first argument should be a PYTHON file, but it is {sys.argv[1]}")
python_args = self.read_python(sys.argv[1])

Check warning on line 398 in paddlenlp/trainer/argparser.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/argparser.py#L396-L398

Added lines #L396 - L398 were not covered by tests
# In case of conflict, command line arguments take precedence
args = python_args + sys.argv[2:]
return self.common_parse(args, return_remaining_strings)

Check warning on line 401 in paddlenlp/trainer/argparser.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/argparser.py#L400-L401

Added lines #L400 - L401 were not covered by tests

def parse_dict(self, args: dict) -> Tuple[DataClass, ...]:
"""
Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass
Expand Down