Skip to content

Commit

Permalink
dont dispatch aten.conj(scalar_tensor) back to python (pytorch#131482)
Browse files Browse the repository at this point in the history
pytorch#105290

The problem in the original flow is that:

(1) the user calls `torch.mul(complex_tensor, complex_scalar)
(2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)`
(3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597)
(4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)`
(5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar
(6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars.

My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually:
(1) convert the scalar tensor back into a scalar
(2) call scalar.conj() directly
(3) convert the result back into a wrapped tensor

This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway).

Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly.

Pull Request resolved: pytorch#131482
Approved by: https://github.com/zou3519, https://github.com/ezyang
ghstack dependencies: pytorch#131403
  • Loading branch information
bdhirsh authored and pytorchmergebot committed Jul 26, 2024
1 parent 8bb9aa9 commit 5570a0d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
15 changes: 15 additions & 0 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -5439,6 +5439,21 @@ def fn(x, d):
d["b"][0].x = 12
self.assertEqual(fn(inp, d), opt_fn(inp, d))

def test_compile_complex_conj(self):
def f(x):
return torch.mul(x, 2j)

x_ref = torch.randn(4, 2, requires_grad=True)
x_test = x_ref.clone().detach().requires_grad_(True)

out_ref = f(torch.view_as_complex(x_ref))
out_test = torch.compile(f, backend="aot_eager")(torch.view_as_complex(x_test))
self.assertEqual(out_ref, out_test)

torch.view_as_real(out_ref).sum().backward()
torch.view_as_real(out_test).sum().backward()
self.assertEqual(x_ref.grad, x_test.grad)

def test_changing_stride(self):
cnt = torch._dynamo.testing.CompileCounter()

Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/utils/python_arg_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ bool should_allow_numbers_as_tensors(const std::string& name) {
"subtract", "subtract_", "subtract_out", // alias of sub
"true_divide", "true_divide_", "true_divide_out",
"to", "_to_copy", "copy_",
"floor_divide", "floor_divide_", "floor_divide_out"};
"floor_divide", "floor_divide_", "floor_divide_out",
"_conj"}; // _conj needed because mul.Tensor backward calls it
return allowed.find(name) != allowed.end();
}

Expand Down

0 comments on commit 5570a0d

Please sign in to comment.