Skip to content

Commit

Permalink
Add check-sparse-tensor-invariants flag to Context - 2nd try. (pytorc…
Browse files Browse the repository at this point in the history
…h#92094)

This PR is a copy of pytorch#90849 that merge was reverted.

The PR adds "check sparse tensor invariants" flag to Context that when enabled will trigger sparse tensor data invariants checks in unsafe methods of constructing sparse COO/CSR/CSC/BSR/BSC tensors. The feature includes the following changes to UI:

`torch.sparse.check_sparse_tensor_invariants` class provides different ways to enable/disable the invariant checking.

`torch.sparse_coo/csr/csc/bsr/bsc/compressed_tensor` functions have a new optional argument `check_invariants` to enable/disable the invariant checks explicitly. When the `check_invariants` argument is specified, the global state of the feature is temporarily overridden.

The PR fixes pytorch#90833

Pull Request resolved: pytorch#92094
Approved by: https://github.com/cpuhrsch
  • Loading branch information
pearu authored and pytorchmergebot committed Jan 13, 2023
1 parent a111dd9 commit b3e4f50
Show file tree
Hide file tree
Showing 18 changed files with 493 additions and 96 deletions.
8 changes: 8 additions & 0 deletions aten/src/ATen/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,14 @@ bool Context::isXNNPACKAvailable() {
#endif
}

void Context::setCheckSparseTensorInvariants(bool e) {
enable_sparse_tensor_invariant_checks = e;
}

bool Context::checkSparseTensorInvariants() const {
return enable_sparse_tensor_invariant_checks;
}

bool Context::releaseWeightsWhenPrepacking() const {
return release_original_weights;
}
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ class TORCH_API Context {
void setQEngine(at::QEngine e);
static const std::vector<at::QEngine>& supportedQEngines();
static bool isXNNPACKAvailable();
void setCheckSparseTensorInvariants(bool e);
bool checkSparseTensorInvariants() const;
// This method is used to release the original weight after pre-packing.
// It should be called once before loading/running the model.
// NB: By default it is set to true for mobile builds.
Expand Down Expand Up @@ -305,6 +307,7 @@ class TORCH_API Context {
#endif
bool display_vmap_fallback_warnings_ = false;
c10::optional<at::QEngine> quantized_engine = c10::nullopt;
bool enable_sparse_tensor_invariant_checks = false;

Allocator* prev_allocator_ptr_{nullptr};
};
Expand Down
10 changes: 6 additions & 4 deletions aten/src/ATen/native/sparse/SparseCsrTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,9 @@ Tensor _sparse_compressed_tensor_unsafe(const Tensor& compressed_indices,
}
Layout layout_ = layout.value();
AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor_unsafe", [&]{});
if (at::globalContext().checkSparseTensorInvariants()) {
_validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_);
}
TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
SparseCsrTensor self = new_compressed_tensor(options);
get_sparse_csr_impl(self)->set_member_tensors(compressed_indices, plain_indices, values, size);
Expand All @@ -373,6 +376,9 @@ Tensor _sparse_compressed_tensor_unsafe_template(const Tensor& compressed_indice
c10::optional<bool> pin_memory) {
Layout layout_ = layout.value_or(required_layout);
TORCH_CHECK(layout_ == required_layout, "sparse compressed layout must be ",required_layout, " but got ", layout_);
if (at::globalContext().checkSparseTensorInvariants()) {
_validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_);
}
TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
SparseCsrTensor self = new_compressed_tensor(options);
get_sparse_csr_impl(self)->set_member_tensors(compressed_indices, plain_indices, values, size);
Expand Down Expand Up @@ -474,8 +480,6 @@ Tensor sparse_compressed_tensor(
// See [Note: hacky wrapper removal for TensorOptions]
TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);

_validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_);

return at::native::_sparse_compressed_tensor_unsafe(
compressed_indices,
plain_indices,
Expand Down Expand Up @@ -507,8 +511,6 @@ Tensor sparse_compressed_tensor(
// See [Note: hacky wrapper removal for TensorOptions]
TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);

_validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout_);

return at::native::_sparse_compressed_tensor_unsafe(
compressed_indices,
plain_indices,
Expand Down
20 changes: 4 additions & 16 deletions aten/src/ATen/native/sparse/SparseTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,8 +398,6 @@ Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values, IntArrayRe
!options.has_layout() || options.layout() == kSparse,
"expected sparse layout, but got layout ",
options.layout());

at::native::_validate_sparse_coo_tensor_args(indices, values, size);
return at::native::_sparse_coo_tensor_unsafe(
indices,
values,
Expand All @@ -415,20 +413,10 @@ Tensor _sparse_coo_tensor_unsafe(const Tensor& indices, const Tensor& values_, a
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
return at::native::_sparse_coo_tensor_unsafe_symint(indices, values_, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory);

Tensor values = expand_values_if_needed(values_);

auto sparse_dim = indices.size(0);
auto dense_dim = values.dim() - 1;

return at::_sparse_coo_tensor_with_dims_and_tensors(
sparse_dim,
dense_dim,
size,
indices,
values,
values.options().layout(kSparse));
if (at::globalContext().checkSparseTensorInvariants()) {
at::native::_validate_sparse_coo_tensor_args(indices, values_, size);
}
return at::native::_sparse_coo_tensor_unsafe_symint(indices, values_, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory);
}

// NOTE: _sparse_coo_tensor_unsafe() differs from sparse_coo_tensor()
Expand Down
30 changes: 27 additions & 3 deletions docs/source/sparse.rst
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,15 @@ invariants:
Dense dimensions always follow sparse dimensions, that is, mixing
of dense and sparse dimensions is not supported.

.. note::

To be sure that a constructed sparse tensor has consistent indices,
values, and size, the invariant checks can be enabled per tensor
creation via ``check_invariants=True`` keyword argument, or
globally using :class:`torch.sparse.check_sparse_tensor_invariants`
context manager instance. By default, the sparse tensor invariants
checks are disabled.

.. _sparse-uncoalesced-coo-docs:

Uncoalesced sparse COO tensors
Expand Down Expand Up @@ -530,6 +539,13 @@ __ https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_o
where ``plain_dim_size`` is the number of plain dimensions
(orthogonal to compressed dimensions, e.g. columns or rows).

To be sure that a constructed sparse tensor has consistent indices,
values, and size, the invariant checks can be enabled per tensor
creation via ``check_invariants=True`` keyword argument, or
globally using :class:`torch.sparse.check_sparse_tensor_invariants`
context manager instance. By default, the sparse tensor invariants
checks are disabled.

.. note::

The generalization of sparse compressed layouts to N-dimensional
Expand Down Expand Up @@ -646,9 +662,9 @@ argument is optional and will be deduced from the ``crow_indices`` and
>>> csr = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=torch.float64)
>>> csr
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4,
dtype=torch.float64)
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4,
dtype=torch.float64)
>>> csr.to_dense()
tensor([[1., 2.],
[3., 4.]], dtype=torch.float64)
Expand Down Expand Up @@ -1160,6 +1176,14 @@ The following :mod:`torch` functions support sparse tensors:
:func:`~torch.zeros`
:func:`~torch.zeros_like`

To manage checking sparse tensor invariants, see:

.. autosummary::
:toctree: generated
:nosignatures:

sparse.check_sparse_tensor_invariants

Unary functions
---------------

Expand Down
4 changes: 4 additions & 0 deletions docs/source/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ Creation Ops

tensor
sparse_coo_tensor
sparse_csr_tensor
sparse_csc_tensor
sparse_bsr_tensor
sparse_bsc_tensor
asarray
as_tensor
as_strided
Expand Down
107 changes: 107 additions & 0 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4072,6 +4072,113 @@ def test_basic(self):

class TestSparseAny(TestCase):

@onlyCPU
@all_sparse_layouts('layout', include_strided=False)
@torch.sparse.check_sparse_tensor_invariants(enable=False)
def test_check_sparse_tensor_invariants(self, layout):

if layout is torch.sparse_coo:

def create_invalid_tensor(check_invariants=None):
shape = (2, 2)
invalid_indices = torch.tensor([[0], [3]]) # column index is out of range
values = torch.tensor([1])
if check_invariants is None:
return torch.sparse_coo_tensor(invalid_indices, values, shape)
else:
return torch.sparse_coo_tensor(invalid_indices, values, shape, check_invariants=check_invariants)

expected_exception_message = 'size is inconsistent with indices: for dim 1, size is 2 but found index 3'

elif layout in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}:

def create_invalid_tensor(check_invariants=None):
shape = (2, 2)
compressed_indices = torch.tensor([0, 0, 1])
invalid_plain_indices = torch.tensor([3]) # index is out of range
if layout in {torch.sparse_bsr, torch.sparse_bsc}:
values = torch.tensor([[[1]]])
else:
values = torch.tensor([1])
if check_invariants is None:
return torch.sparse_compressed_tensor(compressed_indices, invalid_plain_indices, values, shape, layout=layout)
else:
return torch.sparse_compressed_tensor(compressed_indices, invalid_plain_indices, values, shape, layout=layout,
check_invariants=check_invariants)

if layout in {torch.sparse_csr, torch.sparse_bsr}:
expected_exception_message = r'`0 <= col_indices < ncols` is not satisfied.'
else:
expected_exception_message = r'`0 <= row_indices < nrows` is not satisfied.'

else:
raise NotImplementedError(layout)

# First, consider the case where invariant checks are disabled
# "globally" (read: within the context of this test method
# caller) as defined by check_sparse_tensor_invariants(False)
# decorator:
self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())

# Enable the invariant checks in a local context:
with torch.sparse.check_sparse_tensor_invariants():
self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())

# Leaving the local context must restore the "global" state of
# the invariant check feature:
self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())

# Since invariant checks are disabled by default, we can
# create an invalid sparse tensor without raising an
# exception:
r = create_invalid_tensor()
self.assertEqual(r.layout, layout)

# Or, when disabling the invariants check explicitly:
r = create_invalid_tensor(check_invariants=False)
self.assertEqual(r.layout, layout)

# Enabling invariant check via constructor's optional argument
# will raise an exception when sparse tensor invariants are
# violated:
with self.assertRaisesRegex(RuntimeError, expected_exception_message):
create_invalid_tensor(check_invariants=True)

# Check that the global invariant check flag has been restored
# after raising the exception above:
self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())

# Next, consider the case where invariant checks are enabled
# within a local context:
with torch.sparse.check_sparse_tensor_invariants():
self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())

# Since invariant checks are now enabled by default, an
# attempt to create an invalid sparse tensor will lead to
# an exception:
with self.assertRaisesRegex(RuntimeError, expected_exception_message):
create_invalid_tensor()

# Similarly, when enabling the invariant checks
# explicitly, invalid sparse tensor construction will lead
# to an exception:
with self.assertRaisesRegex(RuntimeError, expected_exception_message):
create_invalid_tensor(check_invariants=True)

# However, invariants check can be disabled via
# constructor's optional argument so that the invalid
# tensor is succesfully constructed:
r = create_invalid_tensor(check_invariants=False)
self.assertEqual(r.layout, layout)

# Check that the invariant check flag has been restored
# when leaving the constructor:
self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled())

# Double-check restoring the global state when leaving the
# local context:
self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled())

def test_generate_simple_inputs(self):
layouts = [torch.strided, torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc]

Expand Down
34 changes: 23 additions & 11 deletions test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,9 +1363,17 @@ def test_csr_matvec(self, device, dtype):

@onlyCUDA
@unittest.skipIf(not (CUDA11OrLater or TEST_WITH_ROCM), "Only CUDA 11+ is supported")
# hmm, the test passes ok on CUDA when Rocm is not available:
@skipCUDAIfRocmVersionLessThan((5, 2))
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_baddbmm(self, device, dtype):

# TODO: disable the invariant checks within torch.baddbmm that
# constructs unconventional csr tensors leading to
# RuntimeError: tensor dimensionality must be sum of batch,
# base, and dense dimensionalities (=0 + 2 + 0) but got 3
# when invariant checking is enabled. When done, undecorate run_test.
@torch.sparse.check_sparse_tensor_invariants(enable=False)
def run_test(c, a, a_batched, b, op_b=False, op_out=False, *, dtype=None, device=None):
alpha = complex(random.random(), random.random()) if dtype.is_complex else random.random()
beta = complex(random.random(), random.random()) if dtype.is_complex else random.random()
Expand All @@ -1388,8 +1396,8 @@ def run_test(c, a, a_batched, b, op_b=False, op_out=False, *, dtype=None, device
a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype)

# a_batched is a regular CSR tensor but with a batch dimension in the shape
a_batched = torch._sparse_csr_tensor_unsafe(
a.crow_indices(), a.col_indices(), a.values(), (batch_size, m, k))
a_batched = torch.sparse_csr_tensor(
a.crow_indices(), a.col_indices(), a.values(), (batch_size, m, k), check_invariants=False)

b = make_tensor((batch_size, k, n), dtype=dtype, device=device, noncontiguous=noncontiguous)
c = make_tensor((batch_size, m, n), dtype=dtype, device=device, noncontiguous=noncontiguous)
Expand Down Expand Up @@ -1420,9 +1428,13 @@ def run_test(a, a_batched, b, op_b=False, op_out=False, *, dtype=None, device=No
nnz = random.randint(0, m * k)
a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype)

# a_batched is a regular CSR tensor but with a batch dimension in the shape
a_batched = torch._sparse_csr_tensor_unsafe(
a.crow_indices(), a.col_indices(), a.values(), (batch_size, m, k))
# a_batched is a regular CSR tensor but with a batch
# dimension in the shape. It is unorthodox in PyTorch
# to represent a batch sparse tensor in this way,
# hence checking the tensor invariants is locally
# turned off.
a_batched = torch.sparse_csr_tensor(
a.crow_indices(), a.col_indices(), a.values(), (batch_size, m, k), check_invariants=False)

b = make_tensor((batch_size, k, n), dtype=dtype, device=device, noncontiguous=noncontiguous)
for op_b, op_out in itertools.product([True, False], repeat=2):
Expand Down Expand Up @@ -1549,8 +1561,8 @@ def ref_half_bfloat16(c, a, b, alpha=None, beta=None, out=None):
a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device)
a_data = a_data.mT if noncontiguous else a_data
a = torch._sparse_bsr_tensor_unsafe(a.crow_indices(), a.col_indices(),
a_data, (m * block_size, k * block_size))
a = torch.sparse_bsr_tensor(a.crow_indices(), a.col_indices(),
a_data, (m * block_size, k * block_size), check_invariants=False)
b = make_tensor((k * block_size, n * block_size), dtype=dtype, device=device, noncontiguous=noncontiguous)
c = make_tensor((m * block_size, n * block_size), dtype=dtype, device=device, noncontiguous=noncontiguous)
for op_b, op_out in itertools.product([True, False], repeat=2):
Expand Down Expand Up @@ -1585,8 +1597,8 @@ def test_block_addmv(self, device, dtype, index_dtype, block_size, noncontiguous
a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device)
a_data = a_data.mT if noncontiguous else a_data # Test column-major blocks
a = torch._sparse_bsr_tensor_unsafe(a.crow_indices(), a.col_indices(),
a_data, (m * block_size, k * block_size))
a = torch.sparse_bsr_tensor(a.crow_indices(), a.col_indices(),
a_data, (m * block_size, k * block_size), check_invariants=False)
b = make_tensor((k * block_size,), dtype=dtype, device=device, noncontiguous=noncontiguous)
c = make_tensor((m * block_size,), dtype=dtype, device=device, noncontiguous=noncontiguous)
self.run_test_block_addmm_addmv(torch.addmv, c, a, b, dtype=dtype, device=device)
Expand Down Expand Up @@ -1658,8 +1670,8 @@ def run_test(a, b, upper, transpose, unitriangular, op_out):
a = self.genSparseCSRTensor((m, m), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device)
a_data = a_data.mT if noncontiguous else a_data # Test column-major blocks
a = torch._sparse_bsr_tensor_unsafe(a.crow_indices(), a.col_indices(),
a_data, (m * block_size, m * block_size))
a = torch.sparse_bsr_tensor(a.crow_indices(), a.col_indices(),
a_data, (m * block_size, m * block_size), check_invariants=False)
b = make_tensor((m * block_size, k), dtype=dtype, device=device, noncontiguous=noncontiguous)

for (upper, unitriangular, transpose, op_out) in itertools.product([True, False], repeat=4):
Expand Down
Loading

0 comments on commit b3e4f50

Please sign in to comment.