Skip to content

Commit 80542a1

Browse files
authored
Merge pull request pytorch#488 from mrshenli/rst
Adding DDP tutorial in rst format
2 parents 63d6601 + c342d2b commit 80542a1

File tree

2 files changed

+261
-1
lines changed

2 files changed

+261
-1
lines changed

index.rst

+6-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ starting point, and provides a broad view into how to use PyTorch from the basic
77

88
Some considerations:
99

10-
* We’ve added a new feature to tutorials that allows users to open the notebook associated with a tutorial in Google Colab.
10+
* We’ve added a new feature to tutorials that allows users to open the notebook associated with a tutorial in Google Colab.
1111
Visit `this page <https://pytorch.org/tutorials/beginner/colab.html>`_ for more information.
1212
* If you would like to do the tutorials interactively via IPython / Jupyter,
1313
each tutorial has a download link for a Jupyter Notebook and Python source code.
@@ -216,6 +216,11 @@ Production Usage
216216
:description: :doc:`/intermediate/model_parallel_tutorial`
217217
:figure: _static/img/distributed/DistPyTorch.jpg
218218

219+
.. customgalleryitem::
220+
:tooltip: Getting started with DistributedDataParallel
221+
:description: :doc:`/intermediate/ddp_tutorial`
222+
:figure: _static/img/distributed/DistPyTorch.jpg
223+
219224
.. customgalleryitem::
220225
:tooltip: PyTorch distributed trainer with Amazon AWS
221226
:description: :doc:`/beginner/aws_distributed_training_tutorial`

intermediate_source/ddp_tutorial.rst

