Skip to content

Commit 26ac567

Browse files
Revert "Fix CompiledDDP failure when the gradient is not contiguous (pytorch#138174)"
This reverts commit 0ecafda. Reverted pytorch#138174 on behalf of https://github.com/huydhn due to Sorry for reverting your PR, but I think it fails test_compute_comm_reordering in trunk for rocm and multigpu setup ([comment](pytorch#138174 (comment)))
1 parent 98856f7 commit 26ac567

File tree

2 files changed

+8
-16
lines changed

2 files changed

+8
-16
lines changed

test/distributed/test_c10d_functional_native.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -597,11 +597,14 @@ def func(arg: torch.Tensor) -> torch.Tensor:
597597
(
598598
FileCheck()
599599
.check("buf0 = empty")
600-
# We always call .contiguous() on the input to all_reduce_,
601-
# so input will not be a view anymore.
602-
.check("torch.ops._c10d_functional.all_reduce_.default(buf0")
603-
.check("torch.ops._c10d_functional.wait_tensor.default(buf0")
604-
.check("return (buf0")
600+
# Ensure the all_reduce_ input is a view
601+
.check(
602+
"torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf0"
603+
)
604+
.check(
605+
"torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf0"
606+
)
607+
.check("return (reinterpret_tensor(buf0")
605608
.run(code)
606609
)
607610

@@ -621,16 +624,6 @@ def func(arg: torch.Tensor) -> torch.Tensor:
621624
# clone induced by non contig input
622625
assert "torch.ops._c10d_functional.wait_tensor.default" in code
623626

624-
def func2(arg: torch.Tensor) -> torch.Tensor:
625-
torch.ops._c10d_functional.all_reduce_(arg, "avg", "0")
626-
return arg
627-
628-
compiled = torch.compile(func)
629-
630-
code = run_and_get_triton_code(compiled, arg)
631-
# clone induced by non contig input
632-
assert "torch.ops._c10d_functional.wait_tensor.default" in code
633-
634627
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
635628
@fresh_inductor_cache()
636629
def test_inductor_reuse_buffer_after_inplace_collective(self):

torch/_inductor/lowering.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6470,7 +6470,6 @@ def _all_reduce(inp, reduce_op, group_name):
64706470

64716471
@register_lowering(_c10d_functional.all_reduce_)
64726472
def _all_reduce_(inp, reduce_op, group_name):
6473-
inp = ir.ExternKernel.require_contiguous(inp)
64746473
ir._CollectiveKernel.create_inplace(
64756474
_c10d_functional.all_reduce_.default, inp, reduce_op, group_name
64766475
)

0 commit comments

Comments
 (0)