Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
dont dispatch aten.conj(scalar_tensor) back to python (pytorch#131482)
pytorch#105290 The problem in the original flow is that: (1) the user calls `torch.mul(complex_tensor, complex_scalar) (2) python arg parser wraps the complex scalar in a `scalar_tensor`, and dispatches to `aten.mul.Tensor(self, scalar_other)` (3) autograd sees `aten.mul.Tensor`, calls `scalar_other.conj()` [here](https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/FunctionsManual.cpp#L597) (4) during proxy tensor tracing, this gets dispatched to `aten._conj(scalar_tensor)` (5) when we hit __torch_dispatch__, the scalar_tensor is converted back into a plain python scalar (6) we error during tracing, because in `FunctionalTensorMode.__torch_dispatch__` we try to redispatch on `aten._conj.default(plain_python_scalar)`, and this overload does not accept python scalars. My attempted fix in this PR is to update `TensorBase::conj()` to check if the current tensor is a scalar tensor (wrapped number), and if so, manually: (1) convert the scalar tensor back into a scalar (2) call scalar.conj() directly (3) convert the result back into a wrapped tensor This avoids having to go through python entirely in the tracing case (which is fine, because these scalar tensors are constants that we can const-prop during tracing anyway). Notable, I did **not** add e.g. a new `aten._conj.Scalar` overload. This would not actually fix the problem, since the bug is that we call `aten._conj.default(python_scalar)` directly. we would also need to muck with all `__torch_dispatch__` call sites to know to convert python scalars back into tensors directly. Pull Request resolved: pytorch#131482 Approved by: https://github.com/zou3519, https://github.com/ezyang ghstack dependencies: pytorch#131403
- Loading branch information