+255
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
Getting Started with Distributed Data Parallel
2+
=============================================
3+
**Author**: `Shen Li <https://mrshenli.github.io/>`_
4+
5+
`DistributedDataParallel <https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html>`__
6+
(DDP) implements data parallelism at the module level. It uses communication
7+
collectives in the `torch.distributed <https://pytorch.org/tutorials/intermediate/dist_tuto.html>`__
8+
package to synchronize gradients, parameters, and buffers. Parallelism is
9+
available both within a process and across processes. Within a process, DDP
10+
replicates the input module to devices specified in ``device_ids``, scatters
11+
inputs along the batch dimension accordingly, and gathers outputs to the
12+
``output_device``, which is similar to
13+
`DataParallel <https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html>`__.
14+
Across processes, DDP inserts necessary parameter synchronizations in forward
15+
passes and gradient synchronizations in backward passes. It is up to users to
16+
map processes to available resources, as long as processes do not share GPU
17+
devices. The recommended (usually fastest) approach is to create a process for
18+
every module replica, i.e., no module replication within a process. The code in
19+
this tutorial runs on an 8-GPU server, but it can be easily generalized to
20+
other environments.
21+
22+
23+
Basic Use Case
24+
--------------
25+
26+
To create DDP modules, first set up process groups properly. More details can
27+
be found in
28+
`WRITING DISTRIBUTED APPLICATIONS WITH PYTORCH <https://pytorch.org/tutorials/intermediate/dist_tuto.html>`__.
29+
30+
.. code:: python
31+
32+
import os
33+
import tempfile
34+
import torch
35+
import torch.distributed as dist
36+
import torch.nn as nn
37+
import torch.optim as optim
38+
import torch.multiprocessing as mp
39+
40+
from torch.nn.parallel import DistributedDataParallel as DDP
41+
42+
43+
def setup(rank, world_size):
44+
os.environ['MASTER_ADDR'] = 'localhost'
45+
os.environ['MASTER_PORT'] = '12355'
46+
47+
# initialize the process group
48+
dist.init_process_group("gloo", rank=rank, world_size=world_size)
49+
50+
# Explicitly setting seed to make sure that models created in two processes
51+
# start from same random weights and biases.
52+
torch.manual_seed(42)
53+
54+
55+
def cleanup():
56+
dist.destroy_process_group()
57+
58+
Now, let's create a toy module, wrap it with DDP, and feed it with some dummy
59+
input data. Please note, if training starts from random parameters, you might
60+
want to make sure that all DDP processes use the same initial values.
61+
Otherwise, global gradient synchronizes will not make sense.
62+
63+
.. code:: python
64+
65+
class ToyModel(nn.Module):
66+
def __init__(self):
67+
super(ToyModel, self).__init__()
68+
self.net1 = nn.Linear(10, 10)
69+
self.relu = nn.ReLU()
70+
self.net2 = nn.Linear(10, 5)
71+
72+
def forward(self, x):
73+
return self.net2(self.relu(self.net1(x)))
74+
75+
76+
def demo_basic(rank, world_size):
77+
setup(rank, world_size)
78+
79+
# setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
80+
# rank 2 uses GPUs [4, 5, 6, 7].
81+
n = torch.cuda.device_count() // world_size
82+
device_ids = list(range(rank * n, (rank + 1) * n))
83+
84+
# create model and move it to device_ids[0]
85+
model = ToyModel().to(device_ids[0])
86+
# output_device defaults to device_ids[0]
87+
ddp_model = DDP(model, device_ids=device_ids)
88+
89+
loss_fn = nn.MSELoss()
90+
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
91+
92+
optimizer.zero_grad()
93+
outputs = ddp_model(torch.randn(20, 10))
94+
labels = torch.randn(20, 5).to(device_ids[0])
95+
loss_fn(outputs, labels).backward()
96+
optimizer.step()
97+
98+
cleanup()
99+
100+
101+
def run_demo(demo_fn, world_size):
102+
mp.spawn(demo_fn,
103+
args=(world_size,),
104+
nprocs=world_size,
105+
join=True)
106+
107+
As you can see, DDP wraps lower level distributed communication details, and
108+
provides a clean API as if it is a local model. For basic use cases, DDP only
109+
requires a few more LoCs to set up the process group. When applying DDP to more
110+
advanced use cases, there are some caveats that require cautions.
111+
112+
Skewed Processing Speeds
113+
------------------------
114+
115+
In DDP, constructor, forward method, and differentiation of the outputs are
116+
distributed synchronization points. Different processes are expected to reach
117+
synchronization points in the same order and enter each synchronization point
118+
at roughly the same time. Otherwise, fast processes might arrive early and
119+
timeout on waiting for stragglers. Hence, users are responsible for balancing
120+
workloads distributions across processes. Sometimes, skewed processing speeds
121+
are inevitable due to, e.g., network delays, resource contentions,
122+
unpredictable workload spikes. To avoid timeouts in these situations, make
123+
sure that you pass a sufficiently large ``timeout`` value when calling
124+
`init_process_group <https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group>`__.
125+
126+
Save and Load Checkpoints
127+
-------------------------
128+
129+
It's common to use ``torch.save`` and ``torch.load`` to checkpoint modules
130+
during training and recover from checkpoints. See
131+
`SAVING AND LOADING MODELS <https://pytorch.org/tutorials/beginner/saving_loading_models.html>`__
132+
for more details. When using DDP, one optimization is to save the model in
133+
only one process and then load it to all processes, reducing write overhead.
134+
This is correct because all processes start from the same parameters and
135+
gradients are synchronized in backward passes, and hence optimizers should keep
136+
setting parameters to same values. If you use this optimization, make sure all
137+
processes do not start loading before the saving is finished. Besides, when
138+
loading the module, you need to provide an appropriate ``map_location``
139+
argument to prevent a process to step into others' devices. If ``map_location``
140+
is missing, ``torch.load`` will first load the module to CPU and then copy each
141+
parameter to where it was saved, which would result in all processes on the
142+
same machine using the same set of devices.
143+
144+
.. code:: python
145+
146+
def demo_checkpoint(rank, world_size):
147+
setup(rank, world_size)
148+
149+
# setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
150+
# rank 2 uses GPUs [4, 5, 6, 7].
151+
n = torch.cuda.device_count() // world_size
152+
device_ids = list(range(rank * n, (rank + 1) * n))
153+
154+
model = ToyModel().to(device_ids[0])
155+
# output_device defaults to device_ids[0]
156+
ddp_model = DDP(model, device_ids=device_ids)
157+
158+
loss_fn = nn.MSELoss()
159+
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
160+
161+
CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
162+
if rank == 0:
163+
# All processes should see same parameters as they all start from same
164+
# random parameters and gradients are synchronized in backward passes.
165+
# Therefore, saving it in one process is sufficient.
166+
torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)
167+
168+
# Use a barrier() to make sure that process 1 loads the model after process
169+
# 0 saves it.
170+
dist.barrier()
171+
# configure map_location properly
172+
rank0_devices = [x - rank * len(device_ids) for x in device_ids]
173+
device_pairs = zip(rank0_devices, device_ids)
174+
map_location = {'cuda:%d' % x: 'cuda:%d' % y for x, y in device_pairs}
175+
ddp_model.load_state_dict(
176+
torch.load(CHECKPOINT_PATH, map_location=map_location))
177+
178+
optimizer.zero_grad()
179+
outputs = ddp_model(torch.randn(20, 10))
180+
labels = torch.randn(20, 5).to(device_ids[0])
181+
loss_fn = nn.MSELoss()
182+
loss_fn(outputs, labels).backward()
183+
optimizer.step()
184+
185+
# Use a barrier() to make sure that all processes have finished reading the
186+
# checkpoint
187+
dist.barrier()
188+
189+
if rank == 0:
190+
os.remove(CHECKPOINT_PATH)
191+
192+
cleanup()
193+
194+
Combine DDP with Model Parallelism
195+
----------------------------------
196+
197+
DDP also works with multi-GPU models, but replications within a process are not
198+
supported. You need to create one process per module replica, which usually
199+
leads to better performance compared to multiple replicas per process. DDP
200+
wrapping multi-GPU models is especially helpful when training large models with
201+
a huge amount of data. When using this feature, the multi-GPU model needs to be
202+
carefully implemented to avoid hard-coded devices, because different model
203+
replicas will be placed to different devices.
204+
205+
.. code:: python
206+
207+
class ToyMpModel(nn.Module):
208+
def __init__(self, dev0, dev1):
209+
super(ToyMpModel, self).__init__()
210+
self.dev0 = dev0
211+
self.dev1 = dev1
212+
self.net1 = torch.nn.Linear(10, 10).to(dev0)
213+
self.relu = torch.nn.ReLU()
214+
self.net2 = torch.nn.Linear(10, 5).to(dev1)
215+
216+
def forward(self, x):
217+
x = x.to(self.dev0)
218+
x = self.relu(self.net1(x))
219+
x = x.to(self.dev1)
220+
return self.net2(x)
221+
222+
When passing a multi-GPU model to DDP, ``device_ids`` and ``output_device``
223+
must NOT be set. Input and output data will be placed in proper devices by
224+
either the application or the model ``forward()`` method.
225+
226+
.. code:: python
227+
228+
def demo_model_parallel(rank, world_size):
229+
setup(rank, world_size)
230+
231+
# setup mp_model and devices for this process
232+
dev0 = rank * 2
233+
dev1 = rank * 2 + 1
234+
mp_model = ToyMpModel(dev0, dev1)
235+
ddp_mp_model = DDP(mp_model)
236+
237+
loss_fn = nn.MSELoss()
238+
optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001)
239+
240+
optimizer.zero_grad()
241+
# outputs will be on dev1
242+
outputs = ddp_mp_model(torch.randn(20, 10))
243+
labels = torch.randn(20, 5).to(dev1)
244+
loss_fn(outputs, labels).backward()
245+
optimizer.step()
246+
247+
cleanup()
248+
249+
250+
if __name__ == "__main__":
251+
run_demo(demo_basic, 2)
252+
run_demo(demo_checkpoint, 2)
253+
254+
if torch.cuda.device_count() >= 8:
255+
run_demo(demo_model_parallel, 4)

0 commit comments

Comments
 (0)