Skip to content

Commit

Permalink
modify data_path from args or test_config and fix some bugs (#50)
Browse files Browse the repository at this point in the history
modify data_path from args or test_config and fix some bugs (#50)
---------
Co-authored-by: UP <UP@.wwlwenli@163.com>
  • Loading branch information
upvenly authored Apr 10, 2023
1 parent d6399bd commit a61e300
Show file tree
Hide file tree
Showing 14 changed files with 24 additions and 32 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ nvidia_monitor.log rank1.out.log rank4.out.log rank7.out.log

### 贡献代码

本项目目前由北京智源人工智能研究院、昆仑芯、天数智芯、百度PaddlePaddle共同建设中
本项目目前由北京智源人工智能研究院、天数智芯、百度PaddlePaddle与昆仑芯共同建设中
诚邀各框架、芯片团队与个人参与!
### 联系我们

Expand Down
8 changes: 4 additions & 4 deletions docs/dev/run_pretraining.py.example
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ sys.path.append(os.path.abspath(os.path.join(CURR_PATH,
# 本地库
import config
from driver import Event, dist_pytorch
from driver.helper import InitHelper, get_finished_info
from driver.helper import InitHelper

# TODO 导入相关的模块、方法、变量。这里保持名称一致,实现可以不同。
from train import trainer_adapter
Expand Down Expand Up @@ -139,17 +139,17 @@ def main() -> Tuple[Any, Any]:

if __name__ == "__main__":
start = time.time()
config, state = main()
config_update, state = main()
if not dist_pytorch.is_main_process():
sys.exit(0)

# 训练信息写日志
e2e_time = time.time() - now
if config.do_train:
if config_update.do_train:

# TODO 构建训练所需的统计信息,包括不限于:e2e_time、training_sequences_per_second、
# converged、final_accuracy、raw_train_time、init_time
training_perf = (dist_pytorch.global_batch_size(config) *
training_perf = (dist_pytorch.global_batch_size(config_update) *
state.global_steps) / state.raw_train_time
finished_info = {
"e2e_time": e2e_time,
Expand Down
5 changes: 0 additions & 5 deletions training/benchmarks/driver/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@ def __init__(self, config, mutable_params):
self.logger = None

def setup_config(self, parser):
parser.add_argument(
"--data_dir",
type=str,
required=False,
help="The full path to the root of external modules")
parser.add_argument(
"--extern_module_dir",
type=str,
Expand Down
3 changes: 1 addition & 2 deletions training/benchmarks/driver/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,4 @@ def check_config(config):
if config.gradient_accumulation_steps < 1:
raise ValueError(
"Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
.format(config.gradient_accumulation_steps))
return config
.format(config.gradient_accumulation_steps))
4 changes: 2 additions & 2 deletions training/benchmarks/driver/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ def activate(base_config,

parsed_params = parse_from_args_and_config(params, cmd_args, ext_config,
enable_extern_config)
# tf2
if isinstance(base_config, object):
# TODO:后续考虑换一个更优雅的方式
if "tensorflow2" in base_config.__path__:
base_config.override(parsed_params.__dict__, False)
else:
_merge_dict_to_config(parsed_params.__dict__, base_config.__dict__)
Expand Down
7 changes: 3 additions & 4 deletions training/benchmarks/driver/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
import argparse
import os
import random
import time
import numpy as np
import torch
from driver import perf_logger, Driver, check
import driver
from driver import perf_logger, Driver, check


class InitHelper:
Expand All @@ -19,7 +18,6 @@ class InitHelper:
def __init__(self, config: object) -> None:
self.config = config
self.update_local_rank()
self.config = check.check_config(self.config)

def init_driver(self, global_module, local_module) -> Driver:
"""
Expand All @@ -29,7 +27,8 @@ def init_driver(self, global_module, local_module) -> Driver:
config = self.config
model_driver = Driver(config, config.mutable_params)
model_driver.setup_config(argparse.ArgumentParser(config.name))
model_driver.setup_modules(driver, global_module, local_module)
model_driver.setup_modules(global_module, local_module)
check.check_config(model_driver.config)
return model_driver

def get_logger(self) -> perf_logger.PerfLogger:
Expand Down
2 changes: 1 addition & 1 deletion training/benchmarks/mobilenetv2/pytorch/config/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# =========================================================
# data
# =========================================================
data_dir: str = "/home/data/imagenet"
data_dir: str = None
train_data: str = "train"
eval_data: str = "val"
output_dir: str = ""
Expand Down
13 changes: 6 additions & 7 deletions training/benchmarks/mobilenetv2/pytorch/run_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

CURR_PATH = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.abspath(os.path.join(CURR_PATH, "../../")))
import config
from driver import Event, dist_pytorch
from driver.helper import InitHelper, get_finished_info

from driver.helper import InitHelper
from train import trainer_adapter
from train.evaluator import Evaluator
from train.trainer import Trainer
Expand All @@ -21,9 +21,8 @@


def main() -> Tuple[Any, Any]:
import config
global logger

global config
init_helper = InitHelper(config)
model_driver = init_helper.init_driver(globals(), locals())
config = model_driver.config
Expand Down Expand Up @@ -97,14 +96,14 @@ def main() -> Tuple[Any, Any]:

if __name__ == "__main__":
start = time.time()
config, state = main()
config_update, state = main()
if not dist_pytorch.is_main_process():
sys.exit(0)

global_batch_size = dist_pytorch.global_batch_size(config)
global_batch_size = dist_pytorch.global_batch_size(config_update)
e2e_time = time.time() - start
finished_info = {"e2e_time": e2e_time}
if config.do_train:
if config_update.do_train:
training_perf = (global_batch_size *
state.global_steps) / state.raw_train_time
finished_info = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os
from port_for import is_available
import torch
import torch.distributed as dist
from torch.optim import Optimizer
Expand Down
10 changes: 6 additions & 4 deletions training/run_benchmarks/pytorch/start_pytorch_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,20 @@ def _get_basic_train_script_args(task_args):
'''Generate basic train script args according to the script options.'''
config_dir, config_file = helper.get_config_dir_file(task_args)
if config_dir is None or config_file is None:
START_LOGGER.error("Can't find config dir or config file.")
START_LOGGER.error(
f"Can't find config dir :{config_dir} or config file:{config_file}."
)
return None
if task_args.enable_extern_config:
extern_module_dir = helper.get_extern_module_dir(task_args)
if extern_module_dir is None:
START_LOGGER.error("Can't find extern module dir.")
return None

basic_train_script_args = " --data_dir " + task_args.data_dir \
+ " --extern_config_dir " + config_dir \
basic_train_script_args = " --extern_config_dir " + config_dir \
+ " --extern_config_file " + config_file \
+ " --vendor " + task_args.vendor
+ " --vendor " + task_args.vendor \
+ " --data_dir " + task_args.data_dir
if task_args.enable_extern_config:
basic_train_script_args += " --enable_extern_config " \
+ "--extern_module_dir " + extern_module_dir
Expand Down

0 comments on commit a61e300

Please sign in to comment.