Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] add pipeline.auto_parallel_profiler to auto_config #7343

Merged
merged 22 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions llm/llama/run_pretrain_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,9 +500,11 @@ def fn(layer):
def loss_func(loss, outputs):
return loss

total_train_batch_size = training_args.per_device_train_batch_size \
* training_args.gradient_accumulation_steps \
* training_args.data_parallel_degree
total_train_batch_size = (
training_args.per_device_train_batch_size
* training_args.gradient_accumulation_steps
* training_args.data_parallel_degree
)
print_config(training_args)

engine = auto.Engine(model, loss_func, optimizer, strategy=training_args.strategy)
Expand Down Expand Up @@ -538,8 +540,19 @@ def loss_func(loss, outputs):
global_step_last_logged = 0
start_time_last_logged = time.time()
tr_loss = float(0)

job_schedule_profiler_start = training_args.job_schedule_profiler_start
job_schedule_profiler_end = training_args.job_schedule_profiler_end

for epoch_idx in range(num_train_epochs):
for step, inputs in enumerate(train_dataloader):
if (step == job_schedule_profiler_start) and training_args.use_auto_parallel:
engine.enable_job_schedule_profiler = True

if (step == job_schedule_profiler_end) and training_args.use_auto_parallel:
engine.enable_job_schedule_profiler = False
sys.exit()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可否写成guard的形式,类似nvprof_guard

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我有考虑过实现成 nvprof_guard 的方式,但是 nvprof_guard 的实现里面是直接调用的 c++ api,因为它只需要做push和pop即可,但是我们需要在开启的step之后通过改变传入参数的方式去启动 profiler,感觉 nvprof 的方式就不太适用了


outs = engine.run(inputs, mode="train")

if "loss" in outs:
Expand Down
16 changes: 15 additions & 1 deletion model_zoo/gpt-3/ppfleetx/core/engine/auto_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
import sys
import numpy as np

import paddle
Expand Down Expand Up @@ -80,7 +81,10 @@ def __init__(self, configs, module=None, mode="train"):

# Distributed
self._pp_degree = configs["Distributed"]["pp_degree"]

pipeline_cfg = configs.Distributed.get("pipeline", {})
self._job_schedule_profiler_start = pipeline_cfg.get("job_schedule_profiler_start", -1)
self._job_schedule_profiler_end = pipeline_cfg.get("job_schedule_profiler_end", -1)

# engine configs
self._configs = configs["Engine"]

Expand Down Expand Up @@ -140,6 +144,9 @@ def __init__(self, configs, module=None, mode="train"):
self.memory_stats = configs.get("Profiler_auto", {}).get("memory_stats", False)
self.nvprof_start = configs.get("Profiler_auto", {}).get("nvprof_start", -1)
self.nvprof_end = configs.get("Profiler_auto", {}).get("nvprof_end", -1)

if (self._job_schedule_profiler_start != -1) and use_new_executor():
logger.info("Schedule Profiler start at step {} and end at step {}".format(self._job_schedule_profiler_start, self._job_schedule_profiler_end))

def _validate_batch(self, batch):
if self._pp_degree > 1 or self._accumulate_steps == 1:
Expand Down Expand Up @@ -174,6 +181,13 @@ def _train_one_epoch(self, epoch_index, train_data_loader=None, valid_data_loade
self._auto_engine.prepare(mode="train")

for step, batch in enumerate(train_data_loader):
if (step == self._job_schedule_profiler_start) and use_new_executor():
self._auto_engine.enable_job_schedule_profiler = True

if (step == self._job_schedule_profiler_end - 1) and use_new_executor():
self._auto_engine.enable_job_schedule_profiler = False
sys.exit()

if epoch_index == self._load_recovery["epoch"]:
if step < self._load_recovery["step"]:
continue
Expand Down
2 changes: 2 additions & 0 deletions model_zoo/gpt-3/ppfleetx/utils/auto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ def process_strategy(config):
pipeline.schedule_mode = pipeline_cfg.get("schedule_mode", "1F1B")
pipeline.micro_batch_size = config.Global.micro_batch_size
pipeline.accumulate_steps = accumulate_steps
pipeline.job_schedule_profiler_start = pipeline_cfg.get("job_schedule_profiler_start", -1)
pipeline.job_schedule_profiler_stop = pipeline_cfg.get("job_schedule_profiler_stop", -1)

elif accumulate_steps > 1:
# gradient merge config
Expand Down
16 changes: 16 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,10 @@
Whether skip profile timer, timer will record time usage of forward/ backward/ step, etc.
distributed_dataloader (`bool`, *optional*):
Whether to use distributed dataloader. Default is `False`.
job_schedule_profiler_start (`int`, *optional*):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

命令行参数可以复用pipeline_parallel_config吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

命令行参数可以复用pipeline_parallel_config吗?

好像不行,pipeline_parallel_config 里面只能放布尔值

The start step of job schedule profiler. Default is `-1`.
job_schedule_profiler_end (`int`, *optional*):
The end step of job schedule profiler. Default is `-1`.
"""

output_dir: str = field(
Expand Down Expand Up @@ -706,6 +710,16 @@
metadata={"help": "Whether to unify hybrid parallel checkpoint."},
)

job_schedule_profiler_start: Optional[int] = field(
default=-1,
metadata={"help": "The start step of job schedule profiler."},
)

job_schedule_profiler_end: Optional[int] = field(
default=-1,
metadata={"help": "The end step of job schedule profiler."},
)

def __post_init__(self):
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
if env_local_rank != -1 and env_local_rank != self.local_rank and paddle.distributed.get_world_size() > 1:
Expand Down Expand Up @@ -1085,6 +1099,8 @@
pipeline.accumulate_steps = self.gradient_accumulation_steps
pipeline.micro_batch_size = self.per_device_train_batch_size
pipeline.schedule_mode = "1F1B"
pipeline.job_schedule_profiler_start = self.job_schedule_profiler_start
pipeline.job_schedule_profiler_end = self.job_schedule_profiler_end

Check warning on line 1103 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1103

Added line #L1103 was not covered by tests

if self.amp_master_grad:
warnings.warn("`amp_master_grad` is not supported NOW in AutoParallel!")
Expand Down