Skip to content

Commit

Permalink
Sparse tensor printing; add NotImplemented autograd fn (pytorch#10181)
Browse files Browse the repository at this point in the history
Summary:
Commits:

1. Add autograd function `NotImplemented` (subclass of `Error`) so python `grad_fn` prints nicer. Since `Error` is used in `DelayedError` to implement `oncedifferentiable`, I can't just change its name. cc colesbury

2. Add printing for sparse tensors. Fixes pytorch#9412 . cc weiyangfb The controller you requested could not be found. .

3. Add tests for sparse printing

Examples:
```diff
  In [2]: x = torch.sparse.FloatTensor(torch.arange(4).view(2,2), torch.randn(2, 2), [10, 10, 2])

  In [3]: x
  Out[3]:
- torch.sparse.FloatTensor of size (10,10,2) with indices:
- tensor([[0, 1],
-         [2, 3]])
- and values:
- tensor([[-1.1832, -0.5927],
-         [ 0.0831,  0.2511]])
+ tensor(indices=tensor([[0, 1],
+                        [2, 3]]),
+        values=tensor([[ 1.5081,  0.3451],
+                       [-0.0392,  0.4776]]),
+        size=(10, 10, 2), nnz=2, layout=torch.sparse_coo)

  In [4]: x.requires_grad_()
  Out[4]:
- torch.sparse.FloatTensor of size (10,10,2) with indices:
- tensor([[0, 1],
-         [2, 3]], grad_fn=<Error>)
- and values:
- tensor([[-1.1832, -0.5927],
-         [ 0.0831,  0.2511]], grad_fn=<Error>)
+ tensor(indices=tensor([[0, 1],
+                        [2, 3]]),
+        values=tensor([[ 1.5081,  0.3451],
+                       [-0.0392,  0.4776]]),
+        size=(10, 10, 2), nnz=2, layout=torch.sparse_coo, requires_grad=True)

  In [5]: x + x
  Out[5]:
- torch.sparse.FloatTensor of size (10,10,2) with indices:
- tensor([[0, 1],
-         [2, 3]], grad_fn=<Error>)
- and values:
- tensor([[-2.3664, -1.1855],
-         [ 0.1662,  0.5021]], grad_fn=<Error>)
+ tensor(indices=tensor([[0, 1],
+                        [2, 3]]),
+        values=tensor([[ 3.0162,  0.6902],
+                       [-0.0785,  0.9553]]),
+        size=(10, 10, 2), nnz=2, layout=torch.sparse_coo, grad_fn=<AddBackward0>)

  In [6]: x.double()
  Out[6]:
- torch.sparse.DoubleTensor of size (10,10,2) with indices:
- tensor([[0, 1],
-         [2, 3]], grad_fn=<Error>)
- and values:
- tensor([[-1.1832, -0.5927],
-         [ 0.0831,  0.2511]], dtype=torch.float64, grad_fn=<Error>)
+ tensor(indices=tensor([[0, 1],
+                        [2, 3]]),
+        values=tensor([[ 1.5081,  0.3451],
+                       [-0.0392,  0.4776]]),
+        size=(10, 10, 2), nnz=2, dtype=torch.float64, layout=torch.sparse_coo,
+        grad_fn=<NotImplemented>)

  In [7]: x = torch.sparse.FloatTensor(torch.ones(0, 2, dtype=torch.long), torch.randn(2, 0), [0])

  In [8]: x
  Out[8]:
- torch.sparse.FloatTensor of size (0,) with indices:
- tensor([], size=(0, 2), dtype=torch.int64)
- and values:
- tensor([], size=(2, 0))
+ tensor(indices=tensor([], size=(0, 2)),
+        values=tensor([], size=(2, 0)),
+        size=(0,), nnz=2, layout=torch.sparse_coo)

  In [9]: x = torch.sparse.FloatTensor(torch.ones(0, 2, dtype=torch.long), torch.randn(2), [])

  In [10]: x
  Out[10]:
- torch.sparse.FloatTensor of size () with indices:
- tensor([], size=(0, 2), dtype=torch.int64)
- and values:
- tensor([-0.0064,  0.8518])
+ tensor(indices=tensor([], size=(0, 2)),
+        values=tensor([ 0.9800, -0.5978]),
+        size=(), nnz=2, layout=torch.sparse_coo)
```
Pull Request resolved: pytorch#10181

Differential Revision: D9139845

Pulled By: SsnL

fbshipit-source-id: 353eebd55fac4049ed9bf85f8b0ee2c1418a744e
  • Loading branch information
ssnl authored and facebook-github-bot committed Sep 6, 2018
1 parent fa147ab commit 83a1ab2
Show file tree
Hide file tree
Showing 12 changed files with 1,286 additions and 60 deletions.
13 changes: 10 additions & 3 deletions aten/src/ATen/native/sparse/SparseTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ SparseTensor new_with_dims_and_size_sparse(const SparseType& dtype, int64_t spar
SparseTensor self = new_sparse(dtype);
AT_CHECK(size.size() != 0,
"cannot construct sparse tensor with 0 dimensions and no values; you must specify at least 1 dimension if you want to create a sparse tensor with no elements, \
or you must provide a single-element `values` tensor (e.g. x=torch.sparse_coo_tensor(torch.zeros(0,1), 12.3, [])) if you want to create a scalar sparse tensor");
or you must provide a single-element `values` tensor (e.g. x = torch.sparse_coo_tensor(torch.zeros(0, 1), 12.3, [])) if you want to create a scalar sparse tensor");
_get_sparse_impl(self)->resize_and_clear_(sparseDims, denseDims, size);
return self;
}
Expand Down Expand Up @@ -173,17 +173,24 @@ SparseTensor new_with_tensor_and_size_sparse(const LongTensor& indices, const Te

// Check to make sure all indices are within the boundaries of `sizes`
if (indices.numel() > 0) {
LongTensor min_indices = std::get</* values */ 0>(indices.min(/* dim */ 1, /* keepdim */ false));
LongTensor max_indices = std::get</* values */ 0>(indices.max(/* dim */ 1, /* keepdim */ false));
LongTensor cpu_max_indices;
if (max_indices.is_cuda()) {
LongTensor cpu_min_indices, cpu_max_indices;
if (indices.is_cuda()) {
cpu_min_indices = at::CPU(kLong).copy(min_indices);
cpu_max_indices = at::CPU(kLong).copy(max_indices);
} else {
cpu_min_indices = min_indices;
cpu_max_indices = max_indices;
}
auto cpu_min_indices_accessor = cpu_min_indices.accessor<int64_t, 1>();
auto cpu_max_indices_accessor = cpu_max_indices.accessor<int64_t, 1>();
for (int64_t d = 0; d < sparseDims; d++) {
// NB: This used to sync ndim times to access each entry; now we copy
// everything to CPU first and then access it.
int64_t min_index_in_dim = cpu_min_indices_accessor[d];
AT_CHECK(min_index_in_dim >= 0,
"found negative index ", min_index_in_dim, " for dim ", d);
int64_t max_index_in_dim = cpu_max_indices_accessor[d];
int64_t dim_size = sizes[static_cast<size_t>(d)];
AT_CHECK(max_index_in_dim < dim_size,
Expand Down
13 changes: 3 additions & 10 deletions aten/src/ATen/native/sparse/SparseUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,9 @@ inline LongTensor _newFlattenedIndices(const SparseTensor& self, bool forceClone
// TODO: Expose this for real in ATen, some day?
// NB: Doesn't preserve data.
inline Tensor _new_values_with_size_of(const Tensor& values, int64_t nnz) {
if (values.numel() == 0) { // values tensor uninitialized
// TODO: This logic looks bogus; if we have an uninitialized
// values tensor, why should we believe that denseDims == 0?
// That's the assumption this code makes.
return values.type().tensor({nnz});
} else {
std::vector<int64_t> size = values.sizes().vec();
size[0] = nnz;
return values.type().tensor(size);
}
std::vector<int64_t> size = values.sizes().vec();
size[0] = nnz;
return values.type().tensor(size);
}


Expand Down
294 changes: 294 additions & 0 deletions test/expect/TestCudaSparse.test_print.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
# shape: torch.Size([])
# nnz: 2
# sparseDim: 0
# indices shape: torch.Size([0, 2])
# values shape: torch.Size([2])
########## torch.int32 ##########
# sparse tensor
tensor(indices=tensor([], size=(0, 2)),
values=tensor([0, 1]),
device='cuda:0', size=(), nnz=2, dtype=torch.int32,
layout=torch.sparse_coo)
# _indices
tensor([], device='cuda:0', size=(0, 2), dtype=torch.int64)
# _values
tensor([0, 1], device='cuda:0', dtype=torch.int32)
########## torch.float32 ##########
# sparse tensor
tensor(indices=tensor([], size=(0, 2)),
values=tensor([0., 1.]),
device='cuda:0', size=(), nnz=2, dtype=torch.float32,
layout=torch.sparse_coo)
# after requires_grad_
tensor(indices=tensor([], size=(0, 2)),
values=tensor([0., 1.]),
device='cuda:0', size=(), nnz=2, dtype=torch.float32,
layout=torch.sparse_coo, requires_grad=True)
# after addition
tensor(indices=tensor([], size=(0, 4)),
values=tensor([0., 1., 0., 1.]),
device='cuda:0', size=(), nnz=4, dtype=torch.float32,
layout=torch.sparse_coo, grad_fn=<AddBackward0>)
# _indices
tensor([], device='cuda:0', size=(0, 2), dtype=torch.int64,
grad_fn=<NotImplemented>)
# _values
tensor([0., 1.], device='cuda:0', dtype=torch.float32,
grad_fn=<NotImplemented>)

# shape: torch.Size([0])
# nnz: 10
# sparseDim: 0
# indices shape: torch.Size([0, 10])
# values shape: torch.Size([10, 0])
########## torch.int32 ##########
# sparse tensor
tensor(indices=tensor([], size=(0, 10)),
values=tensor([], size=(10, 0)),
device='cuda:0', size=(0,), nnz=10, dtype=torch.int32,
layout=torch.sparse_coo)
# _indices
tensor([], device='cuda:0', size=(0, 10), dtype=torch.int64)
# _values
tensor([], device='cuda:0', size=(10, 0), dtype=torch.int32)
########## torch.float32 ##########
# sparse tensor
tensor(indices=tensor([], size=(0, 10)),
values=tensor([], size=(10, 0)),
device='cuda:0', size=(0,), nnz=10, dtype=torch.float32,
layout=torch.sparse_coo)
# after requires_grad_
tensor(indices=tensor([], size=(0, 10)),
values=tensor([], size=(10, 0)),
device='cuda:0', size=(0,), nnz=10, dtype=torch.float32,
layout=torch.sparse_coo, requires_grad=True)
# after addition
tensor(indices=tensor([], size=(0, 20)),
values=tensor([], size=(20, 0)),
device='cuda:0', size=(0,), nnz=20, dtype=torch.float32,
layout=torch.sparse_coo, grad_fn=<AddBackward0>)
# _indices
tensor([], device='cuda:0', size=(0, 10), dtype=torch.int64,
grad_fn=<NotImplemented>)
# _values
tensor([], device='cuda:0', size=(10, 0), dtype=torch.float32,
grad_fn=<NotImplemented>)

# shape: torch.Size([2])
# nnz: 3
# sparseDim: 0
# indices shape: torch.Size([0, 3])
# values shape: torch.Size([3, 2])
########## torch.int32 ##########
# sparse tensor
tensor(indices=tensor([], size=(0, 3)),
values=tensor([[0, 0],
[0, 1],
[1, 1]]),
device='cuda:0', size=(2,), nnz=3, dtype=torch.int32,
layout=torch.sparse_coo)
# _indices
tensor([], device='cuda:0', size=(0, 3), dtype=torch.int64)
# _values
tensor([[0, 0],
[0, 1],
[1, 1]], device='cuda:0', dtype=torch.int32)
########## torch.float32 ##########
# sparse tensor
tensor(indices=tensor([], size=(0, 3)),
values=tensor([[0.0000, 0.3333],
[0.6667, 1.0000],
[1.3333, 1.6667]]),
device='cuda:0', size=(2,), nnz=3, dtype=torch.float32,
layout=torch.sparse_coo)
# after requires_grad_
tensor(indices=tensor([], size=(0, 3)),
values=tensor([[0.0000, 0.3333],
[0.6667, 1.0000],
[1.3333, 1.6667]]),
device='cuda:0', size=(2,), nnz=3, dtype=torch.float32,
layout=torch.sparse_coo, requires_grad=True)
# after addition
tensor(indices=tensor([], size=(0, 6)),
values=tensor([[0.0000, 0.3333],
[0.6667, 1.0000],
[1.3333, 1.6667],
[0.0000, 0.3333],
[0.6667, 1.0000],
[1.3333, 1.6667]]),
device='cuda:0', size=(2,), nnz=6, dtype=torch.float32,
layout=torch.sparse_coo, grad_fn=<AddBackward0>)
# _indices
tensor([], device='cuda:0', size=(0, 3), dtype=torch.int64,
grad_fn=<NotImplemented>)
# _values
tensor([[0.0000, 0.3333],
[0.6667, 1.0000],
[1.3333, 1.6667]], device='cuda:0', dtype=torch.float32,
grad_fn=<NotImplemented>)

# shape: torch.Size([100, 3])
# nnz: 3
# sparseDim: 1
# indices shape: torch.Size([1, 3])
# values shape: torch.Size([3, 3])
########## torch.int32 ##########
# sparse tensor
tensor(indices=tensor([[0, 1, 2]]),
values=tensor([[0, 0, 0],
[0, 0, 1],
[1, 1, 1]]),
device='cuda:0', size=(100, 3), nnz=3, dtype=torch.int32,
layout=torch.sparse_coo)
# _indices
tensor([[0, 1, 2]], device='cuda:0')
# _values
tensor([[0, 0, 0],
[0, 0, 1],
[1, 1, 1]], device='cuda:0', dtype=torch.int32)
########## torch.float32 ##########
# sparse tensor
tensor(indices=tensor([[0, 1, 2]]),
values=tensor([[0.0000, 0.2222, 0.4444],
[0.6667, 0.8889, 1.1111],
[1.3333, 1.5556, 1.7778]]),
device='cuda:0', size=(100, 3), nnz=3, dtype=torch.float32,
layout=torch.sparse_coo)
# after requires_grad_
tensor(indices=tensor([[0, 1, 2]]),
values=tensor([[0.0000, 0.2222, 0.4444],
[0.6667, 0.8889, 1.1111],
[1.3333, 1.5556, 1.7778]]),
device='cuda:0', size=(100, 3), nnz=3, dtype=torch.float32,
layout=torch.sparse_coo, requires_grad=True)
# after addition
tensor(indices=tensor([[0, 1, 2, 0, 1, 2]]),
values=tensor([[0.0000, 0.2222, 0.4444],
[0.6667, 0.8889, 1.1111],
[1.3333, 1.5556, 1.7778],
[0.0000, 0.2222, 0.4444],
[0.6667, 0.8889, 1.1111],
[1.3333, 1.5556, 1.7778]]),
device='cuda:0', size=(100, 3), nnz=6, dtype=torch.float32,
layout=torch.sparse_coo, grad_fn=<AddBackward0>)
# _indices
tensor([[0, 1, 2]], device='cuda:0', grad_fn=<NotImplemented>)
# _values
tensor([[0.0000, 0.2222, 0.4444],
[0.6667, 0.8889, 1.1111],
[1.3333, 1.5556, 1.7778]], device='cuda:0', dtype=torch.float32,
grad_fn=<NotImplemented>)

# shape: torch.Size([100, 20, 3])
# nnz: 0
# sparseDim: 2
# indices shape: torch.Size([2, 0])
# values shape: torch.Size([0, 3])
########## torch.int32 ##########
# sparse tensor
tensor(indices=tensor([], size=(2, 0)),
values=tensor([], size=(0, 3)),
device='cuda:0', size=(100, 20, 3), nnz=0, dtype=torch.int32,
layout=torch.sparse_coo)
# _indices
tensor([], device='cuda:0', size=(2, 0), dtype=torch.int64)
# _values
tensor([], device='cuda:0', size=(0, 3), dtype=torch.int32)
########## torch.float32 ##########
# sparse tensor
tensor(indices=tensor([], size=(2, 0)),
values=tensor([], size=(0, 3)),
device='cuda:0', size=(100, 20, 3), nnz=0, dtype=torch.float32,
layout=torch.sparse_coo)
# after requires_grad_
tensor(indices=tensor([], size=(2, 0)),
values=tensor([], size=(0, 3)),
device='cuda:0', size=(100, 20, 3), nnz=0, dtype=torch.float32,
layout=torch.sparse_coo, requires_grad=True)
# after addition
tensor(indices=tensor([], size=(2, 0)),
values=tensor([], size=(0, 3)),
device='cuda:0', size=(100, 20, 3), nnz=0, dtype=torch.float32,
layout=torch.sparse_coo, grad_fn=<AddBackward0>)
# _indices
tensor([], device='cuda:0', size=(2, 0), dtype=torch.int64,
grad_fn=<NotImplemented>)
# _values
tensor([], device='cuda:0', size=(0, 3), dtype=torch.float32,
grad_fn=<NotImplemented>)

# shape: torch.Size([10, 0, 3])
# nnz: 3
# sparseDim: 0
# indices shape: torch.Size([0, 3])
# values shape: torch.Size([3, 10, 0, 3])
########## torch.int32 ##########
# sparse tensor
tensor(indices=tensor([], size=(0, 3)),
values=tensor([], size=(3, 10, 0, 3)),
device='cuda:0', size=(10, 0, 3), nnz=3, dtype=torch.int32,
layout=torch.sparse_coo)
# _indices
tensor([], device='cuda:0', size=(0, 3), dtype=torch.int64)
# _values
tensor([], device='cuda:0', size=(3, 10, 0, 3), dtype=torch.int32)
########## torch.float32 ##########
# sparse tensor
tensor(indices=tensor([], size=(0, 3)),
values=tensor([], size=(3, 10, 0, 3)),
device='cuda:0', size=(10, 0, 3), nnz=3, dtype=torch.float32,
layout=torch.sparse_coo)
# after requires_grad_
tensor(indices=tensor([], size=(0, 3)),
values=tensor([], size=(3, 10, 0, 3)),
device='cuda:0', size=(10, 0, 3), nnz=3, dtype=torch.float32,
layout=torch.sparse_coo, requires_grad=True)
# after addition
tensor(indices=tensor([], size=(0, 6)),
values=tensor([], size=(6, 10, 0, 3)),
device='cuda:0', size=(10, 0, 3), nnz=6, dtype=torch.float32,
layout=torch.sparse_coo, grad_fn=<AddBackward0>)
# _indices
tensor([], device='cuda:0', size=(0, 3), dtype=torch.int64,
grad_fn=<NotImplemented>)
# _values
tensor([], device='cuda:0', size=(3, 10, 0, 3), dtype=torch.float32,
grad_fn=<NotImplemented>)

# shape: torch.Size([10, 0, 3])
# nnz: 0
# sparseDim: 0
# indices shape: torch.Size([0, 0])
# values shape: torch.Size([0, 10, 0, 3])
########## torch.int32 ##########
# sparse tensor
tensor(indices=tensor([], size=(0, 0)),
values=tensor([], size=(0, 10, 0, 3)),
device='cuda:0', size=(10, 0, 3), nnz=0, dtype=torch.int32,
layout=torch.sparse_coo)
# _indices
tensor([], device='cuda:0', size=(0, 0), dtype=torch.int64)
# _values
tensor([], device='cuda:0', size=(0, 10, 0, 3), dtype=torch.int32)
########## torch.float32 ##########
# sparse tensor
tensor(indices=tensor([], size=(0, 0)),
values=tensor([], size=(0, 10, 0, 3)),
device='cuda:0', size=(10, 0, 3), nnz=0, dtype=torch.float32,
layout=torch.sparse_coo)
# after requires_grad_
tensor(indices=tensor([], size=(0, 0)),
values=tensor([], size=(0, 10, 0, 3)),
device='cuda:0', size=(10, 0, 3), nnz=0, dtype=torch.float32,
layout=torch.sparse_coo, requires_grad=True)
# after addition
tensor(indices=tensor([], size=(0, 0)),
values=tensor([], size=(0, 10, 0, 3)),
device='cuda:0', size=(10, 0, 3), nnz=0, dtype=torch.float32,
layout=torch.sparse_coo, grad_fn=<AddBackward0>)
# _indices
tensor([], device='cuda:0', size=(0, 0), dtype=torch.int64,
grad_fn=<NotImplemented>)
# _values
tensor([], device='cuda:0', size=(0, 10, 0, 3), dtype=torch.float32,
grad_fn=<NotImplemented>)
Loading

0 comments on commit 83a1ab2

Please sign in to comment.