Skip to content

Commit

Permalink
Row and column select support for block compressed sparse tensors (py…
Browse files Browse the repository at this point in the history
…torch#88733)

As in the title:

- Support `select` and `select_copy` on block sparse compressed tensors
- Fixes incorrect results when selecting dense dimensions

The PR also improves the performance of indexing sparse compressed tensors considerably:

<details>

Before:

```python
In [3]: a=torch.rand((1000, 1000)).to_sparse_csr()

In [4]: %timeit a.select(0, 0)
606 µs ± 4.27 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [5]: %timeit a.select(1, 0)
527 µs ± 57.7 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [6]: %timeit a[0, 0]
617 µs ± 3.74 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [7]: a = a.cuda()

In [8]: %timeit a.select(0, 0); torch.cuda.synchronize();
1.19 ms ± 137 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [9]: %timeit a.select(1, 0); torch.cuda.synchronize();
1.2 ms ± 119 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [10]: %timeit a[0, 0]; torch.cuda.synchronize();
1.23 ms ± 482 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
```

This PR:

```python
In [3]: a=torch.rand((1000, 1000)).to_sparse_csr()

In [4]: %timeit a.select(0, 0)
4.75 µs ± 8.94 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

In [5]: %timeit a.select(1, 0)
565 µs ± 156 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [6]: %timeit a[0, 0]
13.1 µs ± 435 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

In [7]: a = a.cuda()

In [8]: %timeit a.select(0, 0); torch.cuda.synchronize();
21.6 µs ± 23.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [9]: %timeit a.select(1, 0); torch.cuda.synchronize();
1.15 ms ± 3.13 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [10]: %timeit a[0, 0]; torch.cuda.synchronize();
63.7 µs ± 2.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
```

</details>

Pull Request resolved: pytorch#88733
Approved by: https://github.com/nikitaved, https://github.com/amjames, https://github.com/cpuhrsch
  • Loading branch information
pearu authored and pytorchmergebot committed Nov 30, 2022
1 parent 0cc0e5e commit 296e1ba
Show file tree
Hide file tree
Showing 5 changed files with 351 additions and 66 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12921,6 +12921,7 @@
variants: function
dispatch:
CompositeExplicitAutogradNonFunctional: select_copy_symint
SparseCsrCPU, SparseCsrCUDA: select_copy_sparse_csr
tags: view_copy

- func: detach_copy(Tensor self) -> Tensor
Expand Down
299 changes: 283 additions & 16 deletions aten/src/ATen/native/sparse/SparseCsrTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <ATen/ops/_sparse_bsr_tensor_unsafe_native.h>
#include <ATen/ops/_sparse_bsc_tensor_unsafe_native.h>
#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
#include <ATen/ops/_sparse_coo_tensor_unsafe.h>
#include <ATen/ops/_validate_sparse_compressed_tensor_args_native.h>
#include <ATen/ops/_validate_sparse_csr_tensor_args_native.h>
#include <ATen/ops/_validate_sparse_csc_tensor_args_native.h>
Expand All @@ -42,6 +43,8 @@
#include <ATen/ops/resize_native.h>
#include <ATen/ops/row_indices_native.h>
#include <ATen/ops/select_native.h>
#include <ATen/ops/select_copy.h>
#include <ATen/ops/select_copy_native.h>
#include <ATen/ops/sparse_compressed_tensor_native.h>
#include <ATen/ops/sparse_csr_tensor_native.h>
#include <ATen/ops/sparse_csc_tensor_native.h>
Expand All @@ -50,6 +53,7 @@
#include <ATen/ops/sparse_dim_native.h>
#include <ATen/ops/values_native.h>
#include <ATen/ops/_validate_compressed_sparse_indices.h>
#include <ATen/ops/where.h>
#endif

namespace at {
Expand All @@ -59,6 +63,50 @@ using namespace at::sparse_csr;

namespace {

bool solve_arange(const Tensor& input, int64_t& start, int64_t& end, int64_t& step) {
/*
This function solves the equation
input == arange(start, end, step)
for integers start, end, and step, if possible. If the solution
exists, returns true.
*/
int64_t n = input.numel();
if (n == 0) {
// a trivial solution
start = end = 0;
step = 1;
} else if (n == 1) {
// a simple solution
start = input[0].item<int64_t>();
end = start + 1;
step = 1;
} else {
Tensor first_last = input.slice(0, 0, n, n - 1).cpu();
int64_t start_candidate = first_last[0].item<int64_t>();
int64_t end_candidate = first_last[1].item<int64_t>() + 1;
if (end_candidate - start_candidate == n) {
// a special solution
start = start_candidate;
end = end_candidate;
step = 1;
} else {
// detect if general solution exists
Tensor possible_steps = input.slice(0, 1).sub(input.slice(0, 0, n - 1));
Tensor possible_step = possible_steps[0];
if ((possible_steps.eq(possible_step)).all().item<bool>()) {
start = start_candidate;
end = end_candidate;
step = possible_step.item<int64_t>();
} else {
// no solution
return false;
}
}
}
return true;
}

} // end anonymous namespace

Expand Down Expand Up @@ -744,17 +792,19 @@ Tensor empty_like_sparse_csr(
}
}

Tensor select_sparse_csr(const Tensor& self, int64_t dim, int64_t index) {
template <bool require_view, bool require_copy>
Tensor select_sparse_csr_worker(const Tensor& self, int64_t dim, int64_t index) {
constexpr const char* select_name = (require_view ? "select()" : "select_copy()");
AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(
self.layout(), "select()", []() { return; });
self.layout(), "select", []() { return; });
TORCH_CHECK_INDEX(
self.dim() != 0, "select() cannot be applied to a 0-dim tensor.");
self.dim() != 0, select_name, " cannot be applied to a 0-dim tensor.");
dim = maybe_wrap_dim(dim, self.dim());
auto size = self.size(dim);
if (index < -size || index >= size) {
TORCH_CHECK_INDEX(
false,
"select(): index ",
select_name, ": index ",
index,
" out of range for tensor of size ",
self.sizes(),
Expand All @@ -765,6 +815,14 @@ Tensor select_sparse_csr(const Tensor& self, int64_t dim, int64_t index) {
index += size;
}

auto select_strided = [](const Tensor& self, int64_t dim, int64_t index) {
if (require_copy) {
return at::select_copy(self, dim, index);
} else {
return self.select(dim, index);
}
};

TORCH_INTERNAL_ASSERT(dim >= 0 && dim < self.dim());

auto new_sizes = DimVector(self.sizes());
Expand All @@ -790,36 +848,245 @@ Tensor select_sparse_csr(const Tensor& self, int64_t dim, int64_t index) {
return at::native::_sparse_compressed_tensor_unsafe(
compressed_indices.select(dim, index),
plain_indices.select(dim, index),
self.values().select(dim, index),
select_strided(self.values(), dim, index),
new_sizes,
optTypeMetaToScalarType(options.dtype_opt()),
options.layout_opt(),
options.device_opt(),
options.pinned_memory_opt());
} else if (dim < n_batch + 2) {
// Selecting sparse dimension
TORCH_CHECK(
self.layout() == kSparseCsr || self.layout() == kSparseCsc,
"select(): selecting non-batch dimensions is currently only supported for non-blocked sparse compressed layouts tensors.");
TORCH_CHECK(
n_batch == 0,
"select(): selecting rows or columns is not implemented for batched sparse compressed tensors.")
// Converting to COO and calling select is slightly slower than operating
// on the CSR indices directly for constructing a COO vector, however
// current version is more readable and easier to understand.
return self.to_sparse().select(dim, index);
select_name, ": selecting sparse dimensions is not implemented for batched sparse compressed tensors.")
TORCH_INTERNAL_ASSERT(dim == 0 || dim == 1);

DimVector blocksize{1, 1};
AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "select", [&] {}, [&] {
blocksize[0] = std::max<int64_t>(1, self.values().size(n_batch + 1));
blocksize[1] = std::max<int64_t>(1, self.values().size(n_batch + 2));
});

auto indices_options = compressed_indices.options();
int64_t fast_dim = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "select", [&]() { return 0; }, [&]() { return 1; });
int64_t other_dim = (dim == 0 ? 1 : 0);
Tensor indices;
Tensor values;
bool is_view = dim == fast_dim;
if (is_view) {
// select is always a view operation
Tensor start_end = compressed_indices.narrow(0, index / blocksize[dim], 2).cpu();
int64_t start = start_end[0].item<int64_t>();
int64_t end = start_end[1].item<int64_t>();
indices = plain_indices.slice(0, start, end);
values = self.values().slice(0, start, end);
} else {
Tensor decompressed_indices = at::_convert_indices_from_csr_to_coo(compressed_indices, plain_indices)
.select(0, 0);

Tensor dim_indices = at::where(plain_indices.eq(index / blocksize[dim]))[0];
// Notice that dim_indices is a sorted sequence of non-negative
// distinct integers. Below we'll try to solve `dim_indices ==
// arange(start, stop, step)`. If the solution exists then the
// select will be a view operation also for the `dim !=
// fast_dim` case.
int64_t start{}, end{}, step{};
if (solve_arange(dim_indices, start, end, step)) {
indices = decompressed_indices.slice(0, start, end, step);
values = self.values().slice(0, start, end, step);
is_view = true;
} else {
// select will be a copy operation due to index_select!
indices = decompressed_indices.index_select(0, dim_indices);
values = self.values().index_select(0, dim_indices);
}
}

AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "select", [&]() {},
[&]() {
/*
The formula for select indices and values below are best
explained by an example. Consider a BSR tensor with a
block size (2, 3) having four blocks (the other two blocks
contain all zeros and hence will not be specified):
[ 1 2 3] | [ 7 8 9]
[ 4 5 6] | [10 11 12]
---------------------
[13 14 15] | [ 0 0 0]
[16 17 18] | [ 0 0 0]
-----------------------
[ 0 0 0] | [19 20 21]
[ 0 0 0] | [22 23 24]
that represents a 6 x 6 tensor:
[ 1 2 3 7 8 9 ]
[ 4 5 6 10 11 12 ]
[ 13 14 15 0 0 0 ]
[ 16 17 18 0 0 0 ]
[ 0 0 0 19 20 21 ]
[ 0 0 0 22 23 24 ]
The corresponding data for the BSR representation is:
crow_indices = [0 2 3 4]
col_indices = [0 1 0 1]
values = [ [[1 2 3], [4 5 6]], [[7 8 9], [10 11 12]], [[13 14 15], [16 17 18]], [[19 20 21], [22 23 24]] ]
shape = (6, 6)
From crow_indices, we can find that
row_indices = [0 0 1 2]
In the following, we'll illustrate the details of
computing the result of torch.select_copy(input, dim,
index) where dim is 0 or 1, and index is in
range(shape[dim]).
Select a row of a BSR tensor
----------------------------
We will consider first the dim=0 case that corresponds to
selecting a index-th row of the tensor. For instance, for
dim=0 and index=1, the expected result would represent a
1D tensor:
[ 4 5 6 10 11 12 ]
that is a concatenated tensor of certain slices from the
first and the second block that is computed as follows:
values[dim_indices].select(1 + dim, index % blocksize[dim]).flatten(0, 1)
-> values[[0, 1]][:, 1 % 2].flatten(0, 1)
-> [ [[1 2 3], [4 5 6]], [[7 8 9], [10 11 12]] ][:, 1].flatten(0, 1)
-> [ [4 5 6], [10 11 12]].flatten(0, 1)
-> [ 4 5 6 10 11 12]
where dim_indices is found as
where(row_indices == index//blocksize[dim])
-> where([0 0 1 2] == 1//2)
-> [0 1]
The corresponding column indices are computed as
(col_indices[dim_indices].mul(blocksize[other_dim]).unsqueeze(1) + arange(blocksize[other_dim]).unsqueeze(0)).flatten(0, 1)
where other_dim is 1 if dim is 0, and 0 if dim is 1. Let's
expand the above expression with the data in the example:
-> (col_indices[[0, 1]].mul(3).unsqueeze(1) + arange(3).unsqueeze(0)).flatten(0, 1)
-> ([[0 1].mul(3).unsqueeze(1) + [[0 1 2]]).flatten(0, 1)
-> ([[[0], [3]] + [[0 1 2]]).flatten(0, 1) <- here addition will use broadcasting rules!
-> ([[[0 1 2], [3 4 5]]).flatten(0, 1)
-> [0 1 2 3 4 5]
Finally, the select(dim=0, index=1) op on the given sparse
compressed tensors will return a COO tensor:
sparse_coo_tensor([0 1 2 3 4 5].unsqueeze(0), [4 5 6 10 11 12], (6,))
that represents the expected result: [ 4 5 6 10 11 12 ]
Select a column of a BSR tensor
-------------------------------
Next, we'll consider the dim=1 case that corresponds to
selecting the index-th column of the tensor. For instance,
for dim=1 and index=4, the expected result would represent
a 1D tensor:
[ 8 11 0 0 20 23]
that is a concatenated tensor of certain slices from the
second and the last block:
values[dim_indices].select(1 + dim, index % blocksize[dim]).flatten(0, 1)
-> values[[1, 3]][:, :, 4 % 3 ].flatten(0, 1)
-> [ [[7 8 9], [10 11 12]], [[19 20 21], [22 23 24]] ][:, 1, 1].flatten(0, 1)
-> [ [8 11], [20 23]].flatten(0, 1)
-> [ 8 11 20 23 ]
The corresponding row indices are computed as
(row_indices[dim_indices].mul(blocksize[other_dim]).unsqueeze(1) + arange(blocksize[other_dim]).unsqueeze(0)).flatten(0, 1)
where dim_indices is
where(col_indices == index//blocksize[dim])
-> where([0 1 0 1] == 4//3)
-> [1 3]
and we have
(row_indices[dim_indices].mul(blocksize[other_dim]).unsqueeze(1) + arange(blocksize[other_dim]).unsqueeze(0)).flatten(0, 1)
-> (row_indices[[1 3]].mul(2).unsqueeze(1) + arange(2).unsqueeze(0)).flatten(0, 1)
-> ([0 4].unsqueeze(1) + [0 1].unsqueeze(0)).flatten(0, 1)
-> ([[0], [4]] + [[0 1]]).flatten(0, 1) <- here addition will use broadcasting rules!
-> ([[0 1], [4 5]]).flatten(0, 1)
-> [ 0 1 4 5 ]
Finally, the select(dim=1, index=4) op on the given sparse
compressed tensors will return a COO tensor:
sparse_coo_tensor([0 1 4 5].unsqueeze(0), [8 11 20 23], (6,))
that represents the expected result: [ 8 11 0 0 20 23 ]
*/
Tensor subblock_indices = at::arange(0, blocksize[other_dim], indices_options);
indices = indices.mul(blocksize[other_dim]).unsqueeze(1).add(subblock_indices.unsqueeze(0)).flatten(0, 1);
values = values.select(dim + 1, index % blocksize[dim]).flatten(0, 1);
// flatten(0, 1) can be a view or a copy operation. If view
// is required, it will be checked below via is_alias_of,
// otherwise, we'll check if copy is made here to avoid
// unnecessary clone below:
if (require_copy) {
is_view = values.is_alias_of(self.values());
}
});

