forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Row and column select support for block compressed sparse tensors (py…
…torch#88733) As in the title: - Support `select` and `select_copy` on block sparse compressed tensors - Fixes incorrect results when selecting dense dimensions The PR also improves the performance of indexing sparse compressed tensors considerably: <details> Before: ```python In [3]: a=torch.rand((1000, 1000)).to_sparse_csr() In [4]: %timeit a.select(0, 0) 606 µs ± 4.27 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) In [5]: %timeit a.select(1, 0) 527 µs ± 57.7 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) In [6]: %timeit a[0, 0] 617 µs ± 3.74 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) In [7]: a = a.cuda() In [8]: %timeit a.select(0, 0); torch.cuda.synchronize(); 1.19 ms ± 137 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) In [9]: %timeit a.select(1, 0); torch.cuda.synchronize(); 1.2 ms ± 119 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) In [10]: %timeit a[0, 0]; torch.cuda.synchronize(); 1.23 ms ± 482 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) ``` This PR: ```python In [3]: a=torch.rand((1000, 1000)).to_sparse_csr() In [4]: %timeit a.select(0, 0) 4.75 µs ± 8.94 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) In [5]: %timeit a.select(1, 0) 565 µs ± 156 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) In [6]: %timeit a[0, 0] 13.1 µs ± 435 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) In [7]: a = a.cuda() In [8]: %timeit a.select(0, 0); torch.cuda.synchronize(); 21.6 µs ± 23.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) In [9]: %timeit a.select(1, 0); torch.cuda.synchronize(); 1.15 ms ± 3.13 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) In [10]: %timeit a[0, 0]; torch.cuda.synchronize(); 63.7 µs ± 2.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) ``` </details> Pull Request resolved: pytorch#88733 Approved by: https://github.com/nikitaved, https://github.com/amjames, https://github.com/cpuhrsch
- Loading branch information
1 parent
0cc0e5e
commit 296e1ba
Showing
5 changed files
with
351 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.