Skip to content

[CUDAGraph] GPT3-175B Pipeline Parallel Training with CUDAGraph using PipelineParallelMicroStepCallback #65634

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 29, 2024

Conversation

eee4017
Copy link
Contributor

@eee4017 eee4017 commented Jul 2, 2024

PR Category

Distributed Strategy

PR Types

New features

Description

This PR introduces significant enhancements and fixes to improve the functionality and debugging capabilities of pipeline parallel training in PaddlePaddle. The primary addition is the PipelineParallelMicroStepCallback, which allows for better management of hooks within pipeline parallel processes. This update is crucial for supporting CUDA Graph pipeline parallel training and includes several other improvements.

Key Features

  1. PipelineParallelMicroStepCallback:

    • This new feature facilitates enhanced management of hooks within pipeline parallel processes.
    • It allows for registering callbacks at specific pipeline stages: forward_begin, forward_end, backward_begin, and backward_end.
    • This functionality is particularly important for PipelineParallel, where layers are divided into multiple chunks.
    • The addition supports various tasks, such as logging, monitoring, and dynamic parameter adjustments during pipeline execution.
    • This feature aligns with PaddlePaddle's design principles and addresses the specific needs of pipeline parallelism.
  2. Support for CUDA Graph Pipeline Parallel Training:

    • The update is essential for enabling efficient pipeline parallel training with CUDA Graph.
    • It allows training of large models, such as GPT-3 175B, on 64 H100 GPUs using hybrid parallelism (Pipeline Parallel + Tensor Parallel + Sequence Parallel), achieving a 1.18x speedup in training performance.
  3. Worker Log Adjustment:

    • Updated the worker log to be ranked instead of being node-specific. This change ensures that the worker logs of each node do not collapse into a single log, facilitating better debugging and clarity.
  4. Debug Tools and Fixes:

    • Added several debug tools and fixed issues in the CUDA graphed layer, enhancing the overall debugging experience and reliability of CUDA Graph pipeline training.

Minor Fixes

  • Adjusted the worker log to rank-based logging.
  • Improved debug tools and fixed issues in the CUDA graphed layer.

For more information, please check NVIDIA/TransformerEngine#957
Check #65092

Copy link

paddle-bot bot commented Jul 2, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Jul 2, 2024
@jeng1220 jeng1220 added the NVIDIA label Jul 2, 2024
@eee4017 eee4017 force-pushed the cudagraph_175b_github_submit branch from bd6366e to ed69d5f Compare July 4, 2024 04:39
@jeng1220
Copy link
Collaborator

You must have one RD (phlrain or luotao1 or Aurelius84) approval

Copy link

paddle-ci-bot bot commented Jul 12, 2024

Sorry to inform you that ed69d5f's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@eee4017 eee4017 force-pushed the cudagraph_175b_github_submit branch 2 times, most recently from 1fe672f to 434c7eb Compare July 15, 2024 05:34
@@ -111,7 +111,7 @@ def _build_pod_with_args(self):
"POD_IP": self.ctx.node.ip,
}
e.update(_gloo_envs)
log_file = f"workerlog.{i}"
log_file = f"workerlog.{i + trainer_rank_offset}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个log 命名格式必须改嘛?
集群侧部分log 监控分析程序依赖 log 的结尾是 [0 ~ 7], 如果这块需要更新需要两边同时对齐

Copy link
Contributor Author

@eee4017 eee4017 Jul 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

多节点的时候,这个log会叠再一起,导致除错看不懂,所以想说改一下这个编号。单节点的行为应该是与以前一样的,不影响

Copy link
Contributor Author

@eee4017 eee4017 Jul 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个CI-coverage的覆蓋不夠多,看起来蛮多是在这workerlog的部分,应该原本就没测

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianshuo78520a ,
上述CI-coverage不夠的問題也是找你處理嗎?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianshuo78520a , 上述CI-coverage不夠的問題也是找你處理嗎?

已经处理

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

多节点的时候,这个log会叠再一起,导致除错看不懂,所以想说改一下这个编号。单节点的行为应该是与以前一样的,不影响

这里的多节点怎么理解?这个改动会改变日志保存行为吗?

Copy link
Contributor Author

@eee4017 eee4017 Jul 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

多节点是指多台机器的时候。多台机器多卡的时候每台机器的第0个device都会有一样的log编号,导致多个机器的log叠再一起

Copy link

paddle-ci-bot bot commented Jul 23, 2024

Sorry to inform you that 434c7eb's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@eee4017 eee4017 force-pushed the cudagraph_175b_github_submit branch from 434c7eb to 45b9ed9 Compare July 24, 2024 15:56
@jeng1220
Copy link
Collaborator

@sneaxiy , @JZ-LIANG , @ForFishes , @tianshuo78520a
所有測試都通過了,請問可以合併這PR了嗎?

Copy link
Member

@ForFishes ForFishes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ForFishes ForFishes merged commit c430ee4 into PaddlePaddle:develop Jul 29, 2024
31 checks passed
lixcli pushed a commit to lixcli/Paddle that referenced this pull request Aug 5, 2024
… PipelineParallelMicroStepCallback (PaddlePaddle#65634)

* CUDAGraph: PP hook and workerlog.rank

* fix header

* change logging.info to print

* fix pp hook

* fix logger

---------

Co-authored-by: Frank Lin (Engrg-Hardware 1) <fralin@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers NVIDIA
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants