From b3e4f5029b7af07050f4e71d5c96e207283beeaf Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Thu, 12 Jan 2023 21:00:28 +0200 Subject: [PATCH] Add check-sparse-tensor-invariants flag to Context - 2nd try. (#92094) This PR is a copy of https://github.com/pytorch/pytorch/pull/90849 that merge was reverted. The PR adds "check sparse tensor invariants" flag to Context that when enabled will trigger sparse tensor data invariants checks in unsafe methods of constructing sparse COO/CSR/CSC/BSR/BSC tensors. The feature includes the following changes to UI: `torch.sparse.check_sparse_tensor_invariants` class provides different ways to enable/disable the invariant checking. `torch.sparse_coo/csr/csc/bsr/bsc/compressed_tensor` functions have a new optional argument `check_invariants` to enable/disable the invariant checks explicitly. When the `check_invariants` argument is specified, the global state of the feature is temporarily overridden. The PR fixes https://github.com/pytorch/pytorch/issues/90833 Pull Request resolved: https://github.com/pytorch/pytorch/pull/92094 Approved by: https://github.com/cpuhrsch --- aten/src/ATen/Context.cpp | 8 ++ aten/src/ATen/Context.h | 3 + .../ATen/native/sparse/SparseCsrTensor.cpp | 10 +- aten/src/ATen/native/sparse/SparseTensor.cpp | 20 +-- docs/source/sparse.rst | 30 ++++- docs/source/torch.rst | 4 + test/test_sparse.py | 107 +++++++++++++++ test/test_sparse_csr.py | 34 +++-- tools/pyi/gen_pyi.py | 9 +- torch/_C/__init__.pyi.in | 2 + torch/__init__.py | 3 +- torch/_torch_docs.py | 30 +++-- torch/_utils.py | 11 +- torch/csrc/Module.cpp | 29 +++++ .../python_torch_functions_manual.cpp | 38 +++--- torch/csrc/utils/tensor_new.cpp | 122 ++++++++++++++---- torch/sparse/__init__.py | 106 +++++++++++++++ torch/testing/_internal/common_utils.py | 23 ++++ 18 files changed, 493 insertions(+), 96 deletions(-) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index dd33ded7615bf6..b6cda72cf1e929 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -367,6 +367,14 @@ bool Context::isXNNPACKAvailable() { #endif } +void Context::setCheckSparseTensorInvariants(bool e) { + enable_sparse_tensor_invariant_checks = e; +} + +bool Context::checkSparseTensorInvariants() const { + return enable_sparse_tensor_invariant_checks; +} + bool Context::releaseWeightsWhenPrepacking() const { return release_original_weights; } diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 9ab289b779e03d..8816fb0872a842 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -250,6 +250,8 @@ class TORCH_API Context { void setQEngine(at::QEngine e); static const std::vector& supportedQEngines(); static bool isXNNPACKAvailable(); + void setCheckSparseTensorInvariants(bool e); + bool checkSparseTensorInvariants() const; // This method is used to release the original weight after pre-packing. // It should be called once before loading/running the model. // NB: By default it is set to true for mobile builds. @@ -305,6 +307,7 @@ class TORCH_API Context { #endif bool display_vmap_fallback_warnings_ = false; c10::optional quantized_engine = c10::nullopt; + bool enable_sparse_tensor_invariant_checks = false; Allocator* prev_allocator_ptr_{nullptr}; }; diff --git a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp index 9e3fa5f035b6d1..3d2526c4120436 100644 --- a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp @@ -356,6 +356,9 @@ Tensor _sparse_compressed_tensor_unsafe(const Tensor& compressed_indices, } Layout layout_ = layout.value(); AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor_unsafe", [&]{}); + if (at::globalContext().checkSparseTensorInvariants()) { + _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_); + } TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory); SparseCsrTensor self = new_compressed_tensor(options); get_sparse_csr_impl(self)->set_member_tensors(compressed_indices, plain_indices, values, size); @@ -373,6 +376,9 @@ Tensor _sparse_compressed_tensor_unsafe_template(const Tensor& compressed_indice c10::optional pin_memory) { Layout layout_ = layout.value_or(required_layout); TORCH_CHECK(layout_ == required_layout, "sparse compressed layout must be ",required_layout, " but got ", layout_); + if (at::globalContext().checkSparseTensorInvariants()) { + _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_); + } TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory); SparseCsrTensor self = new_compressed_tensor(options); get_sparse_csr_impl(self)->set_member_tensors(compressed_indices, plain_indices, values, size); @@ -474,8 +480,6 @@ Tensor sparse_compressed_tensor( // See [Note: hacky wrapper removal for TensorOptions] TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory); - _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_); - return at::native::_sparse_compressed_tensor_unsafe( compressed_indices, plain_indices, @@ -507,8 +511,6 @@ Tensor sparse_compressed_tensor( // See [Note: hacky wrapper removal for TensorOptions] TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory); - _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_); - return at::native::_sparse_compressed_tensor_unsafe( compressed_indices, plain_indices, diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index 37f6380757d4d8..d24068c0a05cfe 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -398,8 +398,6 @@ Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values, IntArrayRe !options.has_layout() || options.layout() == kSparse, "expected sparse layout, but got layout ", options.layout()); - - at::native::_validate_sparse_coo_tensor_args(indices, values, size); return at::native::_sparse_coo_tensor_unsafe( indices, values, @@ -415,20 +413,10 @@ Tensor _sparse_coo_tensor_unsafe(const Tensor& indices, const Tensor& values_, a c10::optional layout, c10::optional device, c10::optional pin_memory) { - return at::native::_sparse_coo_tensor_unsafe_symint(indices, values_, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); - - Tensor values = expand_values_if_needed(values_); - - auto sparse_dim = indices.size(0); - auto dense_dim = values.dim() - 1; - - return at::_sparse_coo_tensor_with_dims_and_tensors( - sparse_dim, - dense_dim, - size, - indices, - values, - values.options().layout(kSparse)); + if (at::globalContext().checkSparseTensorInvariants()) { + at::native::_validate_sparse_coo_tensor_args(indices, values_, size); + } + return at::native::_sparse_coo_tensor_unsafe_symint(indices, values_, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); } // NOTE: _sparse_coo_tensor_unsafe() differs from sparse_coo_tensor() diff --git a/docs/source/sparse.rst b/docs/source/sparse.rst index 77e8dabec2744c..377368e09c7831 100644 --- a/docs/source/sparse.rst +++ b/docs/source/sparse.rst @@ -321,6 +321,15 @@ invariants: Dense dimensions always follow sparse dimensions, that is, mixing of dense and sparse dimensions is not supported. +.. note:: + + To be sure that a constructed sparse tensor has consistent indices, + values, and size, the invariant checks can be enabled per tensor + creation via ``check_invariants=True`` keyword argument, or + globally using :class:`torch.sparse.check_sparse_tensor_invariants` + context manager instance. By default, the sparse tensor invariants + checks are disabled. + .. _sparse-uncoalesced-coo-docs: Uncoalesced sparse COO tensors @@ -530,6 +539,13 @@ __ https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_o where ``plain_dim_size`` is the number of plain dimensions (orthogonal to compressed dimensions, e.g. columns or rows). + To be sure that a constructed sparse tensor has consistent indices, + values, and size, the invariant checks can be enabled per tensor + creation via ``check_invariants=True`` keyword argument, or + globally using :class:`torch.sparse.check_sparse_tensor_invariants` + context manager instance. By default, the sparse tensor invariants + checks are disabled. + .. note:: The generalization of sparse compressed layouts to N-dimensional @@ -646,9 +662,9 @@ argument is optional and will be deduced from the ``crow_indices`` and >>> csr = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=torch.float64) >>> csr tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 1]), - values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, - dtype=torch.float64) + col_indices=tensor([0, 1, 0, 1]), + values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4, + dtype=torch.float64) >>> csr.to_dense() tensor([[1., 2.], [3., 4.]], dtype=torch.float64) @@ -1160,6 +1176,14 @@ The following :mod:`torch` functions support sparse tensors: :func:`~torch.zeros` :func:`~torch.zeros_like` +To manage checking sparse tensor invariants, see: + +.. autosummary:: + :toctree: generated + :nosignatures: + + sparse.check_sparse_tensor_invariants + Unary functions --------------- diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 111ee21f6d83d8..1376058117955b 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -48,6 +48,10 @@ Creation Ops tensor sparse_coo_tensor + sparse_csr_tensor + sparse_csc_tensor + sparse_bsr_tensor + sparse_bsc_tensor asarray as_tensor as_strided diff --git a/test/test_sparse.py b/test/test_sparse.py index 5304ab7eaafc6d..d18e3bce8da821 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -4072,6 +4072,113 @@ def test_basic(self): class TestSparseAny(TestCase): + @onlyCPU + @all_sparse_layouts('layout', include_strided=False) + @torch.sparse.check_sparse_tensor_invariants(enable=False) + def test_check_sparse_tensor_invariants(self, layout): + + if layout is torch.sparse_coo: + + def create_invalid_tensor(check_invariants=None): + shape = (2, 2) + invalid_indices = torch.tensor([[0], [3]]) # column index is out of range + values = torch.tensor([1]) + if check_invariants is None: + return torch.sparse_coo_tensor(invalid_indices, values, shape) + else: + return torch.sparse_coo_tensor(invalid_indices, values, shape, check_invariants=check_invariants) + + expected_exception_message = 'size is inconsistent with indices: for dim 1, size is 2 but found index 3' + + elif layout in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}: + + def create_invalid_tensor(check_invariants=None): + shape = (2, 2) + compressed_indices = torch.tensor([0, 0, 1]) + invalid_plain_indices = torch.tensor([3]) # index is out of range + if layout in {torch.sparse_bsr, torch.sparse_bsc}: + values = torch.tensor([[[1]]]) + else: + values = torch.tensor([1]) + if check_invariants is None: + return torch.sparse_compressed_tensor(compressed_indices, invalid_plain_indices, values, shape, layout=layout) + else: + return torch.sparse_compressed_tensor(compressed_indices, invalid_plain_indices, values, shape, layout=layout, + check_invariants=check_invariants) + + if layout in {torch.sparse_csr, torch.sparse_bsr}: + expected_exception_message = r'`0 <= col_indices < ncols` is not satisfied.' + else: + expected_exception_message = r'`0 <= row_indices < nrows` is not satisfied.' + + else: + raise NotImplementedError(layout) + + # First, consider the case where invariant checks are disabled + # "globally" (read: within the context of this test method + # caller) as defined by check_sparse_tensor_invariants(False) + # decorator: + self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled()) + + # Enable the invariant checks in a local context: + with torch.sparse.check_sparse_tensor_invariants(): + self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled()) + + # Leaving the local context must restore the "global" state of + # the invariant check feature: + self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled()) + + # Since invariant checks are disabled by default, we can + # create an invalid sparse tensor without raising an + # exception: + r = create_invalid_tensor() + self.assertEqual(r.layout, layout) + + # Or, when disabling the invariants check explicitly: + r = create_invalid_tensor(check_invariants=False) + self.assertEqual(r.layout, layout) + + # Enabling invariant check via constructor's optional argument + # will raise an exception when sparse tensor invariants are + # violated: + with self.assertRaisesRegex(RuntimeError, expected_exception_message): + create_invalid_tensor(check_invariants=True) + + # Check that the global invariant check flag has been restored + # after raising the exception above: + self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled()) + + # Next, consider the case where invariant checks are enabled + # within a local context: + with torch.sparse.check_sparse_tensor_invariants(): + self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled()) + + # Since invariant checks are now enabled by default, an + # attempt to create an invalid sparse tensor will lead to + # an exception: + with self.assertRaisesRegex(RuntimeError, expected_exception_message): + create_invalid_tensor() + + # Similarly, when enabling the invariant checks + # explicitly, invalid sparse tensor construction will lead + # to an exception: + with self.assertRaisesRegex(RuntimeError, expected_exception_message): + create_invalid_tensor(check_invariants=True) + + # However, invariants check can be disabled via + # constructor's optional argument so that the invalid + # tensor is succesfully constructed: + r = create_invalid_tensor(check_invariants=False) + self.assertEqual(r.layout, layout) + + # Check that the invariant check flag has been restored + # when leaving the constructor: + self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled()) + + # Double-check restoring the global state when leaving the + # local context: + self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled()) + def test_generate_simple_inputs(self): layouts = [torch.strided, torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc] diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index 5ba10452d5d246..30606d15b8597f 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -1363,9 +1363,17 @@ def test_csr_matvec(self, device, dtype): @onlyCUDA @unittest.skipIf(not (CUDA11OrLater or TEST_WITH_ROCM), "Only CUDA 11+ is supported") + # hmm, the test passes ok on CUDA when Rocm is not available: @skipCUDAIfRocmVersionLessThan((5, 2)) @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) def test_baddbmm(self, device, dtype): + + # TODO: disable the invariant checks within torch.baddbmm that + # constructs unconventional csr tensors leading to + # RuntimeError: tensor dimensionality must be sum of batch, + # base, and dense dimensionalities (=0 + 2 + 0) but got 3 + # when invariant checking is enabled. When done, undecorate run_test. + @torch.sparse.check_sparse_tensor_invariants(enable=False) def run_test(c, a, a_batched, b, op_b=False, op_out=False, *, dtype=None, device=None): alpha = complex(random.random(), random.random()) if dtype.is_complex else random.random() beta = complex(random.random(), random.random()) if dtype.is_complex else random.random() @@ -1388,8 +1396,8 @@ def run_test(c, a, a_batched, b, op_b=False, op_out=False, *, dtype=None, device a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype) # a_batched is a regular CSR tensor but with a batch dimension in the shape - a_batched = torch._sparse_csr_tensor_unsafe( - a.crow_indices(), a.col_indices(), a.values(), (batch_size, m, k)) + a_batched = torch.sparse_csr_tensor( + a.crow_indices(), a.col_indices(), a.values(), (batch_size, m, k), check_invariants=False) b = make_tensor((batch_size, k, n), dtype=dtype, device=device, noncontiguous=noncontiguous) c = make_tensor((batch_size, m, n), dtype=dtype, device=device, noncontiguous=noncontiguous) @@ -1420,9 +1428,13 @@ def run_test(a, a_batched, b, op_b=False, op_out=False, *, dtype=None, device=No nnz = random.randint(0, m * k) a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype) - # a_batched is a regular CSR tensor but with a batch dimension in the shape - a_batched = torch._sparse_csr_tensor_unsafe( - a.crow_indices(), a.col_indices(), a.values(), (batch_size, m, k)) + # a_batched is a regular CSR tensor but with a batch + # dimension in the shape. It is unorthodox in PyTorch + # to represent a batch sparse tensor in this way, + # hence checking the tensor invariants is locally + # turned off. + a_batched = torch.sparse_csr_tensor( + a.crow_indices(), a.col_indices(), a.values(), (batch_size, m, k), check_invariants=False) b = make_tensor((batch_size, k, n), dtype=dtype, device=device, noncontiguous=noncontiguous) for op_b, op_out in itertools.product([True, False], repeat=2): @@ -1549,8 +1561,8 @@ def ref_half_bfloat16(c, a, b, alpha=None, beta=None, out=None): a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype) a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device) a_data = a_data.mT if noncontiguous else a_data - a = torch._sparse_bsr_tensor_unsafe(a.crow_indices(), a.col_indices(), - a_data, (m * block_size, k * block_size)) + a = torch.sparse_bsr_tensor(a.crow_indices(), a.col_indices(), + a_data, (m * block_size, k * block_size), check_invariants=False) b = make_tensor((k * block_size, n * block_size), dtype=dtype, device=device, noncontiguous=noncontiguous) c = make_tensor((m * block_size, n * block_size), dtype=dtype, device=device, noncontiguous=noncontiguous) for op_b, op_out in itertools.product([True, False], repeat=2): @@ -1585,8 +1597,8 @@ def test_block_addmv(self, device, dtype, index_dtype, block_size, noncontiguous a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype) a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device) a_data = a_data.mT if noncontiguous else a_data # Test column-major blocks - a = torch._sparse_bsr_tensor_unsafe(a.crow_indices(), a.col_indices(), - a_data, (m * block_size, k * block_size)) + a = torch.sparse_bsr_tensor(a.crow_indices(), a.col_indices(), + a_data, (m * block_size, k * block_size), check_invariants=False) b = make_tensor((k * block_size,), dtype=dtype, device=device, noncontiguous=noncontiguous) c = make_tensor((m * block_size,), dtype=dtype, device=device, noncontiguous=noncontiguous) self.run_test_block_addmm_addmv(torch.addmv, c, a, b, dtype=dtype, device=device) @@ -1658,8 +1670,8 @@ def run_test(a, b, upper, transpose, unitriangular, op_out): a = self.genSparseCSRTensor((m, m), nnz, dtype=dtype, device=device, index_dtype=index_dtype) a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device) a_data = a_data.mT if noncontiguous else a_data # Test column-major blocks - a = torch._sparse_bsr_tensor_unsafe(a.crow_indices(), a.col_indices(), - a_data, (m * block_size, m * block_size)) + a = torch.sparse_bsr_tensor(a.crow_indices(), a.col_indices(), + a_data, (m * block_size, m * block_size), check_invariants=False) b = make_tensor((m * block_size, k), dtype=dtype, device=device, noncontiguous=noncontiguous) for (upper, unitriangular, transpose, op_out) in itertools.product([True, False], repeat=4): diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index 447ac0a9b62e46..a51144589d3fdf 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -364,7 +364,8 @@ def gen_pyi( f"{n2}_indices: Union[Tensor, List]," " values: Union[Tensor, List], size: Optional[_size]=None," " *, dtype: Optional[_dtype]=None," - " device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ..." + " device: Union[_device, str, None]=None, requires_grad:_bool=False," + " check_invariants:_bool=None) -> Tensor: ..." ], f"_sparse_{n}_tensor_unsafe": [ f"def _sparse_{n}_tensor_unsafe({n1}_indices: Union[Tensor, List]," @@ -411,7 +412,8 @@ def gen_pyi( "sparse_coo_tensor": [ "def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List]," " size: Optional[_size]=None, *, dtype: Optional[_dtype]=None," - " device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ..." + " device: Union[_device, str, None]=None, requires_grad:_bool=False," + " check_invariants:_bool=None) -> Tensor: ..." ], "_sparse_coo_tensor_unsafe": [ "def _sparse_coo_tensor_unsafe(indices: Tensor, values: Tensor, size: List[int]," @@ -423,7 +425,8 @@ def gen_pyi( "plain_indices: Union[Tensor, List]," " values: Union[Tensor, List], size: Optional[_size]=None," " *, dtype: Optional[_dtype]=None, layout: Optional[_layout] = None," - " device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ..." + " device: Union[_device, str, None]=None, requires_grad:_bool=False," + " check_invariants:_bool=None) -> Tensor: ..." ], "_sparse_compressed_tensor_unsafe": [ "def _sparse_compressed_tensor_unsafe(comp_indices: Union[Tensor, List]," diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 577f4bb38e5677..fef4bae0e04291 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -878,6 +878,8 @@ def _get_qengine() -> _int: ... # THPModule_qEngine def _set_qengine(qegine: _int) -> None: ... # THPModule_setQEngine def _supported_qengines() -> List[_int]: ... # THPModule_supportedQEngines def _is_xnnpack_enabled() -> _bool: ... # THPModule_isEnabledXNNPACK +def _check_sparse_tensor_invariants() -> _bool: ... # THPModule_checkSparseTensorInvariants +def _set_check_sparse_tensor_invariants(arg: _bool) -> None: ... # THPModule_setCheckSparseTensorInvariants def _set_default_mobile_cpu_allocator() -> None: ... # THPModule_setDefaultMobileCPUAllocator def _unset_default_mobile_cpu_allocator() -> None: ... # THPModule_unsetDefaultMobileCPUAllocator def _is_torch_function_enabled() -> _bool: ... # THPModule_isEnabledTorchFunction diff --git a/torch/__init__.py b/torch/__init__.py index 1783b6a0d4f7cd..c76764092525f6 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -50,8 +50,7 @@ 'set_deterministic_debug_mode', 'get_deterministic_debug_mode', 'set_float32_matmul_precision', 'get_float32_matmul_precision', 'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat', - 'sym_int', 'sym_float', 'compile', 'vmap' -] + 'sym_int', 'sym_float', 'compile', 'vmap'] ################################################################################ # Load the extension module diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index a0f2e78e9df544..9caa080551efa5 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -106,6 +106,9 @@ def merge_dicts(*dicts): the pinned memory. Works only for CPU tensors. Default: ``False``. memory_format (:class:`torch.memory_format`, optional): the desired memory format of returned Tensor. Default: ``torch.contiguous_format``. + check_invariants (bool, optional): If sparse tensor invariants are checked. + Default: as returned by :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled`, + initially False. """ ), { @@ -10161,7 +10164,7 @@ def merge_dicts(*dicts): add_docstr( torch.sparse_compressed_tensor, r"""sparse_compressed_tensor(compressed_indices, plain_indices, values, size=None, """ - r"""*, dtype=None, layout=None, device=None, requires_grad=False) -> Tensor + r"""*, dtype=None, layout=None, device=None, requires_grad=False, check_invariants=None) -> Tensor Constructs a :ref:`sparse tensor in Compressed Sparse format - CSR, CSC, BSR, or BSC - ` with specified values at @@ -10213,6 +10216,7 @@ def merge_dicts(*dicts): the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. {requires_grad} + {check_invariants} Example:: >>> compressed_indices = [0, 2, 4] @@ -10232,8 +10236,8 @@ def merge_dicts(*dicts): add_docstr( torch.sparse_csr_tensor, - r""" -sparse_csr_tensor(crow_indices, col_indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor + r"""sparse_csr_tensor(crow_indices, col_indices, values, size=None, """ + r"""*, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor Constructs a :ref:`sparse tensor in CSR (Compressed Sparse Row) ` with specified values at the given :attr:`crow_indices` and :attr:`col_indices`. Sparse matrix multiplication operations @@ -10273,6 +10277,7 @@ def merge_dicts(*dicts): the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. {requires_grad} + {check_invariants} Example:: >>> crow_indices = [0, 2, 4] @@ -10292,8 +10297,8 @@ def merge_dicts(*dicts): add_docstr( torch.sparse_csc_tensor, - r""" -sparse_csc_tensor(ccol_indices, row_indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor + r"""sparse_csc_tensor(ccol_indices, row_indices, values, size=None, """ + r"""*, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor Constructs a :ref:`sparse tensor in CSC (Compressed Sparse Column) ` with specified values at the given @@ -10335,6 +10340,7 @@ def merge_dicts(*dicts): the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. {requires_grad} + {check_invariants} Example:: >>> ccol_indices = [0, 2, 4] @@ -10354,8 +10360,8 @@ def merge_dicts(*dicts): add_docstr( torch.sparse_bsr_tensor, - r""" -sparse_bsr_tensor(crow_indices, col_indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor + r"""sparse_bsr_tensor(crow_indices, col_indices, values, size=None, """ + r"""*, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor Constructs a :ref:`sparse tensor in BSR (Block Compressed Sparse Row)) ` with specified 2-dimensional blocks at the given @@ -10399,6 +10405,7 @@ def merge_dicts(*dicts): the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. {requires_grad} + {check_invariants} Example:: >>> crow_indices = [0, 1, 2] @@ -10421,8 +10428,8 @@ def merge_dicts(*dicts): add_docstr( torch.sparse_bsc_tensor, - r""" -sparse_bsc_tensor(ccol_indices, row_indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor + r"""sparse_bsc_tensor(ccol_indices, row_indices, values, size=None, """ + r"""*, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor Constructs a :ref:`sparse tensor in BSC (Block Compressed Sparse Column)) ` with specified 2-dimensional blocks at the @@ -10465,6 +10472,7 @@ def merge_dicts(*dicts): the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. {requires_grad} + {check_invariants} Example:: >>> ccol_indices = [0, 1, 2] @@ -10488,7 +10496,7 @@ def merge_dicts(*dicts): add_docstr( torch.sparse_coo_tensor, r""" -sparse_coo_tensor(indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor +sparse_coo_tensor(indices, values, size=None, *, dtype=None, device=None, requires_grad=False, check_invariants=None) -> Tensor Constructs a :ref:`sparse tensor in COO(rdinate) format ` with specified values at the given @@ -10520,7 +10528,7 @@ def merge_dicts(*dicts): (see :func:`torch.set_default_tensor_type`). :attr:`device` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. {requires_grad} - + {check_invariants} Example:: diff --git a/torch/_utils.py b/torch/_utils.py index d00d27571c254f..e9a07e86a09df3 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -237,7 +237,7 @@ def _rebuild_sparse_tensor(layout, data): """ if layout == torch.sparse_coo: indices, values, size = data - result = torch._sparse_coo_tensor_unsafe(indices, values, size) + result = torch.sparse_coo_tensor(indices, values, size, check_invariants=False) _sparse_tensors_to_validate.append(result) return result @@ -248,8 +248,13 @@ def _rebuild_sparse_tensor(layout, data): torch.sparse_bsc, }: compressed_indices, plain_indices, values, size = data - result = torch._sparse_compressed_tensor_unsafe( - compressed_indices, plain_indices, values, size, layout=layout + result = torch.sparse_compressed_tensor( + compressed_indices, + plain_indices, + values, + size, + layout=layout, + check_invariants=False, ) _sparse_tensors_to_validate.append(result) return result diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index f5ee578fd2bd34..1f4d9ac3016187 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -831,6 +831,27 @@ PyObject* THPModule_isEnabledXNNPACK(PyObject* _unused, PyObject* noargs) { Py_RETURN_FALSE; } +PyObject* THPModule_setCheckSparseTensorInvariants( + PyObject* _unused, + PyObject* arg) { + THPUtils_assert( + PyBool_Check(arg), + "set_check_sparse_tensor_invariants expects a bool, " + "but got %s", + THPUtils_typename(arg)); + at::globalContext().setCheckSparseTensorInvariants(arg == Py_True); + Py_RETURN_NONE; +} + +PyObject* THPModule_checkSparseTensorInvariants( + PyObject* _unused, + PyObject* noargs) { + if (at::globalContext().checkSparseTensorInvariants()) + Py_RETURN_TRUE; + else + Py_RETURN_FALSE; +} + PyObject* THPModule_willEngineExecuteNode(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS bool isTHPFunction = THPFunction_Check(arg); @@ -1122,6 +1143,14 @@ static PyMethodDef TorchMethods[] = { {"_set_qengine", THPModule_setQEngine, METH_O, nullptr}, {"_supported_qengines", THPModule_supportedQEngines, METH_NOARGS, nullptr}, {"_is_xnnpack_enabled", THPModule_isEnabledXNNPACK, METH_NOARGS, nullptr}, + {"_set_check_sparse_tensor_invariants", + THPModule_setCheckSparseTensorInvariants, + METH_O, + nullptr}, + {"_check_sparse_tensor_invariants", + THPModule_checkSparseTensorInvariants, + METH_NOARGS, + nullptr}, {"_will_engine_execute_node", THPModule_willEngineExecuteNode, METH_O, diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp index 6aaaaf0eff6e9b..f444ca869fbd49 100644 --- a/torch/csrc/autograd/python_torch_functions_manual.cpp +++ b/torch/csrc/autograd/python_torch_functions_manual.cpp @@ -197,29 +197,29 @@ static PyObject* THPVariable_nonzero( THPVARIABLE_SPARSE_COMPRESSED_CTOR( sparse_compressed_tensor, - 9, - ({"sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)", - "sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"})) + 10, + ({"sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)", + "sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"})) THPVARIABLE_SPARSE_COMPRESSED_CTOR( sparse_csr_tensor, - 9, - ({"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)", - "sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"})) + 10, + ({"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)", + "sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"})) THPVARIABLE_SPARSE_COMPRESSED_CTOR( sparse_csc_tensor, - 9, - ({"sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)", - "sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"})) + 10, + ({"sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)", + "sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"})) THPVARIABLE_SPARSE_COMPRESSED_CTOR( sparse_bsr_tensor, - 9, - ({"sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)", - "sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"})) + 10, + ({"sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)", + "sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"})) THPVARIABLE_SPARSE_COMPRESSED_CTOR( sparse_bsc_tensor, - 9, - ({"sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)", - "sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)"})) + 10, + ({"sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)", + "sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, bool check_invariants=None)"})) THPVARIABLE_SPARSE_COMPRESSED_CTOR( _sparse_compressed_tensor_unsafe, @@ -248,12 +248,12 @@ static PyObject* THPVariable_sparse_coo_tensor( PyObject* kwargs) { HANDLE_TH_ERRORS static PythonArgParser parser({ - "sparse_coo_tensor(PyObject* indices, PyObject* values, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)", - "sparse_coo_tensor(PyObject* indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)", - "sparse_coo_tensor(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)", + "sparse_coo_tensor(PyObject* indices, PyObject* values, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False, bool check_invariants=None)", + "sparse_coo_tensor(PyObject* indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False, bool check_invariants=None)", + "sparse_coo_tensor(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False, bool check_invariants=None)", }); - ParsedArgs<6> parsed_args; + ParsedArgs<7> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); if (r.has_torch_function()) { return handle_torch_function( diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 37e121bd039299..434ecfa9697c8c 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -784,6 +784,19 @@ Tensor indexing_tensor_from_data( } } +class CheckSparseTensorInvariantsContext { + public: + CheckSparseTensorInvariantsContext() { + state = at::globalContext().checkSparseTensorInvariants(); + } + ~CheckSparseTensorInvariantsContext() { + at::globalContext().setCheckSparseTensorInvariants(state); + } + + private: + bool state; +}; + Tensor sparse_compressed_tensor_ctor_worker( std::string name, c10::DispatchKey dispatch_key, @@ -802,6 +815,7 @@ Tensor sparse_compressed_tensor_ctor_worker( ARG_DEVICE, ARG_PIN_MEMORY, ARG_REQUIRES_GRAD, + ARG_CHECK_INVARIANTS, ARGS_COUNT }; enum { @@ -811,6 +825,7 @@ Tensor sparse_compressed_tensor_ctor_worker( ARG_DEVICE1, ARG_PIN_MEMORY1, ARG_REQUIRES_GRAD1, + ARG_CHECK_INVARIANTS1, ARGS_COUNT1 }; @@ -840,6 +855,10 @@ Tensor sparse_compressed_tensor_ctor_worker( at::ScalarType plain_indices_scalar_type = plain_indices_dtype_attr ? reinterpret_cast(plain_indices_dtype_attr.get())->scalar_type : kInt; + CheckSparseTensorInvariantsContext + restores_check_sparse_tensor_invariants_global_state{}; + bool default_check_invariants = + at::globalContext().checkSparseTensorInvariants(); if (r.idx == 0) { bool type_inference = r.isNone(ARG_TYPE); @@ -848,6 +867,10 @@ Tensor sparse_compressed_tensor_ctor_worker( const auto inferred_scalar_type = r.scalartypeWithDefault(ARG_TYPE, scalar_type); at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE)); + // the global state of invariants check flag will be restored via + // CheckSparseTensorInvariantsContext destructor + at::globalContext().setCheckSparseTensorInvariants( + r.toBoolWithDefault(ARG_CHECK_INVARIANTS, default_check_invariants)); Tensor values = internal_new_from_data( inferred_options, @@ -900,6 +923,10 @@ Tensor sparse_compressed_tensor_ctor_worker( const auto inferred_scalar_type = r.scalartypeWithDefault(ARG_TYPE1, scalar_type); at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE1)); + // the global state of invariants check flag will be restored via + // CheckSparseTensorInvariantsContext destructor + at::globalContext().setCheckSparseTensorInvariants( + r.toBoolWithDefault(ARG_CHECK_INVARIANTS1, default_check_invariants)); Tensor values = internal_new_from_data( inferred_options, @@ -1170,17 +1197,54 @@ Tensor sparse_coo_tensor_ctor( PythonArgs& r) { TORCH_INTERNAL_ASSERT(!isSparse(dispatchKeyToBackend(dispatch_key))); TORCH_INTERNAL_ASSERT(!isSparseCsr(dispatchKeyToBackend(dispatch_key))); + enum { + ARG_INDICES = 0, + ARG_VALUES, + ARG_TYPE, + ARG_DEVICE, + ARG_REQUIRES_GRAD, + ARG_CHECK_INVARIANTS, + ARGS_COUNT + }; + enum { + ARG_INDICES1 = 0, + ARG_VALUES1, + ARG_SIZE1, + ARG_TYPE1, + ARG_DEVICE1, + ARG_REQUIRES_GRAD1, + ARG_CHECK_INVARIANTS1, + ARGS_COUNT1 + }; + enum { + ARG_SIZE2 = 0, + ARG_TYPE2, + ARG_DEVICE2, + ARG_REQUIRES_GRAD2, + ARG_CHECK_INVARIANTS2, + ARGS_COUNT2 + }; + CheckSparseTensorInvariantsContext + restores_check_sparse_tensor_invariants_global_state{}; + bool default_check_invariants = + at::globalContext().checkSparseTensorInvariants(); + if (r.idx == 0) { - bool type_inference = r.isNone(2); - const auto inferred_options = typeIdWithDefault(r, 3, dispatch_key); - const auto inferred_scalar_type = r.scalartypeWithDefault(2, scalar_type); - at::OptionalDeviceGuard device_guard(r.deviceOptional(3)); + bool type_inference = r.isNone(ARG_TYPE); + const auto inferred_options = + typeIdWithDefault(r, ARG_DEVICE, dispatch_key); + const auto inferred_scalar_type = + r.scalartypeWithDefault(ARG_TYPE, scalar_type); + at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE)); + at::globalContext().setCheckSparseTensorInvariants( + r.toBoolWithDefault(ARG_CHECK_INVARIANTS, default_check_invariants)); + // if no dtype provided, infer type based on value type. Tensor values = internal_new_from_data( inferred_options, inferred_scalar_type, - r.deviceOptional(3), - r.pyobject(1), + r.deviceOptional(ARG_DEVICE), + r.pyobject(ARG_VALUES), /*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/type_inference); @@ -1188,24 +1252,29 @@ Tensor sparse_coo_tensor_ctor( Tensor indices = internal_new_from_data( values.options(), kLong, - r.deviceOptional(3), - r.pyobject(0), + r.deviceOptional(ARG_DEVICE), + r.pyobject(ARG_INDICES), /*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/false); return at::sparse_coo_tensor( indices, values, values.options().layout(at::kSparse)) - .set_requires_grad(r.toBool(4)); + .set_requires_grad(r.toBool(ARG_REQUIRES_GRAD)); } else if (r.idx == 1) { - bool type_inference = r.isNone(3); - const auto inferred_options = typeIdWithDefault(r, 4, dispatch_key); - const auto inferred_scalar_type = r.scalartypeWithDefault(3, scalar_type); - at::OptionalDeviceGuard device_guard(r.deviceOptional(4)); + bool type_inference = r.isNone(ARG_TYPE1); + const auto inferred_options = + typeIdWithDefault(r, ARG_DEVICE1, dispatch_key); + const auto inferred_scalar_type = + r.scalartypeWithDefault(ARG_TYPE1, scalar_type); + at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE1)); + at::globalContext().setCheckSparseTensorInvariants( + r.toBoolWithDefault(ARG_CHECK_INVARIANTS1, default_check_invariants)); + Tensor values = internal_new_from_data( inferred_options, inferred_scalar_type, - r.deviceOptional(4), - r.pyobject(1), + r.deviceOptional(ARG_DEVICE1), + r.pyobject(ARG_VALUES1), /*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/type_inference); @@ -1213,25 +1282,30 @@ Tensor sparse_coo_tensor_ctor( Tensor indices = internal_new_from_data( values.options(), kLong, - r.deviceOptional(4), - r.pyobject(0), + r.deviceOptional(ARG_DEVICE1), + r.pyobject(ARG_INDICES1), /*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/false); return at::sparse_coo_tensor( indices, values, - r.intlist(2), + r.intlist(ARG_SIZE1), values.options().layout(at::kSparse)) - .set_requires_grad(r.toBool(5)); + .set_requires_grad(r.toBool(ARG_REQUIRES_GRAD1)); } else if (r.idx == 2) { - const auto inferred_options = typeIdWithDefault(r, 2, dispatch_key); - const auto inferred_scalar_type = r.scalartypeWithDefault(1, scalar_type); - at::OptionalDeviceGuard device_guard(r.deviceOptional(2)); + const auto inferred_options = + typeIdWithDefault(r, ARG_DEVICE2, dispatch_key); + const auto inferred_scalar_type = + r.scalartypeWithDefault(ARG_TYPE2, scalar_type); + at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE2)); + at::globalContext().setCheckSparseTensorInvariants( + r.toBoolWithDefault(ARG_CHECK_INVARIANTS2, default_check_invariants)); + return at::sparse_coo_tensor( - r.intlist(0), + r.intlist(ARG_SIZE2), inferred_options.dtype(inferred_scalar_type).layout(at::kSparse)) - .set_requires_grad(r.toBool(3)); + .set_requires_grad(r.toBool(ARG_REQUIRES_GRAD2)); } throw std::runtime_error("sparse_coo_tensor(): invalid arguments"); } diff --git a/torch/sparse/__init__.py b/torch/sparse/__init__.py index 2a921a4a26ba33..3ceaf56fc203f5 100644 --- a/torch/sparse/__init__.py +++ b/torch/sparse/__init__.py @@ -18,6 +18,7 @@ __all__ = [ 'addmm', + 'check_sparse_tensor_invariants', 'mm', 'sum', 'softmax', @@ -356,3 +357,108 @@ def sum(input: Tensor, dim: DimOrDims = None, [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]) """) + + +class check_sparse_tensor_invariants(object): + """A tool to control checking sparse tensor invariants. + +The following options exists to manage sparsr tensor invariants +checking in sparse tensor construction: + +1. Using a context manager: + + .. code:: python + + with torch.sparse.check_sparse_tensor_invariants(): + run_my_model() + +2. Using a procedural approach: + + .. code:: python + + prev_checks_enabled = torch.sparse.check_sparse_tensor_invariants.is_enabled() + torch.sparse.check_sparse_tensor_invariants.enable() + + run_my_model() + + if not prev_checks_enabled: + torch.sparse.check_sparse_tensor_invariants.disable() + +3. Using function decoration: + + .. code:: python + + @torch.sparse.check_sparse_tensor_invariants() + def run_my_model(): + ... + + run_my_model() + +4. Using ``check_invariants`` keyword argument in sparse tensor constructor call. + For example: + + >>> torch.sparse_csr_tensor([0, 1, 3], [0, 1], [1, 2], check_invariants=True) + Traceback (most recent call last): + File "", line 1, in + RuntimeError: `crow_indices[..., -1] == nnz` is not satisfied. + """ + + @staticmethod + def is_enabled(): + r"""Returns True if the sparse tensor invariants checking is enabled. + +.. note:: + + Use :func:`torch.sparse.check_sparse_tensor_invariants.enable` or + :func:`torch.sparse.check_sparse_tensor_invariants.disable` to + manage the state of the sparse tensor invariants checks. + """ + return torch._C._check_sparse_tensor_invariants() + + @staticmethod + def enable(): + r"""Enable sparse tensor invariants checking in sparse tensor constructors. + +.. note:: + + By default, the sparse tensor invariants checks are disabled. Use + :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled` to + retrieve the current state of sparse tensor invariants checking. + +.. note:: + + The sparse tensor invariants check flag is effective to all sparse + tensor constructors, both in Python and ATen. + + The flag can be locally overridden by the ``check_invariants`` + optional argument of the sparse tensor constructor functions. + """ + torch._C._set_check_sparse_tensor_invariants(True) + + @staticmethod + def disable(): + r"""Disable sparse tensor invariants checking in sparse tensor constructors. + +See :func:`torch.sparse.check_sparse_tensor_invariants.enable` for more information. + """ + torch._C._set_check_sparse_tensor_invariants(False) + + # context manager support + def __init__(self, enable=True): + self.state = enable + self.saved_state = self.is_enabled() + + def __enter__(self): + torch._C._set_check_sparse_tensor_invariants(self.state) + + def __exit__(self, type, value, traceback): + torch._C._set_check_sparse_tensor_invariants(self.saved_state) + + # decorator support + def __call__(self, mth): + + def test_mth(*args, **kwargs): + with type(self)(self.state): + return mth(*args, **kwargs) + + return test_mth diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 407de7c829159d..f6161990ce13c9 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -2220,6 +2220,29 @@ def setUp(self): check_if_enable(self) set_rng_seed(SEED) + # Save global check sparse tensor invariants state that can be + # restored from tearDown: + self._check_invariants = torch.sparse.check_sparse_tensor_invariants.is_enabled() + + # Enable invariant checks for all sparse tensors constructions + # including the unsafe ones. If this is not desired for some + # test case, use check_invariants=False optional argument to + # sparse tensor constructors or + # @torch.sparse.check_sparse_tensor_invariants(False) + # decorator to disable the invariant checks. + torch.sparse.check_sparse_tensor_invariants.enable() + + def tearDown(self): + # There exists test cases that override TestCase.setUp + # definition, so we cannot assume that _check_invariants + # attribute is defined in general. + if hasattr(self, '_check_invariants'): + # Restore the global check sparse tensor invariants state + if self._check_invariants: + torch.sparse.check_sparse_tensor_invariants.enable() + else: + torch.sparse.check_sparse_tensor_invariants.disable() + @staticmethod def _make_crow_indices(n_rows, n_cols, nnz, *, device, dtype, random=True):