Skip to content

Commit

Permalink
Support torch.distributed.irecv(src=None, ...) (pytorch#47137)
Browse files Browse the repository at this point in the history
Summary:
Calling torch.distributed.irecv(src=None) fails with "The global rank None is not part of the group". This change calls recv_anysource if src is None. Tested locally with MPI backend.

Pull Request resolved: pytorch#47137

Reviewed By: heitorschueroff

Differential Revision: D25292656

fbshipit-source-id: beb018ba0b676924aeaabeb4a4d6acf96e4a1926
  • Loading branch information
froody authored and facebook-github-bot committed Dec 4, 2020
1 parent e1f9542 commit 4eb4db7
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 14 deletions.
17 changes: 11 additions & 6 deletions torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,15 +702,16 @@ def isend(tensor,


def irecv(tensor,
src,
src=None,
group=group.WORLD,
tag=0):
"""
Receives a tensor asynchronously.
Arguments:
tensor (Tensor): Tensor to fill with received data.
src (int): Source rank.
src (int, optional): Source rank. Will receive from any
process if unspecified.
group (ProcessGroup, optional): The process group to work on
tag (int, optional): Tag to match recv with remote send
Expand All @@ -724,11 +725,15 @@ def irecv(tensor,
return

if group == GroupMember.WORLD:
default_pg = _check_default_pg()
return default_pg.recv([tensor], src, tag)
pg = _check_default_pg()
else:
group_src_rank = _get_group_rank(group, src)
return group.recv([tensor], group_src_rank, tag)
pg = group

if src is None:
return pg.recv_anysource([tensor], tag)
else:
group_src_rank = _get_group_rank(pg, src)
return pg.recv([tensor], group_src_rank, tag)


def send(tensor,
Expand Down
28 changes: 20 additions & 8 deletions torch/testing/_internal/distributed/distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,26 +850,38 @@ def test_send_recv_any_source(self):
rank = dist.get_rank()
tensor = _build_tensor(10, value=rank)
recv_ranks = set()
irecv_ranks = set()

for dst in range(0, dist.get_world_size()):
if dst == rank:
# Recv mode
for dst in range(0, dist.get_world_size()):
if dst == rank:
continue
output_tensor = _build_tensor(10, value=-1)
sender = dist.recv(output_tensor)

# Assert the scalar value "sender" that should be
# equal to the rank of the sender is equal to all
# values in the received tensor.
self.assertTrue(output_tensor.eq(sender).all())
recv_ranks.add(sender)
for recv in ["recv", "irecv"]:
output_tensor = _build_tensor(10, value=-1)

if recv == "recv":
sender = dist.recv(output_tensor)
recv_ranks.add(sender)
elif recv == "irecv":
work = dist.irecv(output_tensor)
work.wait()
sender = work._source_rank()
irecv_ranks.add(sender)

# Assert the scalar value "sender" that should be
# equal to the rank of the sender is equal to all
# values in the received tensor.
self.assertTrue(output_tensor.eq(sender).all())
else:
# Send mode
dist.send(tensor, dst)
dist.send(tensor, dst) # recv
dist.send(tensor, dst) # irecv

self.assertEqual(len(recv_ranks), dist.get_world_size() - 1)
self.assertEqual(len(irecv_ranks), dist.get_world_size() - 1)
self._barrier()

# SEND RECV WITH TAG
Expand Down

0 comments on commit 4eb4db7

Please sign in to comment.