Open
Description
PyTorch 2.4 has a new API for pipeline parallelism, which includes PipelineStage
. With this, we can subclass PipelineStage
and override forward_one_chunk
and backward_one_chunk
, where each will first set the GPU's frequency using the async frequency controller and run actual forward/backward.
In case users already have an instance of PipelineStage
(manual splitting) or _PipelineStage
(automatic splitting with pipeline
), we can provide a static method on our PipelineStage
subclass that melts the user's pipeline stage into ours.
POC can be done on TorchTitan's train.py
without having to modify TorchTitan.
Activity