Skip to content

Commit

Permalink
Update DDP docs for Dynamo/DDPOptimizer (pytorch#89096)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#89096
Approved by: https://github.com/msaroufim
  • Loading branch information
wconstab authored and pytorchmergebot committed Nov 30, 2022
1 parent 12f98f8 commit 4472837
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions docs/source/notes/ddp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,18 @@ updated, and all models on different processes should be exactly the same.
os.environ["MASTER_PORT"] = "29500"
main()
DDP works with TorchDynamo. When used with TorchDynamo, apply the DDP model wrapper
before compiling the model, such that torchdynamo can apply ``DDPOptimizer``
(graph-break optimizations) based on DDP bucket sizes. (See `TorchDynamo DDPOptimizer <./ddp.html#torchdynamo-ddpoptimizer>`_ for more information.)

TorchDynamo support for DDP currently requires setting `static_graph=True` and `find_unused_parameters=True`, due to
interactions between the graph tracing process and DDP's mechanism for observing operations happening on its module,
but this should be fixed ultimately.

.. code::
ddp_model = DDP(model, device_ids=[rank])
ddp_model = torch.compile(ddp_model)
Internal Design
^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -193,3 +204,24 @@ DistributedDataParallel
.. image:: https://user-images.githubusercontent.com/16999635/72313120-4e7c1c80-3658-11ea-9c6d-44336b2daeac.png
:alt: ddp_code.png
:width: 400 px


TorchDynamo DDPOptimizer
------------------------

DDP's performance advantage comes from overlapping allreduce collectives with computations during backwards.
AotAutograd prevents this overlap when used with TorchDynamo for compiling a whole forward and whole backward graph,
becuase allreduce ops are launched by autograd hooks _after_ the whole optimized backwards computation finishes.

TorchDynamo's DDPOptimizer helps by breaking the forward graph at the logical boundaries of DDP's allreduce buckets
during backwards. Note: the goal is to break the graph during backwards, and the simplest implementation is to
break the forward graphs and then call AotAutograd and compilation on each section. This allows DDP's allreduce hooks
to fire in-between sections of backwards, and schedule communications to overlap with compute.

See `this blog post <https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860/1>`_ for
a more in-depth explanation and experimental results, or read the docs and code at
`torch/_dynamo/optimizations/distributed.py <https://github.com/pytorch/pytorch/blob/4908a12542798a3e8641faae6b74f068fdfc6778/torch/_dynamo/optimizations/distributed.py#L56>`_

To Debug DDPOptimizer, set `torch._dynamo.config.log_level` to DEBUG (for full graph dumps) or INFO
(for basic info about bucket boundaries). To disable DDPOptimizer, set `torch._dynamo.config.optimize_ddp=False`.
DDP and TorchDynamo should still work correctly without DDPOptimizer, but with performance degradation.

0 comments on commit 4472837

Please sign in to comment.