Skip to content
This repository was archived by the owner on Sep 19, 2022. It is now read-only.

Change Distributed Data Parallel example #124

Merged
merged 2 commits into from
Jan 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/ddp/mnist/cpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ WORKDIR /workspace
RUN chmod -R a+w /workspace

ADD . /opt/pytorch_dist_mnist
ENTRYPOINT ["mpirun", "-n", "4", "--allow-run-as-root", "python", "-u", "/opt/pytorch_dist_mnist/mnist_ddp_cpu.py"]
ENTRYPOINT ["mpirun", "-n", "2", "--allow-run-as-root", "python", "-u", "/opt/pytorch_dist_mnist/mnist_ddp_cpu.py"]
1 change: 0 additions & 1 deletion examples/ddp/mnist/cpu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,3 @@
```

**Note.** Each copy will utilise 1 CPU. You can binding each process a CPU using `-cpu-slot`. For more reference visit [mpirun docuentation](https://www.open-mpi.org/doc/v3.0/man1/mpirun.1.php).

56 changes: 2 additions & 54 deletions examples/ddp/mnist/cpu/mnist_ddp_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch.optim as optim
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.nn.modules import Module
import torch.nn.parallel as ddp

from math import ceil
from random import Random
Expand All @@ -20,50 +21,6 @@

gbatch_size = 128

class DistributedDataParallel(Module):
def __init__(self, module):
super(DistributedDataParallel, self).__init__()
self.module = module
self.first_call = True

def allreduce_params():
if (self.needs_reduction):
self.needs_reduction = False
buckets = {}
for param in self.module.parameters():
if param.requires_grad and param.grad is not None:
tp = type(param.data)
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced)
coalesced /= dist.get_world_size()
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
for param in list(self.module.parameters()):
def allreduce_hook(*unused):
Variable._execution_engine.queue_callback(allreduce_params)

if param.requires_grad:
param.register_hook(allreduce_hook)

def weight_broadcast(self):
for param in self.module.parameters():
dist.broadcast(param.data, 0)

def forward(self, *inputs, **kwargs):
if self.first_call:
print("first broadcast start")
self.weight_broadcast()
self.first_call = False
print("first broadcast done")
self.needs_reduction = True
return self.module(*inputs, **kwargs)

class Partition(object):
""" Dataset-like object, but only access a subset of it. """

Expand Down Expand Up @@ -138,20 +95,12 @@ def partition_dataset(rank):
dataset, batch_size=bsz, shuffle=(train_sampler is None), sampler=train_sampler)
return train_set, bsz

def average_gradients(model):
""" Gradient averaging. """
size = float(dist.get_world_size())
for param in model.parameters():
dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM, group=0)
param.grad.data /= size


def run(rank, size):
""" Distributed Synchronous SGD Example """
torch.manual_seed(1234)
train_set, bsz = partition_dataset(rank)
model = Net()
model = DistributedDataParallel(model)
model = ddp.DistributedDataParallelCPU(model)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
num_batches = ceil(len(train_set.dataset) / float(bsz))
print("num_batches = ", num_batches)
Expand All @@ -165,7 +114,6 @@ def run(rank, size):
loss = F.nll_loss(output, target)
epoch_loss += loss.item()
loss.backward()
average_gradients(model)
optimizer.step()
print('Epoch {} Loss {:.6f} Global batch size {} on {} ranks'.format(
epoch, epoch_loss / num_batches, gbatch_size, dist.get_world_size()))
Expand Down
54 changes: 2 additions & 52 deletions examples/ddp/mnist/gpu/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,54 +1,4 @@
FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04
ARG PYTHON_VERSION=3.6
FROM pytorch/pytorch:0.4_cuda9_cudnn7

RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cmake \
git \
curl \
vim \
wget \
ca-certificates \
openssh-client \
libjpeg-dev \
libpng-dev &&\
rm -rf /var/lib/apt/lists/*

RUN wget https://www.open-mpi.org/software/ompi/v3.0/downloads/openmpi-3.0.0.tar.gz && \
gunzip -c openmpi-3.0.0.tar.gz | tar xf - && \
cd openmpi-3.0.0 && \
./configure --prefix=/home/.openmpi --with-cuda && \
make all install

ENV PATH="$PATH:/home/.openmpi/bin"
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/home/.openmpi/lib/"

RUN ompi_info --parsable --all | grep mpi_built_with_cuda_support:value
RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
chmod +x ~/miniconda.sh && \
~/miniconda.sh -b -p /opt/conda && \
rm ~/miniconda.sh && \
/opt/conda/bin/conda update conda && \
/opt/conda/bin/conda install -y python=$PYTHON_VERSION numpy pyyaml scipy ipython mkl mkl-include cython typing && \
/opt/conda/bin/conda clean -ya
ENV PATH /opt/conda/bin:$PATH
# This must be done before pip so that requirements.txt is available
WORKDIR /opt/pytorch

RUN git clone --recursive https://github.com/pytorch/pytorch

RUN TORCH_CUDA_ARCH_LIST="3.5 5.2 6.0 6.1 7.0+PTX" TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \
CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" \
cd pytorch/ && \
pip install -v .

RUN /opt/conda/bin/conda config --set ssl_verify False
RUN pip install --upgrade pip --trusted-host pypi.org --trusted-host files.pythonhosted.org
RUN pip install --trusted-host pypi.org --trusted-host files.pythonhosted.org torchvision
WORKDIR /workspace
RUN chmod -R a+w /workspace
ADD . /opt/pytorch_dist_mnist

ENTRYPOINT ["mpirun", "-n", "4", "--allow-run-as-root", "python", "-u", "/opt/pytorch_dist_mnist/mnist_ddp_gpu.py"]


ENTRYPOINT ["python", "-u", "-m", "torch.distributed.launch", "--nproc_per_node", "2", "/opt/pytorch_dist_mnist/mnist_ddp_gpu.py"]
7 changes: 2 additions & 5 deletions examples/ddp/mnist/gpu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@
```
**Keep in mind!** The number of GPUs used by workers and master should be less of equal to the number of available GPUs on your cluster/system. If you should have less, then we recommend you to reduce the number of workers, or use master ony (in case you have 1 GPU).

