From 1a0db29932655ebc545cb175eebef116c1247e61 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Fri, 2 Aug 2024 11:59:39 -0700 Subject: [PATCH] move torch._functionalize APIs to pybind. add one for marking storage mutations (#132337) Pull Request resolved: https://github.com/pytorch/pytorch/pull/132337 Approved by: https://github.com/albanD, https://github.com/justinchuby ghstack dependencies: #132243 --- aten/src/ATen/FunctionalTensorWrapper.h | 4 + test/onnx/test_fx_to_onnx_with_onnxruntime.py | 2 +- .../python_torch_functions_manual.cpp | 725 +++++------------- 3 files changed, 177 insertions(+), 554 deletions(-) diff --git a/aten/src/ATen/FunctionalTensorWrapper.h b/aten/src/ATen/FunctionalTensorWrapper.h index c8acfd3941f37e..7bea72a6cfb53a 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.h +++ b/aten/src/ATen/FunctionalTensorWrapper.h @@ -161,6 +161,10 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { return was_storage_changed_; } + void set_storage_changed() { + was_storage_changed_ = true; + } + c10::SymInt get_storage_size(bool before) { return functional_storage_impl()->get_storage_size(before); } diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index 4ffc725a07168d..68c50bc690f6a8 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -539,7 +539,7 @@ def forward(self, x): ) @pytorch_test_common.xfail_if_model_type_is_not_exportedprogram( - error_message="at::functionalization::impl::isFunctionalTensor(self_) INTERNAL ASSERT FAILED" + error_message="at::functionalization::impl::isFunctionalTensor(t) INTERNAL ASSERT FAILED" ) def test_expand_as_fill_separate_tensor(self): class Model(torch.nn.Module): diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp index c36cf275a6b2a1..9bc66c693cfd79 100644 --- a/torch/csrc/autograd/python_torch_functions_manual.cpp +++ b/torch/csrc/autograd/python_torch_functions_manual.cpp @@ -360,463 +360,6 @@ static PyObject* THPVariable_numel( PyObject* args, PyObject* kwargs); -static PyObject* THPVariable__to_functional_tensor( - PyObject* self, - PyObject* args, - PyObject* kwargs) { - HANDLE_TH_ERRORS - static PythonArgParser parser( - {"_to_functional_tensor(Tensor t)"}, - /*traceable=*/true); - - ParsedArgs<2> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - auto self_ = r.tensor(0); - auto wrapped = at::functionalization::impl::to_functional_tensor(self_); - return wrap(std::move(wrapped)); - END_HANDLE_TH_ERRORS -} - -// Given source and dest tensors, -// Sets **some** (but not all) autograd metadata on dest, according to source: -// - requires_grad -// - grad_fn -// (If src has a grad_fn, we install an error grad_fn on dest to avoid -// difficult bugs. -// The main purpose is to ensure that dst.is_leaf == src.is_leaf) -static PyObject* THPVariable__mirror_autograd_meta_to( - PyObject* self, - PyObject* args, - PyObject* kwargs) { - HANDLE_TH_ERRORS - static PythonArgParser parser( - {"_mirror_autograd_meta_to(Tensor source, Tensor dest)"}, - /*traceable=*/true); - - ParsedArgs<2> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - auto src_ = r.tensor(0); - auto dst_ = r.tensor(1); - // Here, we unsafely set the grad function on the wrapper to be the same as - // the inner. We expect this grad_fn to NEVER be used. It's needed so that - // .is_leaf metadata is accurate on the wrapper - auto inner_autograd_meta = impl::get_autograd_meta(src_); - if (inner_autograd_meta) { - dst_.set_requires_grad(src_.requires_grad()); - if (dst_.requires_grad()) { - auto new_grad_fn = std::shared_ptr( - new torch::autograd::Error( - "Cannot backprop through mirrored meta, file a bug in PyTorch"), - torch::autograd::deleteNode); - torch::autograd::set_history(dst_, new_grad_fn); - } - } - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - -static PyObject* THPVariable__from_functional_tensor( - PyObject* self, - PyObject* args, - PyObject* kwargs) { - HANDLE_TH_ERRORS - static PythonArgParser parser( - {"_from_functional_tensor(Tensor t)"}, /*traceable=*/true); - - ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - auto self_ = r.tensor(0); - auto unwrapped = at::functionalization::impl::from_functional_tensor(self_); - return wrap(std::move(unwrapped)); - END_HANDLE_TH_ERRORS -} - -static PyObject* THPVariable__freeze_functional_tensor( - PyObject* self, - PyObject* args, - PyObject* kwargs) { - HANDLE_TH_ERRORS - static PythonArgParser parser( - {"_freeze_functional_tensor(Tensor t)"}, /*traceable=*/true); - - ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - auto self_ = r.tensor(0); - at::functionalization::impl::freeze_functional_tensor(self_); - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - -static PyObject* THPVariable__is_functional_tensor( - PyObject* self, - PyObject* args, - PyObject* kwargs) { - HANDLE_TH_ERRORS - static PythonArgParser parser( - {"_is_functional_tensor(Tensor t)"}, /*traceable=*/true); - - ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - auto self_ = r.tensor(0); - if (at::functionalization::impl::isFunctionalTensor(self_)) { - Py_RETURN_TRUE; - } else { - Py_RETURN_FALSE; - } - END_HANDLE_TH_ERRORS -} - -static PyObject* THPVariable__functionalize_was_storage_changed( - PyObject* self, - PyObject* args, - PyObject* kwargs) { - HANDLE_TH_ERRORS - static PythonArgParser parser( - {"_functionalize_was_storage_changed(Tensor t)"}, /*traceable=*/true); - - ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - auto self_ = r.tensor(0); - TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self_)); - auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(self_); - if (wrapper->was_storage_changed()) { - Py_RETURN_TRUE; - } else { - Py_RETURN_FALSE; - } - END_HANDLE_TH_ERRORS -} - -static PyObject* THPVariable__functionalize_get_storage_size( - PyObject* self, - PyObject* args, - PyObject* kwargs) { - HANDLE_TH_ERRORS - static PythonArgParser parser( - {"_functionalize_get_storage_size(Tensor t, bool before)"}, - /*traceable=*/true); - - ParsedArgs<2> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - auto self_ = r.tensor(0); - auto before = r.toBool(1); - TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self_)); - auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(self_); - auto size = wrapper->get_storage_size(/*before=*/before); - return toPyObject(size); - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - -static PyObject* THPVariable__functionalize_has_data_mutation( - PyObject* self, - PyObject* args, - PyObject* kwargs) { - HANDLE_TH_ERRORS - static PythonArgParser parser( - {"_functionalize_has_data_mutation(Tensor t)"}, /*traceable=*/true); - - ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - auto self_ = r.tensor(0); - TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self_)); - auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(self_); - if (wrapper->has_data_mutation()) { - Py_RETURN_TRUE; - } else { - Py_RETURN_FALSE; - } - END_HANDLE_TH_ERRORS -} - -static PyObject* THPVariable__functionalize_has_metadata_mutation( - PyObject* self, - PyObject* args, - PyObject* kwargs) { - HANDLE_TH_ERRORS - static PythonArgParser parser( - {"_functionalize_has_metadata_mutation(Tensor t)"}, /*traceable=*/true); - - ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - auto self_ = r.tensor(0); - TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self_)); - auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(self_); - if (wrapper->has_metadata_mutation()) { - Py_RETURN_TRUE; - } else { - Py_RETURN_FALSE; - } - END_HANDLE_TH_ERRORS -} - -static PyObject* THPVariable__enable_functionalization( - PyObject* self, - PyObject* args, - PyObject* kwargs) { - HANDLE_TH_ERRORS - static PythonArgParser parser( - {"_enable_functionalization(*, bool reapply_views=False)"}, - /*traceable=*/true); - ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - const auto reapply_views = r.toBool(0); - - if (c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Functionalize)) { - TORCH_INTERNAL_ASSERT( - false, - "multiple layers of mode-style functionalization nesting is not" - " currently supported, outside of the functionalize() transform"); - } - c10::impl::tls_set_dispatch_key_included( - at::DispatchKey::Functionalize, true); - if (reapply_views) { - at::functionalization::impl::setFunctionalizationReapplyViewsTLS(true); - } - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - -static PyObject* THPVariable__functionalize_enable_reapply_views( - PyObject* self, - PyObject* args, - PyObject* kwargs) { - HANDLE_TH_ERRORS - static PythonArgParser parser( - {"_functionalize_enable_reapply_views(bool reapply_views=False)"}, - /*traceable=*/true); - ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - const auto reapply_views = r.toBool(0); - auto old = at::functionalization::impl::getFunctionalizationReapplyViewsTLS(); - if (reapply_views) { - at::functionalization::impl::setFunctionalizationReapplyViewsTLS(true); - } else { - at::functionalization::impl::setFunctionalizationReapplyViewsTLS(false); - } - if (old) { - Py_RETURN_TRUE; - } else { - Py_RETURN_FALSE; - } - END_HANDLE_TH_ERRORS -} - -static PyObject* THPVariable__functionalize_is_multi_output_view( - PyObject* self, - PyObject* args, - PyObject* kwargs) { - HANDLE_TH_ERRORS - static PythonArgParser parser( - {"_functionalize_is_multi_output_view(Tensor t)"}, - /*traceable=*/true); - ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - auto t = r.tensor(0); - TORCH_CHECK(at::functionalization::impl::isFunctionalTensor(t)); - auto t_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t); - if (t_impl->is_multi_output_view()) { - Py_RETURN_TRUE; - } else { - Py_RETURN_FALSE; - } - END_HANDLE_TH_ERRORS -} - -static PyObject* THPVariable__disable_functionalization( - PyObject* self, - PyObject* args, - PyObject* kwargs) { - HANDLE_TH_ERRORS - c10::impl::tls_set_dispatch_key_included( - at::DispatchKey::Functionalize, false); - at::functionalization::impl::setFunctionalizationReapplyViewsTLS(false); - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - -static PyObject* THPVariable__functionalize_replace( - PyObject* self, - PyObject* args, - PyObject* kwargs) { - HANDLE_TH_ERRORS - static PythonArgParser parser( - {"_functionalize_replace(Tensor t, Tensor o)"}, /*traceable=*/true); - - ParsedArgs<2> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - auto self_ = r.tensor(0); - auto other = r.tensor(1); - TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self_)); - TORCH_INTERNAL_ASSERT( - !at::functionalization::impl::isFunctionalTensor(other)); - at::functionalization::impl::replace_(self_, other); - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - -static PyObject* THPVariable__functionalize_commit_update( - PyObject* self, - PyObject* args, - PyObject* kwargs) { - HANDLE_TH_ERRORS - static PythonArgParser parser( - {"_functionalize_commit_update(Tensor t)"}, /*traceable=*/true); - - ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - auto self_ = r.tensor(0); - TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self_)); - at::functionalization::impl::commit_update(self_); - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - -static PyObject* THPVariable__functionalize_sync( - PyObject* self, - PyObject* args, - PyObject* kwargs) { - HANDLE_TH_ERRORS - static PythonArgParser parser( - {"_functionalize_sync(Tensor t)"}, /*traceable=*/true); - - ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - auto self_ = r.tensor(0); - TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self_)); - at::functionalization::impl::sync(self_); - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - -static PyObject* THPVariable__functionalize_is_symbolic( - PyObject* self, - PyObject* args, - PyObject* kwargs) { - HANDLE_TH_ERRORS - static PythonArgParser parser( - {"_functionalize_is_symbolic(Tensor tensor)"}, - /*traceable=*/true); - - ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - auto tensor = r.tensor(0); - TORCH_INTERNAL_ASSERT( - at::functionalization::impl::isFunctionalTensor(tensor)); - auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor); - if (impl->is_symbolic()) { - Py_RETURN_TRUE; - } else { - Py_RETURN_FALSE; - } - END_HANDLE_TH_ERRORS -} - -static PyObject* THPVariable__functionalize_apply_view_metas( - PyObject* self, - PyObject* args, - PyObject* kwargs) { - HANDLE_TH_ERRORS - static PythonArgParser parser( - {"_functionalize_apply_view_metas(Tensor tensor, Tensor base)"}, - /*traceable=*/true); - - ParsedArgs<2> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - auto tensor = r.tensor(0); - TORCH_INTERNAL_ASSERT( - at::functionalization::impl::isFunctionalTensor(tensor)); - auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor); - return wrap(impl->apply_view_metas(r.tensor(1))); - END_HANDLE_TH_ERRORS -} - -static PyObject* THPVariable__functionalize_mark_mutation_hidden_from_autograd( - PyObject* self, - PyObject* args, - PyObject* kwargs) { - HANDLE_TH_ERRORS - static PythonArgParser parser( - {"_functionalize_mark_mutation_hidden_from_autograd(Tensor t)"}, - /*traceable=*/true); - - ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - auto self_ = r.tensor(0); - TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self_)); - at::functionalization::impl::mark_mutation_hidden_from_autograd(self_); - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - -static PyObject* -THPVariable__functionalize_are_all_mutations_hidden_from_autograd( - PyObject* self, - PyObject* args, - PyObject* kwargs) { - HANDLE_TH_ERRORS - static PythonArgParser parser( - {"_functionalize_are_all_mutations_hidden_from_autograd(Tensor t)"}, - /*traceable=*/true); - - ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - auto self_ = r.tensor(0); - TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self_)); - if (at::functionalization::impl::are_all_mutations_hidden_from_autograd( - self_)) { - Py_RETURN_TRUE; - } else { - Py_RETURN_FALSE; - } - END_HANDLE_TH_ERRORS -} - -static PyObject* THPVariable__functionalize_was_inductor_storage_resized( - PyObject* self, - PyObject* args, - PyObject* kwargs) { - HANDLE_TH_ERRORS - static PythonArgParser parser( - {"_functionalize_was_inductor_storage_resized(Tensor t)"}, - /*traceable=*/true); - - ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - auto self_ = r.tensor(0); - TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self_)); - auto functional_impl = - at::functionalization::impl::unsafeGetFunctionalWrapper(self_); - if (functional_impl->was_inductor_storage_resized()) { - Py_RETURN_TRUE; - } else { - Py_RETURN_FALSE; - } - END_HANDLE_TH_ERRORS -} - -static PyObject* -THPVariable__functionalize_are_all_mutations_under_no_grad_or_inference_mode( - PyObject* self, - PyObject* args, - PyObject* kwargs) { - HANDLE_TH_ERRORS - static PythonArgParser parser( - {"_functionalize_are_all_mutations_under_no_grad_or_inference_mode(Tensor t)"}, - /*traceable=*/true); - - ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - auto self_ = r.tensor(0); - TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self_)); - if (at::functionalization::impl:: - are_all_mutations_under_no_grad_or_inference_mode(self_)) { - Py_RETURN_TRUE; - } else { - Py_RETURN_FALSE; - } - END_HANDLE_TH_ERRORS -} // XXX: ops that are bound here are not exposed to the C++ api nor the JIT. // Any new ops added here should be accompanied with a comment why they are not // being registered through native_functions.yaml, and be tagged cpp / JIT @@ -835,102 +378,6 @@ static PyMethodDef torch_functions_manual[] = { castPyCFunctionWithKeywords(THPVariable_frombuffer), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, - {"_is_functional_tensor", - castPyCFunctionWithKeywords(THPVariable__is_functional_tensor), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, - nullptr}, - {"_to_functional_tensor", - castPyCFunctionWithKeywords(THPVariable__to_functional_tensor), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, - nullptr}, - {"_mirror_autograd_meta_to", - castPyCFunctionWithKeywords(THPVariable__mirror_autograd_meta_to), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, - nullptr}, - {"_from_functional_tensor", - castPyCFunctionWithKeywords(THPVariable__from_functional_tensor), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, - nullptr}, - {"_freeze_functional_tensor", - castPyCFunctionWithKeywords(THPVariable__freeze_functional_tensor), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, - nullptr}, - {"_functionalize_replace", - castPyCFunctionWithKeywords(THPVariable__functionalize_replace), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, - nullptr}, - {"_functionalize_commit_update", - castPyCFunctionWithKeywords(THPVariable__functionalize_commit_update), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, - nullptr}, - {"_functionalize_sync", - castPyCFunctionWithKeywords(THPVariable__functionalize_sync), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, - nullptr}, - {"_functionalize_is_symbolic", - castPyCFunctionWithKeywords(THPVariable__functionalize_is_symbolic), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, - nullptr}, - {"_functionalize_apply_view_metas", - castPyCFunctionWithKeywords(THPVariable__functionalize_apply_view_metas), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, - nullptr}, - {"_enable_functionalization", - castPyCFunctionWithKeywords(THPVariable__enable_functionalization), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, - nullptr}, - {"_disable_functionalization", - castPyCFunctionWithKeywords(THPVariable__disable_functionalization), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, - nullptr}, - {"_functionalize_has_metadata_mutation", - castPyCFunctionWithKeywords( - THPVariable__functionalize_has_metadata_mutation), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, - nullptr}, - {"_functionalize_mark_mutation_hidden_from_autograd", - castPyCFunctionWithKeywords( - THPVariable__functionalize_mark_mutation_hidden_from_autograd), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, - nullptr}, - {"_functionalize_are_all_mutations_hidden_from_autograd", - castPyCFunctionWithKeywords( - THPVariable__functionalize_are_all_mutations_hidden_from_autograd), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, - nullptr}, - {"_functionalize_was_inductor_storage_resized", - castPyCFunctionWithKeywords( - THPVariable__functionalize_was_inductor_storage_resized), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, - nullptr}, - {"_functionalize_are_all_mutations_under_no_grad_or_inference_mode", - castPyCFunctionWithKeywords( - THPVariable__functionalize_are_all_mutations_under_no_grad_or_inference_mode), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, - nullptr}, - {"_functionalize_is_multi_output_view", - castPyCFunctionWithKeywords( - THPVariable__functionalize_is_multi_output_view), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, - nullptr}, - {"_functionalize_has_data_mutation", - castPyCFunctionWithKeywords(THPVariable__functionalize_has_data_mutation), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, - nullptr}, - {"_functionalize_was_storage_changed", - castPyCFunctionWithKeywords( - THPVariable__functionalize_was_storage_changed), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, - nullptr}, - {"_functionalize_get_storage_size", - castPyCFunctionWithKeywords(THPVariable__functionalize_get_storage_size), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, - nullptr}, - {"_functionalize_enable_reapply_views", - castPyCFunctionWithKeywords( - THPVariable__functionalize_enable_reapply_views), - METH_VARARGS | METH_KEYWORDS | METH_STATIC, - nullptr}, {"nonzero", castPyCFunctionWithKeywords(THPVariable_nonzero), METH_VARARGS | METH_KEYWORDS | METH_STATIC, @@ -1152,6 +599,178 @@ void initTorchFunctions(PyObject* module) { module, "_VariableFunctions", THPVariableFunctionsModule) < 0) { throw python_error(); } + + // pybind registrations to torch module + // TODO: move these from torch.* to torch._C.* + auto py_module = py::module::import("torch"); + + py_module.def( + "_functionalize_are_all_mutations_under_no_grad_or_inference_mode", + [](const at::Tensor& t) { + TORCH_INTERNAL_ASSERT( + at::functionalization::impl::isFunctionalTensor(t)); + return at::functionalization::impl:: + are_all_mutations_under_no_grad_or_inference_mode(t); + }); + py_module.def( + "_functionalize_was_inductor_storage_resized", [](const at::Tensor& t) { + TORCH_INTERNAL_ASSERT( + at::functionalization::impl::isFunctionalTensor(t)); + auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t); + return impl->was_inductor_storage_resized(); + }); + py_module.def( + "_functionalize_are_all_mutations_hidden_from_autograd", + [](const at::Tensor& t) { + TORCH_INTERNAL_ASSERT( + at::functionalization::impl::isFunctionalTensor(t)); + return at::functionalization::impl:: + are_all_mutations_hidden_from_autograd(t); + }); + py_module.def( + "_functionalize_mark_mutation_hidden_from_autograd", + [](const at::Tensor& t) { + TORCH_INTERNAL_ASSERT( + at::functionalization::impl::isFunctionalTensor(t)); + at::functionalization::impl::mark_mutation_hidden_from_autograd(t); + }); + py_module.def( + "_functionalize_apply_view_metas", + [](const at::Tensor& tensor, const at::Tensor& base) { + TORCH_INTERNAL_ASSERT( + at::functionalization::impl::isFunctionalTensor(tensor)); + auto impl = + at::functionalization::impl::unsafeGetFunctionalWrapper(tensor); + return impl->apply_view_metas(base); + }); + py_module.def("_functionalize_is_symbolic", [](const at::Tensor& t) { + TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t)); + auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t); + return impl->is_symbolic(); + }); + py_module.def("_functionalize_sync", [](const at::Tensor& t) { + TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t)); + at::functionalization::impl::sync(t); + }); + py_module.def("_functionalize_commit_update", [](const at::Tensor& t) { + TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t)); + at::functionalization::impl::commit_update(t); + }); + py_module.def( + "_functionalize_replace", [](const at::Tensor& t, const at::Tensor& o) { + TORCH_INTERNAL_ASSERT( + at::functionalization::impl::isFunctionalTensor(t)); + TORCH_INTERNAL_ASSERT( + !at::functionalization::impl::isFunctionalTensor(o)); + at::functionalization::impl::replace_(t, o); + }); + py_module.def("_functionalize_is_multi_output_view", [](const at::Tensor& t) { + TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t)); + auto t_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t); + return t_impl->is_multi_output_view(); + }); + py_module.def( + "_functionalize_enable_reapply_views", + [](bool reapply_views = false) { + auto old = + at::functionalization::impl::getFunctionalizationReapplyViewsTLS(); + at::functionalization::impl::setFunctionalizationReapplyViewsTLS( + reapply_views); + return old; + }, + py::arg("reapply_views") = false); + py_module.def( + "_functionalize_has_metadata_mutation", [](const at::Tensor& t) { + TORCH_INTERNAL_ASSERT( + at::functionalization::impl::isFunctionalTensor(t)); + auto t_impl = + at::functionalization::impl::unsafeGetFunctionalWrapper(t); + return t_impl->has_metadata_mutation(); + }); + py_module.def("_functionalize_has_data_mutation", [](const at::Tensor& t) { + TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t)); + auto t_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t); + return t_impl->has_data_mutation(); + }); + py_module.def( + "_functionalize_get_storage_size", [](const at::Tensor& t, bool before) { + TORCH_INTERNAL_ASSERT( + at::functionalization::impl::isFunctionalTensor(t)); + auto wrapper = + at::functionalization::impl::unsafeGetFunctionalWrapper(t); + auto size = wrapper->get_storage_size(/*before=*/before); + return size; + }); + py_module.def("_functionalize_set_storage_changed", [](const at::Tensor& t) { + TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t)); + auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(t); + wrapper->set_storage_changed(); + }); + py_module.def("_functionalize_was_storage_changed", [](const at::Tensor& t) { + TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t)); + auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(t); + return wrapper->was_storage_changed(); + }); + py_module.def( + "_functionalize_mark_mutation_hidden_from_autograd", + [](const at::Tensor& t) { + TORCH_INTERNAL_ASSERT( + at::functionalization::impl::isFunctionalTensor(t)); + at::functionalization::impl::mark_mutation_hidden_from_autograd(t); + }); + py_module.def("_is_functional_tensor", [](const at::Tensor& t) { + return at::functionalization::impl::isFunctionalTensor(t); + }); + py_module.def("_to_functional_tensor", [](const at::Tensor& t) { + return at::functionalization::impl::to_functional_tensor(t); + }); + py_module.def("_from_functional_tensor", [](const at::Tensor& t) { + return at::functionalization::impl::from_functional_tensor(t); + }); + py_module.def("_freeze_functional_tensor", [](const at::Tensor& t) { + at::functionalization::impl::freeze_functional_tensor(t); + }); + py_module.def( + "_enable_functionalization", + [](bool reapply_views = false) { + if (c10::impl::tls_is_dispatch_key_included( + at::DispatchKey::Functionalize)) { + TORCH_INTERNAL_ASSERT( + false, + "multiple layers of mode-style functionalization nesting is not" + " currently supported, outside of the functionalize() transform"); + } + c10::impl::tls_set_dispatch_key_included( + at::DispatchKey::Functionalize, true); + if (reapply_views) { + at::functionalization::impl::setFunctionalizationReapplyViewsTLS( + true); + } + }, + py::arg("reapply_views") = false); + py_module.def("_disable_functionalization", []() { + c10::impl::tls_set_dispatch_key_included( + at::DispatchKey::Functionalize, false); + at::functionalization::impl::setFunctionalizationReapplyViewsTLS(false); + }); + py_module.def( + "_mirror_autograd_meta_to", + [](const at::Tensor& src_, const at::Tensor& dst_) { + // Here, we unsafely set the grad function on the wrapper to be the same + // as the inner. We expect this grad_fn to NEVER be used. It's needed so + // that .is_leaf metadata is accurate on the wrapper + auto inner_autograd_meta = impl::get_autograd_meta(src_); + if (inner_autograd_meta) { + dst_.set_requires_grad(src_.requires_grad()); + if (dst_.requires_grad()) { + auto new_grad_fn = std::shared_ptr( + new torch::autograd::Error( + "Cannot backprop through mirrored meta, file a bug in PyTorch"), + torch::autograd::deleteNode); + torch::autograd::set_history(dst_, new_grad_fn); + } + } + }); } } // namespace autograd