forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add spdiags sparse matrix initialization (pytorch#78439)
Similar to [scipy.sparse.spdiags](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.spdiags.html#scipy-sparse-spdiags) Part of pytorch#70926 In other functions (ie (torch.diagonal)[https://pytorch.org/docs/stable/generated/torch.diagonal.html#torch.diagonal]) diagonals of a tensor are referenced using the offset and the two dimensions that the diagonal is taken with respect to. Here the reference implementation from scipy is only considering matrix output, so even if we only support 2-d output at first. It may be useful to consider how the dimensions corresponding to each diagonal would be specified for higher dimensional output. The proposed torch signature implies that all offsets refer to the diagonals with respect to the only two dimensions of the output: ``` torch.sparse.spdiags(Tensor diagonals, IntTensor offsets, int[] shape, Layout? layout=None) -> SparseTensor ``` Above it is required that: `diagonals.ndimension() == 2`, `offsets.ndimensions() == 1`, `offsets.shape[0] == diagonals.shape[0]` and `len(shape) == 2`. This would need to be altered for the case where `len(shape)` > 2. One options is: ``` torch.sparse.spdiags(Tensor[] diagonals, IntTensor[] offsets, IntTensor dims, int[] shape, Layout? layout=None) -> SparseTensor ``` Here `offsets` and `diagonals` becomes lists of tensors, and the `IntTensor dims` argument is introduced. This would require that `len(diagonals) == len(offsets) == dims.shape[0]`, `dims.ndimension() == 2` and `dims.shape[1] == 2` also the same restrictions as the 2d case above apply to the elements of `diagonals` and `offsets` pairwise (that is `diagonals[i].ndimension() == 2`, `offsets[i].ndimension() == 1` and `offsets[i].shape[0] == diagonals[i].shape[0]` for all i). This form of the signature would construct the sparse result by placing the values from `diagonals[i][j]` into the diagonal with offset `offset[i][j]` taken with respect to dimensions `dims[i]`. The specialization back to the original signature for the 2d case could be seen as allowing the single row of dims to default to `[0, 1]` when there is only one `diagonals`, `offsets` provided, and shape is `2-d`. This option allows the rows of an input element `diagonals[i]` to have a different length which may be appropriate as the max length of a diagonal along different dimension pairs will be different. Another option is to specify the dimensions the diagonal is taken with respect to for each offset. This signature would look like: ``` torch.sparse.spdiags(Tensor diagonals, IntTensor offsets, IntTensor dims, int[] shape, Layout? layout=None) -> SparseTensor ``` Here, `diagonals` is still 2-D with dimension 0 matching the length of 1-D `offsets` and the tensor input `dims` is also 2-D with dimension 0 matching the length of 1-D `offsets` and the second dimension being fixed at `2` in this case the sparse result is constructed by placing the elements from `diagonals[i]` into the output diagonal `output.diagonal(offset[i], dim0=dims[i][0], dim1=dims[i][1])` (with some additional consideration that makes it more complicated than simply asigning to that view). The specialization from this back to the 2-D form could be seen as assuming `dims = [[0, 1], [0, 1]... len(offsets) times ]` when `len shape==2`. In both proposed signatures for the N-D case the specialization back to the 2-D signature is a bit of a stretch for your typical default arguments logic, however I think the first is better choice as it offers more flexibility. I think some discussion is required about: - [x] Should the N-D output case be implemented from the outset - [x] If not, should the future addition of the N-D output case be considered when designing the interface. - [x] Other thoughts on the signature which includes the `dims` information for the N-D output case. **Resolution**: Since no one has offered a request for N-D output support, I think is fine to restrict this to sparse matrix generation. Should a request for N-D support come later, an overload accepting the additional `dims` could be added. Pull Request resolved: pytorch#78439 Approved by: https://github.com/nikitaved, https://github.com/cpuhrsch, https://github.com/pearu
- Loading branch information
1 parent
e5162dc
commit 5a4c9e8
Showing
8 changed files
with
378 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
#include <ATen/Dispatch.h> | ||
#include <ATen/SparseTensorImpl.h> | ||
#include <ATen/SparseTensorUtils.h> | ||
#include <ATen/TensorIndexing.h> | ||
#include <ATen/TensorIterator.h> | ||
#include <ATen/core/ATen_fwd.h> | ||
#include <ATen/core/Tensor.h> | ||
#include <ATen/native/cpu/Loops.h> | ||
#include <ATen/native/sparse/SparseFactories.h> | ||
#include <c10/core/Scalar.h> | ||
#include <c10/util/ArrayRef.h> | ||
#include <c10/util/Exception.h> | ||
|
||
#ifndef AT_PER_OPERATOR_HEADERS | ||
#include <ATen/Functions.h> | ||
#include <ATen/NativeFunctions.h> | ||
#else | ||
#include <ATen/ops/sparse_coo_tensor.h> | ||
#endif | ||
|
||
namespace at { | ||
namespace native { | ||
using namespace at::sparse; | ||
|
||
namespace { | ||
void _spdiags_kernel_cpu( | ||
TensorIterator& iter, | ||
const Tensor& diagonals, | ||
Tensor& values, | ||
Tensor& indices) { | ||
auto* row_index_write_ptr = indices[0].data_ptr<int64_t>(); | ||
auto* col_index_write_ptr = indices[1].data_ptr<int64_t>(); | ||
const int64_t diagonals_read_stride = diagonals.stride(1); | ||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( | ||
at::ScalarType::BFloat16, | ||
at::ScalarType::Half, | ||
at::ScalarType::Bool, | ||
at::ScalarType::ComplexHalf, | ||
diagonals.scalar_type(), | ||
"spdiags_cpu", | ||
[&] { | ||
auto* values_write_ptr = values.data_ptr<scalar_t>(); | ||
cpu_kernel( | ||
iter, | ||
[&](int64_t diag_index, | ||
int64_t diag_offset, | ||
int64_t out_offset, | ||
int64_t n_out) -> int64_t { | ||
if (n_out > 0) { | ||
auto* rows_start = row_index_write_ptr + out_offset; | ||
auto* cols_start = col_index_write_ptr + out_offset; | ||
auto* vals_start = values_write_ptr + out_offset; | ||
const int64_t first_col = std::max<int64_t>(diag_offset, 0); | ||
const int64_t first_row = first_col - diag_offset; | ||
auto* data_read = diagonals[diag_index].data_ptr<scalar_t>() + | ||
first_col * diagonals_read_stride; | ||
for (int64_t i = 0; i < n_out; ++i) { | ||
rows_start[i] = first_row + i; | ||
cols_start[i] = first_col + i; | ||
vals_start[i] = data_read[i * diagonals_read_stride]; | ||
} | ||
} | ||
// dummy return | ||
return 0; | ||
}); | ||
}); | ||
} | ||
|
||
} // namespace | ||
|
||
REGISTER_DISPATCH(spdiags_kernel_stub, &_spdiags_kernel_cpu) | ||
|
||
} // namespace native | ||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
#include <ATen/Dispatch.h> | ||
#include <ATen/native/sparse/SparseFactories.h> | ||
|
||
#ifndef AT_PER_OPERATOR_HEADERS | ||
#include <ATen/Functions.h> | ||
#include <ATen/NativeFunctions.h> | ||
#else | ||
#include <ATen/ops/_unique.h> | ||
#include <ATen/ops/arange.h> | ||
#include <ATen/ops/empty.h> | ||
#include <ATen/ops/sparse_coo_tensor.h> | ||
#include <ATen/ops/where.h> | ||
#endif | ||
|
||
namespace at { | ||
namespace native { | ||
|
||
DEFINE_DISPATCH(spdiags_kernel_stub); | ||
|
||
Tensor spdiags( | ||
const Tensor& diagonals, | ||
const Tensor& offsets, | ||
IntArrayRef shape, | ||
c10::optional<Layout> layout) { | ||
auto diagonals_2d = diagonals.dim() == 1 ? diagonals.unsqueeze(0) : diagonals; | ||
TORCH_CHECK(diagonals_2d.dim() == 2, "Diagonals must be vector or matrix"); | ||
TORCH_CHECK(shape.size() == 2, "Output shape must be 2d"); | ||
auto offsets_1d = offsets.dim() == 0 ? offsets.unsqueeze(0) : offsets; | ||
TORCH_CHECK(offsets_1d.dim() == 1, "Offsets must be scalar or vector"); | ||
TORCH_CHECK( | ||
diagonals_2d.size(0) == offsets_1d.size(0), | ||
"Number of diagonals (", | ||
diagonals_2d.size(0), | ||
") does not match the number of offsets (", | ||
offsets_1d.size(0), | ||
")"); | ||
if (layout) { | ||
TORCH_CHECK( | ||
(*layout == Layout::Sparse) || (*layout == Layout::SparseCsc) || | ||
(*layout == Layout::SparseCsr), | ||
"Only output layouts (Sparse, SparseCsc, SparseCsr) are supported, got ", | ||
*layout); | ||
} | ||
TORCH_CHECK( | ||
offsets_1d.scalar_type() == at::kLong, | ||
"Offset Tensor must have dtype Long but got ", | ||
offsets_1d.scalar_type()); | ||
|
||
TORCH_CHECK( | ||
offsets_1d.numel() == std::get<0>(at::_unique(offsets_1d)).numel(), | ||
"Offset tensor contains duplicate values"); | ||
|
||
auto nnz_per_diag = at::where( | ||
offsets_1d.le(0), | ||
offsets_1d.add(shape[0]).clamp_max_(diagonals_2d.size(1)), | ||
offsets_1d.add(-std::min<int64_t>(shape[1], diagonals_2d.size(1))).neg()); | ||
|
||
auto nnz_per_diag_cumsum = nnz_per_diag.cumsum(-1); | ||
const auto nnz = diagonals_2d.size(0) > 0 | ||
? nnz_per_diag_cumsum.select(-1, -1).item<int64_t>() | ||
: int64_t{0}; | ||
// Offsets into nnz for each diagonal | ||
auto result_mem_offsets = nnz_per_diag_cumsum.sub(nnz_per_diag); | ||
// coo tensor guts | ||
auto indices = at::empty({2, nnz}, offsets_1d.options()); | ||
auto values = at::empty({nnz}, diagonals_2d.options()); | ||
// We add this indexer to lookup the row of diagonals we are reading from at | ||
// each iteration | ||
const auto n_diag = offsets_1d.size(0); | ||
Tensor diag_index = at::arange(n_diag, offsets_1d.options()); | ||
// cpu_kernel requires an output | ||
auto dummy = at::empty({1}, offsets_1d.options()).resize_({0}); | ||
auto iter = TensorIteratorConfig() | ||
.set_check_mem_overlap(false) | ||
.add_output(dummy) | ||
.add_input(diag_index) | ||
.add_input(offsets_1d) | ||
.add_input(result_mem_offsets) | ||
.add_input(nnz_per_diag) | ||
.build(); | ||
spdiags_kernel_stub(iter.device_type(), iter, diagonals_2d, values, indices); | ||
auto result_coo = at::sparse_coo_tensor(indices, values, shape); | ||
if (layout) { | ||
if (*layout == Layout::SparseCsr) { | ||
return result_coo.to_sparse_csr(); | ||
} | ||
if (*layout == Layout::SparseCsc) { | ||
return result_coo.to_sparse_csc(); | ||
} | ||
} | ||
return result_coo; | ||
} | ||
|
||
} // namespace native | ||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#pragma once | ||
#include <ATen/TensorIterator.h> | ||
#include <ATen/core/ATen_fwd.h> | ||
#include <ATen/core/Tensor.h> | ||
#include <ATen/native/DispatchStub.h> | ||
|
||
namespace at { | ||
namespace native { | ||
|
||
using spdiags_kernel_fn_t = | ||
void (*)(TensorIterator&, const Tensor&, Tensor&, Tensor&); | ||
|
||
DECLARE_DISPATCH(spdiags_kernel_fn_t, spdiags_kernel_stub); | ||
} // namespace native | ||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.