Skip to content

Commit d950f45

Browse files
Revert "[Functional Collectives] Migrate DeviceMesh::all_reduce to use functional all_reduce. (pytorch#95009)"
This reverts commit 0765dbc. Reverted pytorch#95009 on behalf of https://github.com/jeanschmidt due to this PR is causing internal breakages. Check https://fburl.com/diff/me41urq8
1 parent 1cf11c1 commit d950f45

File tree

6 files changed

+26
-21
lines changed

6 files changed

+26
-21
lines changed

test/distributed/_spmd/test_tracing.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,10 @@ def _test_tracing_all_reduce_nd(self, mesh_tensor):
4747
]
4848

4949
def fn(tensor: torch.Tensor):
50-
tensor = mesh.all_reduce(tensor, mesh_dim=dim)
50+
tensor_to_reduce = CommTensor(tensor.clone())
51+
mesh.all_reduce(tensor_to_reduce, mesh_dim=dim)
5152
# multiply with 1 to trigger wait on read during tracing.
52-
return tensor * 1
53+
return tensor_to_reduce * 1
5354

5455
# use a local_tensor + 1 for tracing to make sure that we are not
5556
# simply replaying recorded tensor value

test/distributed/_tensor/test_device_mesh.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
is_initialized,
1414
new_group,
1515
ProcessGroup,
16-
get_process_group_ranks
1716
)
1817
from torch.testing._internal.common_utils import run_tests
1918
from torch.testing._internal.distributed._tensor.common_dtensor import (
@@ -240,8 +239,7 @@ def world_size(self):
240239
def test_all_reduce_1d(self):
241240
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
242241
local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank
243-
# We have to clone the result tensor because assertEqual fails to compare AsyncTensor with plain tensor.
244-
local_tensor = mesh.all_reduce(local_tensor, mesh_dim=0).clone()
242+
mesh.all_reduce(local_tensor, mesh_dim=0)
245243
res_num = ((0 + self.world_size - 1) * self.world_size) / 2
246244
self.assertEqual(local_tensor, torch.ones(3, 3) * res_num)
247245

@@ -481,9 +479,12 @@ def test_all_reduce_nd(self):
481479
# check all dim groups
482480
dim_to_subgroups = mesh.get_dim_groups()
483481
for dim, dim_group in enumerate(dim_to_subgroups):
484-
global_ranks = get_process_group_ranks(dim_group)
482+
dim_group_size = get_world_size(dim_group)
483+
global_ranks = [
484+
get_global_rank(dim_group, i) for i in range(dim_group_size)
485+
]
485486
cloned_local_tensor = local_tensor.clone()
486-
cloned_local_tensor = mesh.all_reduce(cloned_local_tensor, mesh_dim=dim).clone()
487+
mesh.all_reduce(cloned_local_tensor, mesh_dim=dim)
487488
res_num = sum(global_ranks)
488489
self.assertEqual(cloned_local_tensor, torch.ones(3, 3) * res_num)
489490

torch/distributed/_functional_collectives.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def _all_reduce(self, reduceOp, tag, ranks, group_size):
145145
group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size)
146146
assert group is not None
147147

148-
inplace_tensor = self.clone(memory_format=torch.contiguous_format)
148+
inplace_tensor = self.clone()
149149
work = dist.all_reduce(inplace_tensor, op=op, group=group, async_op=True)
150150
_register_tensor_work(inplace_tensor, work)
151151

torch/distributed/_spmd/distribute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def _convert_output(
249249

250250
traced_dispatch, result_obj = _build_dummy_add_graph(dt, node_to_obj)
251251

252-
wait = [n for n in traced_dispatch.graph.nodes if n.name == "wait_comm" or n.name == "wait_tensor"]
252+
wait = [n for n in traced_dispatch.graph.nodes if n.name == "wait_comm"]
253253
add = [n for n in traced_dispatch.graph.nodes if n.name == "add"]
254254
assert len(wait) == 1 and len(add) == 1
255255

torch/distributed/_tensor/device_mesh.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch.distributed.distributed_c10d import (
88
_get_default_group,
99
all_gather,
10+
all_reduce,
1011
all_to_all,
1112
broadcast,
1213
get_global_rank,
@@ -22,9 +23,6 @@
2223
scatter,
2324
Work,
2425
)
25-
import torch.distributed.distributed_c10d as c10d
26-
27-
import torch.distributed._functional_collectives as funcol
2826

2927
_global_device_mesh: Optional["DeviceMesh"] = None
3028

@@ -420,7 +418,8 @@ def all_reduce(
420418
tensor: torch.Tensor,
421419
op: ReduceOp = ReduceOp.SUM, # type: ignore[assignment]
422420
mesh_dim: int = 0,
423-
) -> torch.Tensor:
421+
async_op: bool = False,
422+
) -> Optional[Work]:
424423
"""
425424
all_reduce the tensor on each rank on a device mesh dimension, and
426425
return an output tensor on each rank after all_reduce.
@@ -433,10 +432,10 @@ def all_reduce(
433432
to reduce on.
434433
435434
Returns:
436-
A :class:`torch.Tensor` object
435+
A :class:`Work` object
437436
"""
438-
op_name: str = op.name # type: ignore[attr-defined]
439-
return funcol.all_reduce(tensor, reduceOp=op_name, group=(self, mesh_dim,))
437+
dim_group = self._dim_groups[mesh_dim]
438+
return all_reduce(tensor, op=op, group=dim_group, async_op=async_op)
440439

441440
def reduce_scatter(
442441
self,
@@ -494,9 +493,9 @@ def reduce_scatter(
494493
flat_tensor = torch.cat(flattened_list).clone(
495494
memory_format=torch.contiguous_format
496495
)
497-
dim_group = self._dim_groups[mesh_dim]
498-
fut = c10d.all_reduce(flat_tensor, op=op, group=dim_group, async_op=async_op)
499-
496+
fut = self.all_reduce(
497+
flat_tensor, op=op, mesh_dim=mesh_dim, async_op=async_op
498+
)
500499
# scatter the tensor
501500
output_offset = offset_list[my_coordinate]
502501
output.copy_(

torch/distributed/_tensor/placement_types.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,13 @@ def __init__(self, reduce_op: c10d.ReduceOp = c10d.ReduceOp.SUM): # type: ignor
250250
def _to_replicate(
251251
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
252252
) -> torch.Tensor:
253-
return mesh.all_reduce(
254-
tensor, self.reduce_op, mesh_dim=mesh_dim # type: ignore[call-arg]
253+
# out-of-place all_reduce to replicate, since the current partial DTensor
254+
# might get used by other ops as well, so we can't inplace modify it
255+
cloned_local = CommTensor(tensor.clone(memory_format=torch.contiguous_format))
256+
mesh.all_reduce(
257+
cloned_local, self.reduce_op, mesh_dim=mesh_dim # type: ignore[call-arg]
255258
)
259+
return cloned_local
256260

257261
def _to_shard(
258262
self,

0 commit comments

Comments
 (0)