|
| 1 | +Getting Started with Distributed Data Parallel |
| 2 | +============================================= |
| 3 | +**Author**: `Shen Li <https://mrshenli.github.io/>`_ |
| 4 | + |
| 5 | +`DistributedDataParallel <https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html>`__ |
| 6 | +(DDP) implements data parallelism at the module level. It uses communication |
| 7 | +collectives in the `torch.distributed <https://pytorch.org/tutorials/intermediate/dist_tuto.html>`__ |
| 8 | +package to synchronize gradients, parameters, and buffers. Parallelism is |
| 9 | +available both within a process and across processes. Within a process, DDP |
| 10 | +replicates the input module to devices specified in ``device_ids``, scatters |
| 11 | +inputs along the batch dimension accordingly, and gathers outputs to the |
| 12 | +``output_device``, which is similar to |
| 13 | +`DataParallel <https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html>`__. |
| 14 | +Across processes, DDP inserts necessary parameter synchronizations in forward |
| 15 | +passes and gradient synchronizations in backward passes. It is up to users to |
| 16 | +map processes to available resources, as long as processes do not share GPU |
| 17 | +devices. The recommended (usually fastest) approach is to create a process for |
| 18 | +every module replica, i.e., no module replication within a process. The code in |
| 19 | +this tutorial runs on an 8-GPU server, but it can be easily generalized to |
| 20 | +other environments. |
| 21 | + |
| 22 | + |
| 23 | +Basic Use Case |
| 24 | +-------------- |
| 25 | + |
| 26 | +To create DDP modules, first set up process groups properly. More details can |
| 27 | +be found in |
| 28 | +`WRITING DISTRIBUTED APPLICATIONS WITH PYTORCH <https://pytorch.org/tutorials/intermediate/dist_tuto.html>`__. |
| 29 | + |
| 30 | +.. code:: python |
| 31 | +
|
| 32 | + import os |
| 33 | + import tempfile |
| 34 | + import torch |
| 35 | + import torch.distributed as dist |
| 36 | + import torch.nn as nn |
| 37 | + import torch.optim as optim |
| 38 | + import torch.multiprocessing as mp |
| 39 | +
|
| 40 | + from torch.nn.parallel import DistributedDataParallel as DDP |
| 41 | +
|
| 42 | +
|
| 43 | + def setup(rank, world_size): |
| 44 | + os.environ['MASTER_ADDR'] = 'localhost' |
| 45 | + os.environ['MASTER_PORT'] = '12355' |
| 46 | +
|
| 47 | + # initialize the process group |
| 48 | + dist.init_process_group("gloo", rank=rank, world_size=world_size) |
| 49 | +
|
| 50 | + # Explicitly setting seed to make sure that models created in two processes |
| 51 | + # start from same random weights and biases. |
| 52 | + torch.manual_seed(42) |
| 53 | +
|
| 54 | +
|
| 55 | + def cleanup(): |
| 56 | + dist.destroy_process_group() |
| 57 | +
|
| 58 | +Now, let's create a toy module, wrap it with DDP, and feed it with some dummy |
| 59 | +input data. Please note, if training starts from random parameters, you might |
| 60 | +want to make sure that all DDP processes use the same initial values. |
| 61 | +Otherwise, global gradient synchronizes will not make sense. |
| 62 | + |
| 63 | +.. code:: python |
| 64 | +
|
| 65 | + class ToyModel(nn.Module): |
| 66 | + def __init__(self): |
| 67 | + super(ToyModel, self).__init__() |
| 68 | + self.net1 = nn.Linear(10, 10) |
| 69 | + self.relu = nn.ReLU() |
| 70 | + self.net2 = nn.Linear(10, 5) |
| 71 | +
|
| 72 | + def forward(self, x): |
| 73 | + return self.net2(self.relu(self.net1(x))) |
| 74 | +
|
| 75 | +
|
| 76 | + def demo_basic(rank, world_size): |
| 77 | + setup(rank, world_size) |
| 78 | +
|
| 79 | + # setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and |
| 80 | + # rank 2 uses GPUs [4, 5, 6, 7]. |
| 81 | + n = torch.cuda.device_count() // world_size |
| 82 | + device_ids = list(range(rank * n, (rank + 1) * n)) |
| 83 | +
|
| 84 | + # create model and move it to device_ids[0] |
| 85 | + model = ToyModel().to(device_ids[0]) |
| 86 | + # output_device defaults to device_ids[0] |
| 87 | + ddp_model = DDP(model, device_ids=device_ids) |
| 88 | +
|
| 89 | + loss_fn = nn.MSELoss() |
| 90 | + optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) |
| 91 | +
|
| 92 | + optimizer.zero_grad() |
| 93 | + outputs = ddp_model(torch.randn(20, 10)) |
| 94 | + labels = torch.randn(20, 5).to(device_ids[0]) |
| 95 | + loss_fn(outputs, labels).backward() |
| 96 | + optimizer.step() |
| 97 | +
|
| 98 | + cleanup() |
| 99 | +
|
| 100 | +
|
| 101 | + def run_demo(demo_fn, world_size): |
| 102 | + mp.spawn(demo_fn, |
| 103 | + args=(world_size,), |
| 104 | + nprocs=world_size, |
| 105 | + join=True) |
| 106 | +
|
| 107 | +As you can see, DDP wraps lower level distributed communication details, and |
| 108 | +provides a clean API as if it is a local model. For basic use cases, DDP only |
| 109 | +requires a few more LoCs to set up the process group. When applying DDP to more |
| 110 | +advanced use cases, there are some caveats that require cautions. |
| 111 | + |
| 112 | +Skewed Processing Speeds |
| 113 | +------------------------ |
| 114 | + |
| 115 | +In DDP, constructor, forward method, and differentiation of the outputs are |
| 116 | +distributed synchronization points. Different processes are expected to reach |
| 117 | +synchronization points in the same order and enter each synchronization point |
| 118 | +at roughly the same time. Otherwise, fast processes might arrive early and |
| 119 | +timeout on waiting for stragglers. Hence, users are responsible for balancing |
| 120 | +workloads distributions across processes. Sometimes, skewed processing speeds |
| 121 | +are inevitable due to, e.g., network delays, resource contentions, |
| 122 | +unpredictable workload spikes. To avoid timeouts in these situations, make |
| 123 | +sure that you pass a sufficiently large ``timeout`` value when calling |
| 124 | +`init_process_group <https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group>`__. |
| 125 | + |
| 126 | +Save and Load Checkpoints |
| 127 | +------------------------- |
| 128 | + |
| 129 | +It's common to use ``torch.save`` and ``torch.load`` to checkpoint modules |
| 130 | +during training and recover from checkpoints. See |
| 131 | +`SAVING AND LOADING MODELS <https://pytorch.org/tutorials/beginner/saving_loading_models.html>`__ |
| 132 | +for more details. When using DDP, one optimization is to save the model in |
| 133 | +only one process and then load it to all processes, reducing write overhead. |
| 134 | +This is correct because all processes start from the same parameters and |
| 135 | +gradients are synchronized in backward passes, and hence optimizers should keep |
| 136 | +setting parameters to same values. If you use this optimization, make sure all |
| 137 | +processes do not start loading before the saving is finished. Besides, when |
| 138 | +loading the module, you need to provide an appropriate ``map_location`` |
| 139 | +argument to prevent a process to step into others' devices. If ``map_location`` |
| 140 | +is missing, ``torch.load`` will first load the module to CPU and then copy each |
| 141 | +parameter to where it was saved, which would result in all processes on the |
| 142 | +same machine using the same set of devices. |
| 143 | + |
| 144 | +.. code:: python |
| 145 | +
|
| 146 | + def demo_checkpoint(rank, world_size): |
| 147 | + setup(rank, world_size) |
| 148 | +
|
| 149 | + # setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and |
| 150 | + # rank 2 uses GPUs [4, 5, 6, 7]. |
| 151 | + n = torch.cuda.device_count() // world_size |
| 152 | + device_ids = list(range(rank * n, (rank + 1) * n)) |
| 153 | +
|
| 154 | + model = ToyModel().to(device_ids[0]) |
| 155 | + # output_device defaults to device_ids[0] |
| 156 | + ddp_model = DDP(model, device_ids=device_ids) |
| 157 | +
|
| 158 | + loss_fn = nn.MSELoss() |
| 159 | + optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) |
| 160 | +
|
| 161 | + CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint" |
| 162 | + if rank == 0: |
| 163 | + # All processes should see same parameters as they all start from same |
| 164 | + # random parameters and gradients are synchronized in backward passes. |
| 165 | + # Therefore, saving it in one process is sufficient. |
| 166 | + torch.save(ddp_model.state_dict(), CHECKPOINT_PATH) |
| 167 | +
|
| 168 | + # Use a barrier() to make sure that process 1 loads the model after process |
| 169 | + # 0 saves it. |
| 170 | + dist.barrier() |
| 171 | + # configure map_location properly |
| 172 | + rank0_devices = [x - rank * len(device_ids) for x in device_ids] |
| 173 | + device_pairs = zip(rank0_devices, device_ids) |
| 174 | + map_location = {'cuda:%d' % x: 'cuda:%d' % y for x, y in device_pairs} |
| 175 | + ddp_model.load_state_dict( |
| 176 | + torch.load(CHECKPOINT_PATH, map_location=map_location)) |
| 177 | +
|
| 178 | + optimizer.zero_grad() |
| 179 | + outputs = ddp_model(torch.randn(20, 10)) |
| 180 | + labels = torch.randn(20, 5).to(device_ids[0]) |
| 181 | + loss_fn = nn.MSELoss() |
| 182 | + loss_fn(outputs, labels).backward() |
| 183 | + optimizer.step() |
| 184 | +
|
| 185 | + # Use a barrier() to make sure that all processes have finished reading the |
| 186 | + # checkpoint |
| 187 | + dist.barrier() |
| 188 | +
|
| 189 | + if rank == 0: |
| 190 | + os.remove(CHECKPOINT_PATH) |
| 191 | +
|
| 192 | + cleanup() |
| 193 | +
|
| 194 | +Combine DDP with Model Parallelism |
| 195 | +---------------------------------- |
| 196 | + |
| 197 | +DDP also works with multi-GPU models, but replications within a process are not |
| 198 | +supported. You need to create one process per module replica, which usually |
| 199 | +leads to better performance compared to multiple replicas per process. DDP |
| 200 | +wrapping multi-GPU models is especially helpful when training large models with |
| 201 | +a huge amount of data. When using this feature, the multi-GPU model needs to be |
| 202 | +carefully implemented to avoid hard-coded devices, because different model |
| 203 | +replicas will be placed to different devices. |
| 204 | + |
| 205 | +.. code:: python |
| 206 | +
|
| 207 | + class ToyMpModel(nn.Module): |
| 208 | + def __init__(self, dev0, dev1): |
| 209 | + super(ToyMpModel, self).__init__() |
| 210 | + self.dev0 = dev0 |
| 211 | + self.dev1 = dev1 |
| 212 | + self.net1 = torch.nn.Linear(10, 10).to(dev0) |
| 213 | + self.relu = torch.nn.ReLU() |
| 214 | + self.net2 = torch.nn.Linear(10, 5).to(dev1) |
| 215 | +
|
| 216 | + def forward(self, x): |
| 217 | + x = x.to(self.dev0) |
| 218 | + x = self.relu(self.net1(x)) |
| 219 | + x = x.to(self.dev1) |
| 220 | + return self.net2(x) |
| 221 | +
|
| 222 | +When passing a multi-GPU model to DDP, ``device_ids`` and ``output_device`` |
| 223 | +must NOT be set. Input and output data will be placed in proper devices by |
| 224 | +either the application or the model ``forward()`` method. |
| 225 | + |
| 226 | +.. code:: python |
| 227 | +
|
| 228 | + def demo_model_parallel(rank, world_size): |
| 229 | + setup(rank, world_size) |
| 230 | +
|
| 231 | + # setup mp_model and devices for this process |
| 232 | + dev0 = rank * 2 |
| 233 | + dev1 = rank * 2 + 1 |
| 234 | + mp_model = ToyMpModel(dev0, dev1) |
| 235 | + ddp_mp_model = DDP(mp_model) |
| 236 | +
|
| 237 | + loss_fn = nn.MSELoss() |
| 238 | + optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001) |
| 239 | +
|
| 240 | + optimizer.zero_grad() |
| 241 | + # outputs will be on dev1 |
| 242 | + outputs = ddp_mp_model(torch.randn(20, 10)) |
| 243 | + labels = torch.randn(20, 5).to(dev1) |
| 244 | + loss_fn(outputs, labels).backward() |
| 245 | + optimizer.step() |
| 246 | +
|
| 247 | + cleanup() |
| 248 | +
|
| 249 | +
|
| 250 | + if __name__ == "__main__": |
| 251 | + run_demo(demo_basic, 2) |
| 252 | + run_demo(demo_checkpoint, 2) |
| 253 | +
|
| 254 | + if torch.cuda.device_count() >= 8: |
| 255 | + run_demo(demo_model_parallel, 4) |
0 commit comments