forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Sparse tensor printing; add NotImplemented autograd fn (pytorch#10181)
Summary: Commits: 1. Add autograd function `NotImplemented` (subclass of `Error`) so python `grad_fn` prints nicer. Since `Error` is used in `DelayedError` to implement `oncedifferentiable`, I can't just change its name. cc colesbury 2. Add printing for sparse tensors. Fixes pytorch#9412 . cc weiyangfb The controller you requested could not be found. . 3. Add tests for sparse printing Examples: ```diff In [2]: x = torch.sparse.FloatTensor(torch.arange(4).view(2,2), torch.randn(2, 2), [10, 10, 2]) In [3]: x Out[3]: - torch.sparse.FloatTensor of size (10,10,2) with indices: - tensor([[0, 1], - [2, 3]]) - and values: - tensor([[-1.1832, -0.5927], - [ 0.0831, 0.2511]]) + tensor(indices=tensor([[0, 1], + [2, 3]]), + values=tensor([[ 1.5081, 0.3451], + [-0.0392, 0.4776]]), + size=(10, 10, 2), nnz=2, layout=torch.sparse_coo) In [4]: x.requires_grad_() Out[4]: - torch.sparse.FloatTensor of size (10,10,2) with indices: - tensor([[0, 1], - [2, 3]], grad_fn=<Error>) - and values: - tensor([[-1.1832, -0.5927], - [ 0.0831, 0.2511]], grad_fn=<Error>) + tensor(indices=tensor([[0, 1], + [2, 3]]), + values=tensor([[ 1.5081, 0.3451], + [-0.0392, 0.4776]]), + size=(10, 10, 2), nnz=2, layout=torch.sparse_coo, requires_grad=True) In [5]: x + x Out[5]: - torch.sparse.FloatTensor of size (10,10,2) with indices: - tensor([[0, 1], - [2, 3]], grad_fn=<Error>) - and values: - tensor([[-2.3664, -1.1855], - [ 0.1662, 0.5021]], grad_fn=<Error>) + tensor(indices=tensor([[0, 1], + [2, 3]]), + values=tensor([[ 3.0162, 0.6902], + [-0.0785, 0.9553]]), + size=(10, 10, 2), nnz=2, layout=torch.sparse_coo, grad_fn=<AddBackward0>) In [6]: x.double() Out[6]: - torch.sparse.DoubleTensor of size (10,10,2) with indices: - tensor([[0, 1], - [2, 3]], grad_fn=<Error>) - and values: - tensor([[-1.1832, -0.5927], - [ 0.0831, 0.2511]], dtype=torch.float64, grad_fn=<Error>) + tensor(indices=tensor([[0, 1], + [2, 3]]), + values=tensor([[ 1.5081, 0.3451], + [-0.0392, 0.4776]]), + size=(10, 10, 2), nnz=2, dtype=torch.float64, layout=torch.sparse_coo, + grad_fn=<NotImplemented>) In [7]: x = torch.sparse.FloatTensor(torch.ones(0, 2, dtype=torch.long), torch.randn(2, 0), [0]) In [8]: x Out[8]: - torch.sparse.FloatTensor of size (0,) with indices: - tensor([], size=(0, 2), dtype=torch.int64) - and values: - tensor([], size=(2, 0)) + tensor(indices=tensor([], size=(0, 2)), + values=tensor([], size=(2, 0)), + size=(0,), nnz=2, layout=torch.sparse_coo) In [9]: x = torch.sparse.FloatTensor(torch.ones(0, 2, dtype=torch.long), torch.randn(2), []) In [10]: x Out[10]: - torch.sparse.FloatTensor of size () with indices: - tensor([], size=(0, 2), dtype=torch.int64) - and values: - tensor([-0.0064, 0.8518]) + tensor(indices=tensor([], size=(0, 2)), + values=tensor([ 0.9800, -0.5978]), + size=(), nnz=2, layout=torch.sparse_coo) ``` Pull Request resolved: pytorch#10181 Differential Revision: D9139845 Pulled By: SsnL fbshipit-source-id: 353eebd55fac4049ed9bf85f8b0ee2c1418a744e
- Loading branch information
1 parent
fa147ab
commit 83a1ab2
Showing
12 changed files
with
1,286 additions
and
60 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,294 @@ | ||
# shape: torch.Size([]) | ||
# nnz: 2 | ||
# sparseDim: 0 | ||
# indices shape: torch.Size([0, 2]) | ||
# values shape: torch.Size([2]) | ||
########## torch.int32 ########## | ||
# sparse tensor | ||
tensor(indices=tensor([], size=(0, 2)), | ||
values=tensor([0, 1]), | ||
device='cuda:0', size=(), nnz=2, dtype=torch.int32, | ||
layout=torch.sparse_coo) | ||
# _indices | ||
tensor([], device='cuda:0', size=(0, 2), dtype=torch.int64) | ||
# _values | ||
tensor([0, 1], device='cuda:0', dtype=torch.int32) | ||
########## torch.float32 ########## | ||
# sparse tensor | ||
tensor(indices=tensor([], size=(0, 2)), | ||
values=tensor([0., 1.]), | ||
device='cuda:0', size=(), nnz=2, dtype=torch.float32, | ||
layout=torch.sparse_coo) | ||
# after requires_grad_ | ||
tensor(indices=tensor([], size=(0, 2)), | ||
values=tensor([0., 1.]), | ||
device='cuda:0', size=(), nnz=2, dtype=torch.float32, | ||
layout=torch.sparse_coo, requires_grad=True) | ||
# after addition | ||
tensor(indices=tensor([], size=(0, 4)), | ||
values=tensor([0., 1., 0., 1.]), | ||
device='cuda:0', size=(), nnz=4, dtype=torch.float32, | ||
layout=torch.sparse_coo, grad_fn=<AddBackward0>) | ||
# _indices | ||
tensor([], device='cuda:0', size=(0, 2), dtype=torch.int64, | ||
grad_fn=<NotImplemented>) | ||
# _values | ||
tensor([0., 1.], device='cuda:0', dtype=torch.float32, | ||
grad_fn=<NotImplemented>) | ||
|
||
# shape: torch.Size([0]) | ||
# nnz: 10 | ||
# sparseDim: 0 | ||
# indices shape: torch.Size([0, 10]) | ||
# values shape: torch.Size([10, 0]) | ||
########## torch.int32 ########## | ||
# sparse tensor | ||
tensor(indices=tensor([], size=(0, 10)), | ||
values=tensor([], size=(10, 0)), | ||
device='cuda:0', size=(0,), nnz=10, dtype=torch.int32, | ||
layout=torch.sparse_coo) | ||
# _indices | ||
tensor([], device='cuda:0', size=(0, 10), dtype=torch.int64) | ||
# _values | ||
tensor([], device='cuda:0', size=(10, 0), dtype=torch.int32) | ||
########## torch.float32 ########## | ||
# sparse tensor | ||
tensor(indices=tensor([], size=(0, 10)), | ||
values=tensor([], size=(10, 0)), | ||
device='cuda:0', size=(0,), nnz=10, dtype=torch.float32, | ||
layout=torch.sparse_coo) | ||
# after requires_grad_ | ||
tensor(indices=tensor([], size=(0, 10)), | ||
values=tensor([], size=(10, 0)), | ||
device='cuda:0', size=(0,), nnz=10, dtype=torch.float32, | ||
layout=torch.sparse_coo, requires_grad=True) | ||
# after addition | ||
tensor(indices=tensor([], size=(0, 20)), | ||
values=tensor([], size=(20, 0)), | ||
device='cuda:0', size=(0,), nnz=20, dtype=torch.float32, | ||
layout=torch.sparse_coo, grad_fn=<AddBackward0>) | ||
# _indices | ||
tensor([], device='cuda:0', size=(0, 10), dtype=torch.int64, | ||
grad_fn=<NotImplemented>) | ||
# _values | ||
tensor([], device='cuda:0', size=(10, 0), dtype=torch.float32, | ||
grad_fn=<NotImplemented>) | ||
|
||
# shape: torch.Size([2]) | ||
# nnz: 3 | ||
# sparseDim: 0 | ||
# indices shape: torch.Size([0, 3]) | ||
# values shape: torch.Size([3, 2]) | ||
########## torch.int32 ########## | ||
# sparse tensor | ||
tensor(indices=tensor([], size=(0, 3)), | ||
values=tensor([[0, 0], | ||
[0, 1], | ||
[1, 1]]), | ||
device='cuda:0', size=(2,), nnz=3, dtype=torch.int32, | ||
layout=torch.sparse_coo) | ||
# _indices | ||
tensor([], device='cuda:0', size=(0, 3), dtype=torch.int64) | ||
# _values | ||
tensor([[0, 0], | ||
[0, 1], | ||
[1, 1]], device='cuda:0', dtype=torch.int32) | ||
########## torch.float32 ########## | ||
# sparse tensor | ||
tensor(indices=tensor([], size=(0, 3)), | ||
values=tensor([[0.0000, 0.3333], | ||
[0.6667, 1.0000], | ||
[1.3333, 1.6667]]), | ||
device='cuda:0', size=(2,), nnz=3, dtype=torch.float32, | ||
layout=torch.sparse_coo) | ||
# after requires_grad_ | ||
tensor(indices=tensor([], size=(0, 3)), | ||
values=tensor([[0.0000, 0.3333], | ||
[0.6667, 1.0000], | ||
[1.3333, 1.6667]]), | ||
device='cuda:0', size=(2,), nnz=3, dtype=torch.float32, | ||
layout=torch.sparse_coo, requires_grad=True) | ||
# after addition | ||
tensor(indices=tensor([], size=(0, 6)), | ||
values=tensor([[0.0000, 0.3333], | ||
[0.6667, 1.0000], | ||
[1.3333, 1.6667], | ||
[0.0000, 0.3333], | ||
[0.6667, 1.0000], | ||
[1.3333, 1.6667]]), | ||
device='cuda:0', size=(2,), nnz=6, dtype=torch.float32, | ||
layout=torch.sparse_coo, grad_fn=<AddBackward0>) | ||
# _indices | ||
tensor([], device='cuda:0', size=(0, 3), dtype=torch.int64, | ||
grad_fn=<NotImplemented>) | ||
# _values | ||
tensor([[0.0000, 0.3333], | ||
[0.6667, 1.0000], | ||
[1.3333, 1.6667]], device='cuda:0', dtype=torch.float32, | ||
grad_fn=<NotImplemented>) | ||
|
||
# shape: torch.Size([100, 3]) | ||
# nnz: 3 | ||
# sparseDim: 1 | ||
# indices shape: torch.Size([1, 3]) | ||
# values shape: torch.Size([3, 3]) | ||
########## torch.int32 ########## | ||
# sparse tensor | ||
tensor(indices=tensor([[0, 1, 2]]), | ||
values=tensor([[0, 0, 0], | ||
[0, 0, 1], | ||
[1, 1, 1]]), | ||
device='cuda:0', size=(100, 3), nnz=3, dtype=torch.int32, | ||
layout=torch.sparse_coo) | ||
# _indices | ||
tensor([[0, 1, 2]], device='cuda:0') | ||
# _values | ||
tensor([[0, 0, 0], | ||
[0, 0, 1], | ||
[1, 1, 1]], device='cuda:0', dtype=torch.int32) | ||
########## torch.float32 ########## | ||
# sparse tensor | ||
tensor(indices=tensor([[0, 1, 2]]), | ||
values=tensor([[0.0000, 0.2222, 0.4444], | ||
[0.6667, 0.8889, 1.1111], | ||
[1.3333, 1.5556, 1.7778]]), | ||
device='cuda:0', size=(100, 3), nnz=3, dtype=torch.float32, | ||
layout=torch.sparse_coo) | ||
# after requires_grad_ | ||
tensor(indices=tensor([[0, 1, 2]]), | ||
values=tensor([[0.0000, 0.2222, 0.4444], | ||
[0.6667, 0.8889, 1.1111], | ||
[1.3333, 1.5556, 1.7778]]), | ||
device='cuda:0', size=(100, 3), nnz=3, dtype=torch.float32, | ||
layout=torch.sparse_coo, requires_grad=True) | ||
# after addition | ||
tensor(indices=tensor([[0, 1, 2, 0, 1, 2]]), | ||
values=tensor([[0.0000, 0.2222, 0.4444], | ||
[0.6667, 0.8889, 1.1111], | ||
[1.3333, 1.5556, 1.7778], | ||
[0.0000, 0.2222, 0.4444], | ||
[0.6667, 0.8889, 1.1111], | ||
[1.3333, 1.5556, 1.7778]]), | ||
device='cuda:0', size=(100, 3), nnz=6, dtype=torch.float32, | ||
layout=torch.sparse_coo, grad_fn=<AddBackward0>) | ||
# _indices | ||
tensor([[0, 1, 2]], device='cuda:0', grad_fn=<NotImplemented>) | ||
# _values | ||
tensor([[0.0000, 0.2222, 0.4444], | ||
[0.6667, 0.8889, 1.1111], | ||
[1.3333, 1.5556, 1.7778]], device='cuda:0', dtype=torch.float32, | ||
grad_fn=<NotImplemented>) | ||
|
||
# shape: torch.Size([100, 20, 3]) | ||
# nnz: 0 | ||
# sparseDim: 2 | ||
# indices shape: torch.Size([2, 0]) | ||
# values shape: torch.Size([0, 3]) | ||
########## torch.int32 ########## | ||
# sparse tensor | ||
tensor(indices=tensor([], size=(2, 0)), | ||
values=tensor([], size=(0, 3)), | ||
device='cuda:0', size=(100, 20, 3), nnz=0, dtype=torch.int32, | ||
layout=torch.sparse_coo) | ||
# _indices | ||
tensor([], device='cuda:0', size=(2, 0), dtype=torch.int64) | ||
# _values | ||
tensor([], device='cuda:0', size=(0, 3), dtype=torch.int32) | ||
########## torch.float32 ########## | ||
# sparse tensor | ||
tensor(indices=tensor([], size=(2, 0)), | ||
values=tensor([], size=(0, 3)), | ||
device='cuda:0', size=(100, 20, 3), nnz=0, dtype=torch.float32, | ||
layout=torch.sparse_coo) | ||
# after requires_grad_ | ||
tensor(indices=tensor([], size=(2, 0)), | ||
values=tensor([], size=(0, 3)), | ||
device='cuda:0', size=(100, 20, 3), nnz=0, dtype=torch.float32, | ||
layout=torch.sparse_coo, requires_grad=True) | ||
# after addition | ||
tensor(indices=tensor([], size=(2, 0)), | ||
values=tensor([], size=(0, 3)), | ||
device='cuda:0', size=(100, 20, 3), nnz=0, dtype=torch.float32, | ||
layout=torch.sparse_coo, grad_fn=<AddBackward0>) | ||
# _indices | ||
tensor([], device='cuda:0', size=(2, 0), dtype=torch.int64, | ||
grad_fn=<NotImplemented>) | ||
# _values | ||
tensor([], device='cuda:0', size=(0, 3), dtype=torch.float32, | ||
grad_fn=<NotImplemented>) | ||
|
||
# shape: torch.Size([10, 0, 3]) | ||
# nnz: 3 | ||
# sparseDim: 0 | ||
# indices shape: torch.Size([0, 3]) | ||
# values shape: torch.Size([3, 10, 0, 3]) | ||
########## torch.int32 ########## | ||
# sparse tensor | ||
tensor(indices=tensor([], size=(0, 3)), | ||
values=tensor([], size=(3, 10, 0, 3)), | ||
device='cuda:0', size=(10, 0, 3), nnz=3, dtype=torch.int32, | ||
layout=torch.sparse_coo) | ||
# _indices | ||
tensor([], device='cuda:0', size=(0, 3), dtype=torch.int64) | ||
# _values | ||
tensor([], device='cuda:0', size=(3, 10, 0, 3), dtype=torch.int32) | ||
########## torch.float32 ########## | ||
# sparse tensor | ||
tensor(indices=tensor([], size=(0, 3)), | ||
values=tensor([], size=(3, 10, 0, 3)), | ||
device='cuda:0', size=(10, 0, 3), nnz=3, dtype=torch.float32, | ||
layout=torch.sparse_coo) | ||
# after requires_grad_ | ||
tensor(indices=tensor([], size=(0, 3)), | ||
values=tensor([], size=(3, 10, 0, 3)), | ||
device='cuda:0', size=(10, 0, 3), nnz=3, dtype=torch.float32, | ||
layout=torch.sparse_coo, requires_grad=True) | ||
# after addition | ||
tensor(indices=tensor([], size=(0, 6)), | ||
values=tensor([], size=(6, 10, 0, 3)), | ||
device='cuda:0', size=(10, 0, 3), nnz=6, dtype=torch.float32, | ||
layout=torch.sparse_coo, grad_fn=<AddBackward0>) | ||
# _indices | ||
tensor([], device='cuda:0', size=(0, 3), dtype=torch.int64, | ||
grad_fn=<NotImplemented>) | ||
# _values | ||
tensor([], device='cuda:0', size=(3, 10, 0, 3), dtype=torch.float32, | ||
grad_fn=<NotImplemented>) | ||
|
||
# shape: torch.Size([10, 0, 3]) | ||
# nnz: 0 | ||
# sparseDim: 0 | ||
# indices shape: torch.Size([0, 0]) | ||
# values shape: torch.Size([0, 10, 0, 3]) | ||
########## torch.int32 ########## | ||
# sparse tensor | ||
tensor(indices=tensor([], size=(0, 0)), | ||
values=tensor([], size=(0, 10, 0, 3)), | ||
device='cuda:0', size=(10, 0, 3), nnz=0, dtype=torch.int32, | ||
layout=torch.sparse_coo) | ||
# _indices | ||
tensor([], device='cuda:0', size=(0, 0), dtype=torch.int64) | ||
# _values | ||
tensor([], device='cuda:0', size=(0, 10, 0, 3), dtype=torch.int32) | ||
########## torch.float32 ########## | ||
# sparse tensor | ||
tensor(indices=tensor([], size=(0, 0)), | ||
values=tensor([], size=(0, 10, 0, 3)), | ||
device='cuda:0', size=(10, 0, 3), nnz=0, dtype=torch.float32, | ||
layout=torch.sparse_coo) | ||
# after requires_grad_ | ||
tensor(indices=tensor([], size=(0, 0)), | ||
values=tensor([], size=(0, 10, 0, 3)), | ||
device='cuda:0', size=(10, 0, 3), nnz=0, dtype=torch.float32, | ||
layout=torch.sparse_coo, requires_grad=True) | ||
# after addition | ||
tensor(indices=tensor([], size=(0, 0)), | ||
values=tensor([], size=(0, 10, 0, 3)), | ||
device='cuda:0', size=(10, 0, 3), nnz=0, dtype=torch.float32, | ||
layout=torch.sparse_coo, grad_fn=<AddBackward0>) | ||
# _indices | ||
tensor([], device='cuda:0', size=(0, 0), dtype=torch.int64, | ||
grad_fn=<NotImplemented>) | ||
# _values | ||
tensor([], device='cuda:0', size=(0, 10, 0, 3), dtype=torch.float32, | ||
grad_fn=<NotImplemented>) |
Oops, something went wrong.