2. MPI can spawn different number of copies. It is controlled by mpirun -n inside the Dockerfile.
2. If you have only 1 GPU you have to control number of process per node. Parameter ```nproc_per_node``` inside Dockerfile should be equal to the dist world size.

```
mpirun -n <number_of_copies>
torch.distributed.launch --nproc_per_node <number of process per node>
```

**Note.** Each copy will utilise 1 GPU.

59 changes: 4 additions & 55 deletions examples/ddp/mnist/gpu/mnist_ddp_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch.optim as optim
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.nn.modules import Module
import torch.nn.parallel as ddp

from math import ceil
from random import Random
Expand All @@ -21,50 +22,6 @@

gbatch_size = 128

class DistributedDataParallel(Module):
def __init__(self, module):
super(DistributedDataParallel, self).__init__()
self.module = module
self.first_call = True

def allreduce_params():
if (self.needs_reduction):
self.needs_reduction = False
buckets = {}
for param in self.module.parameters():
if param.requires_grad and param.grad is not None:
tp = type(param.data)
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced)
coalesced /= dist.get_world_size()
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
for param in list(self.module.parameters()):
def allreduce_hook(*unused):
Variable._execution_engine.queue_callback(allreduce_params)

if param.requires_grad:
param.register_hook(allreduce_hook)

def weight_broadcast(self):
for param in self.module.parameters():
dist.broadcast(param.data, 0)

def forward(self, *inputs, **kwargs):
if self.first_call:
print("first broadcast start")
self.weight_broadcast()
self.first_call = False
print("first broadcast done")
self.needs_reduction = True
return self.module(*inputs, **kwargs)

class Partition(object):
""" Dataset-like object, but only access a subset of it. """

Expand Down Expand Up @@ -139,21 +96,13 @@ def partition_dataset(rank):
dataset, batch_size=bsz, shuffle=(train_sampler is None), sampler=train_sampler)
return train_set, bsz

def average_gradients(model):
""" Gradient averaging. """
size = float(dist.get_world_size())
for param in model.parameters():
dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM, group=0)
param.grad.data /= size


def run(rank, size):
""" Distributed Synchronous SGD Example """
torch.manual_seed(1234)
train_set, bsz = partition_dataset(rank)
model = Net()
model = model.cuda()
model = DistributedDataParallel(model)
model = ddp.DistributedDataParallel(model)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

num_batches = ceil(len(train_set.dataset) / float(bsz))
Expand All @@ -168,7 +117,6 @@ def run(rank, size):
loss = F.nll_loss(output, target)
epoch_loss += loss.item()
loss.backward()
average_gradients(model)
optimizer.step()
print('Epoch {} Loss {:.6f} Global batch size {} on {} ranks'.format(
epoch, epoch_loss / num_batches, gbatch_size, dist.get_world_size()))
Expand Down Expand Up @@ -204,8 +152,9 @@ def write(self, x):
print("CUDA Device Name:", torch.cuda.get_device_name(0))
print("CUDA Version:", torch.version.cuda)
print("=========================\n")
dist.init_process_group(backend='mpi')
dist.init_process_group(backend='nccl')
size = dist.get_world_size()
print("world size = ",size)
rank = dist.get_rank()
init_print(rank, size)
run(rank, size)