Skip to content

PyTorch goes distributed #241

Closed
Closed
@apaszke

Description

Together with @0mp, @VirrageS andy @jytug we're developing a torch.distributed package for PyTorch. All work is done in a fork on a thd branch (we didn't want to make a lot of unnecessary noise in the main repo). We're creating this issue, so we can gather feedback on our API designs from all you guys.

We plan to make the package have two modes. The user has to choose one of them as part of the initialisation.

Process group mode

This is very similar to the API defined in MPI. We assume all processes are equal, assign them ranks and later on, allow them to use a well known set of communication collectives like reduce, broadcast, allReduce, gather, scatter, etc.

Example:

import torch.distributed
torch.distributed.init_process_group(backend='tcp')
my_rank = torch.distributed.get_rank()
num_processes = torch.distributed.get_num_processes()

...

if my_rank == 0:
    torch.distributed.send(tensor, 1)
else:
    tensor = torch.distributed.recv(0)

...

result = torch.distributed.all_reduce(tensor)

Master-worker mode

This would provide a very similar API to the torch.cuda package. At the beginning of your script you would have to call torch.distributed.init_master_worker(backend='mpi')

Operation execution is asynchronous w.r.t. to the master process, we'll implement a CUDA-like concurrency model (streams + events). Until then, the only sync points are copies between master and workers.

Example:

import torch.distributed
torch.distributed.init_master_worker(backend='tcp')

x = torch.distributed.FloatTensor(20, 20).fill_(4)
y = torch.randn(20, 20).dist_send()
z = x + y
# z.get_node(), z.get_device() == 0, -1 (i.e. CPU)
cuda_x = x.cuda()
# cuda_x.get_node(), cuda_x.get_device() == 0, 0
with torch.distributed.node(1):
    a = torch.distributed.FloatTensor(10, device=1)
    # a.get_node(), a.get_device() == 1, 1
    cuda_y = y.cuda()
    # cuda_y.get_node(), cuda_y.get_device() == 0, 0
    q = cuda_x + cuda_y
    # q.get_node(), q.get_device() == 0, 0

How to launch the jobs

We'll provide a pytorch_exec utility that will spawn the process groups in a similar fashion that mpiexec does.

Decoupling data backends from other logic

You might have noticed that both init_process_group and init_master_worker accept a backend argument. We're aware that the best strategy for sending the data might be different for every user, and it will be crucial to pick a good one to limit communication overhead. This was the reason why we decided to introduce a DataChannel interface, so users will be able to pick from one of the provided implementations (initially MPI and raw TCP sockets, later RDMA etc.), or add custom ones, so they can easily achieve the lowest overhead possible in their setup.

Please let us know what you think! Thanks!

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions