Skip to content

Commit

Permalink
Add spdiags sparse matrix initialization (pytorch#78439)
Browse files Browse the repository at this point in the history
Similar to [scipy.sparse.spdiags](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.spdiags.html#scipy-sparse-spdiags)

Part of pytorch#70926

In other functions (ie (torch.diagonal)[https://pytorch.org/docs/stable/generated/torch.diagonal.html#torch.diagonal]) diagonals of a tensor are referenced using the offset and the two dimensions that the diagonal is taken with respect to.

Here the reference implementation from scipy is only considering matrix output, so even if we only support 2-d output at first. It may be useful to consider how the dimensions corresponding to each diagonal would be specified for higher dimensional output.

The proposed torch signature implies that all offsets refer to the diagonals with respect to the only two dimensions of the output:

```
torch.sparse.spdiags(Tensor diagonals, IntTensor offsets, int[] shape, Layout? layout=None) -> SparseTensor
```
 Above it is required that: `diagonals.ndimension() == 2`, `offsets.ndimensions() == 1`, `offsets.shape[0] == diagonals.shape[0]` and `len(shape) == 2`.

This would need to be altered for the case where `len(shape)` > 2. One options is:
```
torch.sparse.spdiags(Tensor[] diagonals, IntTensor[] offsets, IntTensor dims, int[] shape, Layout? layout=None) -> SparseTensor
```

Here `offsets` and `diagonals` becomes lists of tensors, and the `IntTensor dims` argument is introduced. This would require that `len(diagonals) == len(offsets) == dims.shape[0]`, `dims.ndimension() == 2` and `dims.shape[1] == 2` also the same restrictions as the 2d case above apply to the elements of `diagonals` and `offsets` pairwise (that is `diagonals[i].ndimension() == 2`, `offsets[i].ndimension() == 1` and `offsets[i].shape[0] == diagonals[i].shape[0]` for all i). This form of the signature would construct the sparse result by placing the values from `diagonals[i][j]` into the diagonal with offset `offset[i][j]` taken with respect to dimensions `dims[i]`. The specialization back to the original signature for the 2d case could be seen as allowing the single row of dims to default to `[0, 1]` when there is only one `diagonals`, `offsets` provided, and shape is `2-d`. This option allows the rows of an input element `diagonals[i]` to have a different length which may be appropriate as the max length of a diagonal along different dimension pairs will be different.

Another option is to specify the dimensions the diagonal is taken with respect to for each offset. This signature would look like:

```
torch.sparse.spdiags(Tensor diagonals, IntTensor offsets, IntTensor dims, int[] shape, Layout? layout=None) -> SparseTensor
```
Here, `diagonals` is still 2-D with dimension 0 matching the length of 1-D `offsets` and the tensor input `dims` is also 2-D with dimension 0 matching the length of 1-D `offsets` and the second dimension being fixed at `2` in this case the sparse result is constructed by placing the elements from `diagonals[i]` into the output diagonal `output.diagonal(offset[i], dim0=dims[i][0], dim1=dims[i][1])` (with some additional consideration that makes it more complicated than simply asigning to that view). The specialization from this back to the 2-D form could be seen as assuming `dims = [[0, 1], [0, 1]... len(offsets) times ]` when `len shape==2`.

In both proposed signatures for the N-D case the specialization back to the 2-D signature is a bit of a stretch for your typical default arguments logic, however I think the first is better choice as it offers more flexibility.

I think some discussion is required about:
- [x] Should the N-D output case be implemented from the outset
- [x] If not, should the future addition of the N-D output case be considered when designing the interface.
- [x] Other thoughts on the signature which includes the `dims` information for the N-D output case.

**Resolution**: Since no one has offered a request for N-D output support, I think is fine to restrict this to sparse matrix generation. Should a request for N-D support come later, an overload accepting the additional `dims` could be added.

Pull Request resolved: pytorch#78439
Approved by: https://github.com/nikitaved, https://github.com/cpuhrsch, https://github.com/pearu
  • Loading branch information
amjames authored and pytorchmergebot committed Jul 1, 2022
1 parent e5162dc commit 5a4c9e8
Show file tree
Hide file tree
Showing 8 changed files with 378 additions and 1 deletion.
74 changes: 74 additions & 0 deletions aten/src/ATen/native/cpu/SparseFactories.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#include <ATen/Dispatch.h>
#include <ATen/SparseTensorImpl.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/TensorIndexing.h>
#include <ATen/TensorIterator.h>
#include <ATen/core/ATen_fwd.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/cpu/Loops.h>
#include <ATen/native/sparse/SparseFactories.h>
#include <c10/core/Scalar.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Exception.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/sparse_coo_tensor.h>
#endif

namespace at {
namespace native {
using namespace at::sparse;

namespace {
void _spdiags_kernel_cpu(
TensorIterator& iter,
const Tensor& diagonals,
Tensor& values,
Tensor& indices) {
auto* row_index_write_ptr = indices[0].data_ptr<int64_t>();
auto* col_index_write_ptr = indices[1].data_ptr<int64_t>();
const int64_t diagonals_read_stride = diagonals.stride(1);
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
at::ScalarType::BFloat16,
at::ScalarType::Half,
at::ScalarType::Bool,
at::ScalarType::ComplexHalf,
diagonals.scalar_type(),
"spdiags_cpu",
[&] {
auto* values_write_ptr = values.data_ptr<scalar_t>();
cpu_kernel(
iter,
[&](int64_t diag_index,
int64_t diag_offset,
int64_t out_offset,
int64_t n_out) -> int64_t {
if (n_out > 0) {
auto* rows_start = row_index_write_ptr + out_offset;
auto* cols_start = col_index_write_ptr + out_offset;
auto* vals_start = values_write_ptr + out_offset;
const int64_t first_col = std::max<int64_t>(diag_offset, 0);
const int64_t first_row = first_col - diag_offset;
auto* data_read = diagonals[diag_index].data_ptr<scalar_t>() +
first_col * diagonals_read_stride;
for (int64_t i = 0; i < n_out; ++i) {
rows_start[i] = first_row + i;
cols_start[i] = first_col + i;
vals_start[i] = data_read[i * diagonals_read_stride];
}
}
// dummy return
return 0;
});
});
}

} // namespace

REGISTER_DISPATCH(spdiags_kernel_stub, &_spdiags_kernel_cpu)

} // namespace native
} // namespace at
5 changes: 5 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5281,6 +5281,11 @@
SparseCPU: log_softmax_backward_sparse_cpu
SparseCUDA: log_softmax_backward_sparse_cuda

- func: _spdiags(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None) -> Tensor
python_module: sparse
dispatch:
CPU: spdiags

- func: norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
Expand Down
95 changes: 95 additions & 0 deletions aten/src/ATen/native/sparse/SparseFactories.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#include <ATen/Dispatch.h>
#include <ATen/native/sparse/SparseFactories.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_unique.h>
#include <ATen/ops/arange.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/sparse_coo_tensor.h>
#include <ATen/ops/where.h>
#endif

namespace at {
namespace native {

DEFINE_DISPATCH(spdiags_kernel_stub);

Tensor spdiags(
const Tensor& diagonals,
const Tensor& offsets,
IntArrayRef shape,
c10::optional<Layout> layout) {
auto diagonals_2d = diagonals.dim() == 1 ? diagonals.unsqueeze(0) : diagonals;
TORCH_CHECK(diagonals_2d.dim() == 2, "Diagonals must be vector or matrix");
TORCH_CHECK(shape.size() == 2, "Output shape must be 2d");
auto offsets_1d = offsets.dim() == 0 ? offsets.unsqueeze(0) : offsets;
TORCH_CHECK(offsets_1d.dim() == 1, "Offsets must be scalar or vector");
TORCH_CHECK(
diagonals_2d.size(0) == offsets_1d.size(0),
"Number of diagonals (",
diagonals_2d.size(0),
") does not match the number of offsets (",
offsets_1d.size(0),
")");
if (layout) {
TORCH_CHECK(
(*layout == Layout::Sparse) || (*layout == Layout::SparseCsc) ||
(*layout == Layout::SparseCsr),
"Only output layouts (Sparse, SparseCsc, SparseCsr) are supported, got ",
*layout);
}
TORCH_CHECK(
offsets_1d.scalar_type() == at::kLong,
"Offset Tensor must have dtype Long but got ",
offsets_1d.scalar_type());

TORCH_CHECK(
offsets_1d.numel() == std::get<0>(at::_unique(offsets_1d)).numel(),
"Offset tensor contains duplicate values");

auto nnz_per_diag = at::where(
offsets_1d.le(0),
offsets_1d.add(shape[0]).clamp_max_(diagonals_2d.size(1)),
offsets_1d.add(-std::min<int64_t>(shape[1], diagonals_2d.size(1))).neg());

auto nnz_per_diag_cumsum = nnz_per_diag.cumsum(-1);
const auto nnz = diagonals_2d.size(0) > 0
? nnz_per_diag_cumsum.select(-1, -1).item<int64_t>()
: int64_t{0};
// Offsets into nnz for each diagonal
auto result_mem_offsets = nnz_per_diag_cumsum.sub(nnz_per_diag);
// coo tensor guts
auto indices = at::empty({2, nnz}, offsets_1d.options());
auto values = at::empty({nnz}, diagonals_2d.options());
// We add this indexer to lookup the row of diagonals we are reading from at
// each iteration
const auto n_diag = offsets_1d.size(0);
Tensor diag_index = at::arange(n_diag, offsets_1d.options());
// cpu_kernel requires an output
auto dummy = at::empty({1}, offsets_1d.options()).resize_({0});
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.add_output(dummy)
.add_input(diag_index)
.add_input(offsets_1d)
.add_input(result_mem_offsets)
.add_input(nnz_per_diag)
.build();
spdiags_kernel_stub(iter.device_type(), iter, diagonals_2d, values, indices);
auto result_coo = at::sparse_coo_tensor(indices, values, shape);
if (layout) {
if (*layout == Layout::SparseCsr) {
return result_coo.to_sparse_csr();
}
if (*layout == Layout::SparseCsc) {
return result_coo.to_sparse_csc();
}
}
return result_coo;
}

} // namespace native
} // namespace at
15 changes: 15 additions & 0 deletions aten/src/ATen/native/sparse/SparseFactories.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once
#include <ATen/TensorIterator.h>
#include <ATen/core/ATen_fwd.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/DispatchStub.h>

namespace at {
namespace native {

using spdiags_kernel_fn_t =
void (*)(TensorIterator&, const Tensor&, Tensor&, Tensor&);

DECLARE_DISPATCH(spdiags_kernel_fn_t, spdiags_kernel_stub);
} // namespace native
} // namespace at
2 changes: 2 additions & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,7 @@ aten_native_source_codegen_list = [
"aten/src/ATen/native/cpu/scaled_modified_bessel_k0.cpp",
"aten/src/ATen/native/cpu/scaled_modified_bessel_k1.cpp",
"aten/src/ATen/native/cpu/spherical_bessel_j0.cpp",
"aten/src/ATen/native/cpu/SparseFactories.cpp",
"aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp",
]

Expand Down Expand Up @@ -1357,6 +1358,7 @@ aten_native_source_non_codegen_list = [
"aten/src/ATen/native/sparse/SparseTensorMath.cpp",
"aten/src/ATen/native/sparse/SparseUnaryOps.cpp",
"aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp",
"aten/src/ATen/native/sparse/SparseFactories.cpp",
"aten/src/ATen/native/transformers/attention.cpp",
"aten/src/ATen/native/transformers/transformer.cpp",
"aten/src/ATen/native/utils/Factory.cpp",
Expand Down
1 change: 1 addition & 0 deletions docs/source/sparse.rst
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,7 @@ Torch functions specific to sparse Tensors
smm
sparse.softmax
sparse.log_softmax
sparse.spdiags

Other functions
+++++++++++++++
Expand Down
93 changes: 92 additions & 1 deletion test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import unittest
from torch.testing import make_tensor
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \
do_test_empty_full, load_tests, TEST_NUMPY, IS_WINDOWS, gradcheck, coalescedonoff, \
do_test_empty_full, load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \
DeterministicGuard, first_sample
from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version
from numbers import Number
Expand All @@ -26,6 +26,9 @@
floating_and_complex_types_and, integral_types, floating_types_and,
)

if TEST_SCIPY:
import scipy.sparse

# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
Expand Down Expand Up @@ -3558,6 +3561,94 @@ def test(sparse_dims, nnz, with_size, new_size):
test(4, 6, [7, 3, 1, 3, 1, 3], [7, 3, 1, 3, 2, 3])
test(4, 6, [7, 3, 1, 3, 2, 1], [7, 3, 1, 3, 2, 3])

@unittest.skipIf(not TEST_NUMPY, "NumPy is not availible")
@onlyCPU
@dtypes(*all_types_and_complex_and(torch.bool))
def test_sparse_spdiags(self, device, dtype):

make_diags = functools.partial(make_tensor, dtype=dtype, device=device)
make_offsets = functools.partial(torch.tensor, dtype=torch.long, device=device)

if TEST_SCIPY:
def reference(diags, offsets, shape):
return scipy.sparse.spdiags(diags, offsets, *shape).toarray()

else:
def reference(diags, offsets, shape):
result = torch.zeros(shape, dtype=dtype, device=device)
for i, off in enumerate(offsets):
res_view = result.diagonal(off)
data = diags[i]
if off > 0:
data = data[off:]

m = min(res_view.shape[0], data.shape[0])
res_view[:m] = data[:m]
return result

def check_valid(diags, offsets, shape, layout=None):
ref_out = reference(diags, offsets, shape)
out = torch.sparse.spdiags(diags, offsets, shape, layout=layout)
if layout is None:
ex_layout = torch.sparse_coo
else:
ex_layout = layout
out_dense = out.to_dense()
self.assertTrue(out.layout == ex_layout, f"Output layout {out.layout} expected {ex_layout}")
self.assertEqual(out_dense, ref_out, f"Result:\n{out_dense} does not match reference:\n{ref_out}")

def check_invalid(args, error):
with self.assertRaisesRegex(RuntimeError, error):
torch.sparse.spdiags(*args)

def valid_cases():
# some normal cases
yield (make_diags((1, 5)), make_offsets([0]), (5, 5))
yield (make_diags((3, 3)), make_offsets([-1, 0, 1]), (4, 4))
# noncontigous diags
yield (make_diags((5, 4), noncontiguous=True), make_offsets([-1, 1, 0, 2, -2]), (5, 5))
# noncontigous offsets
yield (make_diags((3, 4)), make_offsets([1, -1, 0, -2, 2])[::2], (5, 5))
# noncontigous diags + offsets
yield (make_diags((3, 4), noncontiguous=True), make_offsets([1, -1, 0, -2, 2])[::2], (5, 5))
# correct dimensionality, 2d, 2d , and shapes match, but the number of diagonals is zero
yield (make_diags((0, 3)), make_offsets([]), (3, 3))
# forward rotation of upper diagonals
yield (make_diags((3, 8)), make_offsets([1, 2, 3]), (4, 4))
# rotation exausts input space to read from
yield (make_diags((2, 3)), make_offsets([2, 1]), (3, 3))
# Simple cases repeated with special output format
yield (make_diags((1, 5)), make_offsets([0]), (5, 5), torch.sparse_csc)
yield (make_diags((3, 3)), make_offsets([-1, 0, 1]), (4, 4), torch.sparse_csr)
# vector diags
yield (make_diags((3, )), make_offsets([1]), (4, 4))
# Scalar offset
yield (make_diags((1, 3)), make_offsets(2), (4, 4))
# offsets out of range
yield (make_diags((1, 3)), make_offsets([3]), (3, 3))
yield (make_diags((1, 3)), make_offsets([-3]), (3, 3))

for case in valid_cases():
check_valid(*case)

def invalid_cases():
yield (make_diags((1, 3)), make_offsets([0]), (3, 2, 3)), "Output shape must be 2d"
yield (make_diags((2, 3)), make_offsets([[1, 2], [0, 3]]), (3, 3)), "Offsets must be scalar or vector"
yield (make_diags((3, 2, 3)), make_offsets([0, 1, 2]), (4, 4)), "Diagonals must be vector or matrix"
yield (make_diags((3, 3)), make_offsets([-1, 0]), (3, 3)),\
r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)"
yield (make_diags((5,)), make_offsets([0, 1, 2, 3, 4]), (3, 3)),\
r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)"
yield (make_diags((2, 2)), make_offsets([-1, 0]), (2, 3), torch.strided),\
r"Only output layouts \(\w+, \w+, \w+\) are supported, got \w+"
yield (make_diags((2, 5)), make_offsets([0, 0]), (5, 5)), "Offset tensor contains duplicate values"
yield (make_diags((1, 5)), make_offsets([0]).to(torch.int32), (5, 5)), r"Offset Tensor must have dtype Long but got \w+"


for case, error_regex in invalid_cases():
check_invalid(case, error_regex)



class TestSparseOneOff(TestCase):
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
Expand Down
Loading

0 comments on commit 5a4c9e8

Please sign in to comment.