Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] add pipeline.auto_parallel_profiler to auto_config #7343

Merged
merged 22 commits into from
Dec 15, 2023

Conversation

AndSonder
Copy link
Contributor

@AndSonder AndSonder commented Oct 31, 2023

PR types

New features

PR changes

Others

Description

在 PR PaddlePaddle/Paddle#58313 中为Paddle添加了可视化静态图模式下流水并行时序图的功能,本 PR 将 pipeline.auto_parallel_profiler 添加上 auto_config 中使得用户可以通过命令行参数启动该功能;

GPT-3 测试用例

运行示例:

log_dir=log_auto_6.7B_mp2pp4_st
rm -rf $log_dir
 
export FLAGS_embedding_deterministic=1
export FLAGS_cudnn_deterministic=1

 #nsys profile --stats=true -t cuda,nvtx -o 6.7B_st --capture-range=cudaProfilerApi --force-overwrite true \
python -m paddle.distributed.launch --log_dir $log_dir --devices "0,1,2,3" \
         ./tools/auto.py \
         -c ./ppfleetx/configs/nlp/gpt/auto/pretrain_gpt_1.3B_dp8.yaml \
         -o Global.local_batch_size=4 \
         -o Global.micro_batch_size=1 \
         -o Engine.max_steps=50 \
         -o Engine.logging_freq=10 \
         -o Engine.eval_freq=100 \
         -o Engine.save_load.save_steps=1000 \
         -o Model.use_recompute=True \
         -o Model.hidden_dropout_prob=0 \
         -o Model.attention_probs_dropout_prob=0 \
         -o Model.hidden_size=512 \
         -o Distributed.pipeline.schedule_mode=FThenB \
         -o Distributed.pipeline.auto_parallel_profiler=1 \
         -o Distributed.dp_degree=1 \
         -o Distributed.mp_degree=2 \
         -o Distributed.pp_degree=2 \
         -o Distributed.sharding.sharding_degree=1 \
         -o Distributed.sharding.sharding_stage=1 \
         -o Profiler_auto.memory_stats=True \
         -o Engine.verbose=3 \
         -o Model.hidden_dropout_prob=0 \
         -o Model.attention_probs_dropout_prob=0 \

其中 Distributed.pipeline.auto_parallel_profiler=1 为开启语句。

后续使用解析脚本生成json文件:

python python/paddle/distributed/auto_parallel/static/profiler_helper_static.py --devices 0,1,2,3 --log_dir /home/root/PaddleNLP/model_zoo/gpt-3/log_auto_6.7B_mp2pp4_st/

image

https://ui.d.dev/ 中打开 pipeline_profile_perfetto.json 效果如下:

image

Llama2 测试用例

task_name="llama_7b_pp2_mp4_st"
rm -rf output/$task_name/
rm -rf "output/$task_name""_log"

export SOT_LOG_LEVEL=4
export PYTHONPATH=../../:$PYTHONPATH

export FLAGS_embedding_deterministic=1
export FLAGS_cudnn_deterministic=1

export CUDA_DEVICE_MAX_CONNECTIONS=1

python -u  -m paddle.distributed.launch \
     --gpus "0,1,2,3" \
     --log_dir "output/$task_name""_log" \
     run_pretrain_auto.py \
     --model_type "llama" \
     --model_name_or_path "meta-llama/Llama-2-7b" \
     --tokenizer_name_or_path "meta-llama/Llama-2-7b" \
     --input_dir "./data" \
     --output_dir "output/$task_name" \
     --split 949,50,1 \
     --max_seq_length 2048 \
     --per_device_train_batch_size 1 \
     --per_device_eval_batch_size 1 \
     --gradient_accumulation_steps 4 \
     --use_flash_attention 0 \
     --use_fused_rms_norm 0 \
     --fp16 0 \
     --fp16_opt_level "O2"  \
     --scale_loss 1024 \
     --pipeline_parallel_degree 4 \
     --tensor_parallel_degree 1 \
     --sharding_parallel_degree 1 \
     --sharding "stage1" \
     --learning_rate 0.0001 \
     --min_learning_rate 0.00001 \
     --max_steps 10 \
     --save_steps 5000 \
     --weight_decay 0.01 \
     --warmup_ratio 0.01 \
     --max_grad_norm 1.0 \
     --logging_steps 1\
     --dataloader_num_workers 1 \
     --sharding "" \
     --eval_steps 1000 \
     --report_to "visualdl" \
     --disable_tqdm true \
     --continue_training 0\
     --recompute 1 \
     --do_train \
     --do_eval 0 \
     --device "gpu" \
     --data_impl "mmap" \
     --parallel_mode "auto" \
     --job_schedule_profiler_start 0 \
     --job_schedule_profiler_end 5 \

可视化结果:

python python/paddle/distributed/auto_parallel/static/profiler_helper_static.py --devices 0,1,2,3 --log_dir /home/workspace/PaddleNLP/llm/llama/output/llama_7b_pp2_mp4_st_log
image

依赖 pr:

@AndSonder AndSonder changed the title [feat] add pipeline.auto_parallel_profiler to auto_config [feat][AutoParallel] add pipeline.auto_parallel_profiler to auto_config Oct 31, 2023
@codecov
Copy link

codecov bot commented Oct 31, 2023

Codecov Report

Attention: 40 lines in your changes are missing coverage. Please review.

Comparison is base (1bfe864) 57.59% compared to head (c4efefe) 57.58%.
Report is 1 commits behind head on develop.

Files Patch % Lines
paddlenlp/transformers/llama/modeling_auto.py 9.30% 39 Missing ⚠️
paddlenlp/trainer/training_args.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #7343      +/-   ##
===========================================
- Coverage    57.59%   57.58%   -0.02%     
===========================================
  Files          582      582              
  Lines        86912    86929      +17     
===========================================
+ Hits         50061    50062       +1     
- Misses       36851    36867      +16     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment on lines 549 to 554
if (step == job_schedule_profiler_start) and training_args.use_auto_parallel:
engine.enable_job_schedule_profiler = True

if (step == job_schedule_profiler_end) and training_args.use_auto_parallel:
engine.enable_job_schedule_profiler = False
sys.exit()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可否写成guard的形式,类似nvprof_guard

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我有考虑过实现成 nvprof_guard 的方式,但是 nvprof_guard 的实现里面是直接调用的 c++ api,因为它只需要做push和pop即可,但是我们需要在开启的step之后通过改变传入参数的方式去启动 profiler,感觉 nvprof 的方式就不太适用了

@@ -325,6 +325,10 @@ class TrainingArguments:
Whether skip profile timer, timer will record time usage of forward/ backward/ step, etc.
distributed_dataloader (`bool`, *optional*):
Whether to use distributed dataloader. Default is `False`.
job_schedule_profiler_start (`int`, *optional*):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

命令行参数可以复用pipeline_parallel_config吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

命令行参数可以复用pipeline_parallel_config吗?

好像不行,pipeline_parallel_config 里面只能放布尔值

@AndSonder AndSonder changed the title [feat][AutoParallel] add pipeline.auto_parallel_profiler to auto_config [AutoParallel] add pipeline.auto_parallel_profiler to auto_config Dec 7, 2023
@@ -27,6 +27,7 @@
import paddle
import paddle.distributed as dist
import paddle.distributed.auto_parallel as auto
from paddle.utils.profiler import job_schedule_profiler_range
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from paddle.profiler.utils import job_schedule_profiler_range

Copy link
Collaborator

@From00 From00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@From00 From00 merged commit 5106809 into PaddlePaddle:develop Dec 15, 2023
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants