Skip to content

Add _allgather_base & _reduce_scatter_base to dist backend #3919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions test/test_torch_distributed_all_gather_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@
import torch.distributed as dist


def _mp_fn(index):
def _test_allgather():
device = xm.xla_device()
if xm.xla_device_hw(device) in ('TPU', 'GPU'):
world_size = xm.xrt_world_size()
rank = xm.get_ordinal()

dist.init_process_group('xla', world_size=world_size, rank=rank)

input = torch.ones((2, 3)) * rank
outputs = [torch.zeros_like(input)] * world_size
xinput = input.to(device)
Expand All @@ -33,5 +31,36 @@ def _mp_fn(index):
file=sys.stderr)


def _test__allgather_base():
device = xm.xla_device()
if xm.xla_device_hw(device) in ('TPU', 'GPU'):
world_size = xm.xrt_world_size()
rank = xm.get_ordinal()

input = torch.ones((2, 3)) * rank
output = torch.zeros((2 * world_size, 3))
xinput = input.to(device)
xoutput = output.to(device)
dist._all_gather_base(xoutput, xinput)
xoutputs = torch.split(xoutput, world_size)
for i, o in enumerate(xoutputs):
expected = torch.ones((2, 3)) * i
assert torch.all(o.cpu() == expected), f'{o} != {expected}'
else:
print(
'Default device {} is not a TPU or GPU device'.format(device),
file=sys.stderr)


def _mp_fn(index):
device = xm.xla_device()
if xm.xla_device_hw(device) in ('TPU', 'GPU'):
world_size = xm.xrt_world_size()
rank = xm.get_ordinal()
dist.init_process_group('xla', world_size=world_size, rank=rank)
_test_allgather()
_test__allgather_base()


if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())
34 changes: 31 additions & 3 deletions test/test_torch_distributed_reduce_scatter_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@
import torch.distributed as dist


def _mp_fn(index):
def _test_reduce_scatter():
device = xm.xla_device()
if xm.xla_device_hw(device) in ('TPU', 'GPU'):
world_size = xm.xrt_world_size()
rank = xm.get_ordinal()

dist.init_process_group('xla', world_size=world_size, rank=rank)

input_size = (32, 3)
inputs = torch.ones(input_size).split(input_size[0] // world_size)
output = torch.zeros_like(inputs[0])
Expand All @@ -30,5 +28,35 @@ def _mp_fn(index):
file=sys.stderr)


def _test__reduce_scatter_base():
device = xm.xla_device()
if xm.xla_device_hw(device) in ('TPU', 'GPU'):
world_size = xm.xrt_world_size()
rank = xm.get_ordinal()

input_size = (32, 3)
input = torch.ones(input_size)
output = torch.zeros((input_size[0] // world_size, input_size[1]))
xinput = input.to(device)
xoutput = output.to(device)
dist._reduce_scatter_base(xoutput, xinput)
expected = torch.ones_like(output) * world_size
assert torch.all(xoutput.cpu() == expected), f'{xoutput} != {expected}'
else:
print(
'Default device {} is not a TPU or GPU device'.format(device),
file=sys.stderr)


def _mp_fn(index):
device = xm.xla_device()
if xm.xla_device_hw(device) in ('TPU', 'GPU'):
world_size = xm.xrt_world_size()
rank = xm.get_ordinal()
dist.init_process_group('xla', world_size=world_size, rank=rank)
_test_reduce_scatter()
_test__reduce_scatter_base()


if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())
19 changes: 19 additions & 0 deletions torch_xla/distributed/xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def allgather(self, output_tensors_list, input_tensors):

return _ret_work([t for sublist in output_tensors_list for t in sublist])

def _allgather_base(self, output, input):
xm.all_gather(input, output=output, groups=self._mesh)
return _ret_work([output])

# Call site:
# https://github.com/pytorch/pytorch/blob/release/1.10/torch/distributed/distributed_c10d.py#L1129
def broadcast(self, tensors, opts):
Expand Down Expand Up @@ -117,6 +121,21 @@ def reduce_scatter(self, output_tensors, input_tensors_list, opts):

return _ret_work(output_tensors)

def _reduce_scatter_base(self, output, input, opts):
reduce_type = self._get_reduce_type(opts.reduceOp)
groups = self._mesh
shard_count = len(groups[0]) if groups else self.size()
xm.reduce_scatter(
reduce_type,
input,
scatter_dim=0,
shard_count=shard_count,
scale=1,
groups=groups,
output=output,
pin_layout=False)
return _ret_work(output)

# Call site:
# https://github.com/pytorch/pytorch/blob/70f57bcb1e45d21532bdb1c44d3aab018d1cbe88/torch/distributed/distributed_c10d.py#L2683
def barrier(self, opts):
Expand Down