if (require_view) {
TORCH_CHECK(values.is_alias_of(self.values()), select_name,
": no view exists for the given input, consider using torch.select_copy.");
}

indices = indices.unsqueeze(0).to(kLong);
if (require_copy && is_view) {
values = values.clone();
}
return at::_sparse_coo_tensor_unsafe(indices, values, new_sizes)._coalesced_(true);
} else {
// Selecting dense dimension
return AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(
Tensor new_values = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(
self.layout(),
"select",
// Non blocked layout (2 sparse dims become 1 nnz dim in values, so dim
// is found one position to the left)
[&]() { return self.values().select(dim - 1, index); },
[&]() { return select_strided(self.values(), dim - 1, index); },
// Block layout (2 sparse dims become 1 nnz dim + 2 block-shape dims in
// values, so dim is found 1 position to the right)
[&]() { return self.values().select(dim + 1, index); });
[&]() { return select_strided(self.values(), dim + 1, index); });
return at::native::_sparse_compressed_tensor_unsafe(
compressed_indices,
plain_indices,
new_values,
new_sizes,
optTypeMetaToScalarType(options.dtype_opt()),
options.layout_opt(),
options.device_opt(),
options.pinned_memory_opt());
}
}

Tensor select_sparse_csr(const Tensor& self, int64_t dim, int64_t index) {
return select_sparse_csr_worker<true, false>(self, dim, index);
}

Tensor select_copy_sparse_csr(const Tensor& self, int64_t dim, int64_t index) {
return select_sparse_csr_worker<false, true>(self, dim, index);
}

} // namespace native
} // namespace at
3 changes: 1 addition & 2 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4052,8 +4052,7 @@ def test_basic(self):
class TestSparseAny(TestCase):

def test_generate_simple_inputs(self):
# Temporarily disable BSC and BSC layouts as these don't support select yet, see the next PR in the stack.
layouts = [torch.strided, torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc][:-2]
layouts = [torch.strided, torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc]

tested_combinations = set()
for tensors in zip(*map(self.generate_simple_inputs, layouts)):
Expand Down
Loading

0 comments on commit 296e1ba

Please sign in to comment.