Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding unbatching support for torch.sparse.Tensor #7037

Merged
merged 6 commits into from
Mar 27, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update
  • Loading branch information
rusty1s committed Mar 27, 2023
commit 56b8c5939f172a4a1ac0c8096a3f70d0b19f9e1b
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [2.4.0] - 2023-MM-DD

### Added

- Added unbatching logic for `torch.sparse` tensors ([#7037](https://github.com/pyg-team/pytorch_geometric/pull/7037))
- Added the `RotatE` KGE model ([#7026](https://github.com/pyg-team/pytorch_geometric/pull/7026))

### Changed
Expand Down
42 changes: 42 additions & 0 deletions test/data/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,48 @@ def test_batch_with_sparse_tensor():
assert data_list[2].adj.coo()[1].tolist() == [1, 0, 2, 1, 3, 2]


def test_batch_with_torch_coo_tensor():
x = torch.tensor([[1.0], [2.0], [3.0]]).to_sparse_coo()
data1 = Data(x=x)

x = torch.tensor([[1.0], [2.0]]).to_sparse_coo()
data2 = Data(x=x)

x = torch.tensor([[1.0], [2.0], [3.0], [4.0]]).to_sparse_coo()
data3 = Data(x=x)

batch = Batch.from_data_list([data1])
assert str(batch) == ('DataBatch(x=[3, 1], batch=[3], ptr=[2])')
assert batch.num_graphs == len(batch) == 1
assert batch.x.to_dense().tolist() == [[1], [2], [3]]
assert batch.batch.tolist() == [0, 0, 0]
assert batch.ptr.tolist() == [0, 3]

batch = Batch.from_data_list([data1, data2, data3])

assert str(batch) == ('DataBatch(x=[9, 1], batch=[9], ptr=[4])')
assert batch.num_graphs == len(batch) == 3
assert batch.x.to_dense().view(-1).tolist() == [1, 2, 3, 1, 2, 1, 2, 3, 4]
assert batch.batch.tolist() == [0, 0, 0, 1, 1, 2, 2, 2, 2]
assert batch.ptr.tolist() == [0, 3, 5, 9]

assert str(batch[0]) == ("Data(x=[3, 1])")
assert str(batch[1]) == ("Data(x=[2, 1])")
assert str(batch[2]) == ("Data(x=[4, 1])")

data_list = batch.to_data_list()
assert len(data_list) == 3

assert len(data_list[0]) == 1
assert data_list[0].x.to_dense().tolist() == [[1], [2], [3]]

assert len(data_list[1]) == 1
assert data_list[1].x.to_dense().tolist() == [[1], [2]]

assert len(data_list[2]) == 1
assert data_list[2].x.to_dense().tolist() == [[1], [2], [3], [4]]


def test_batching_with_new_dimension():
torch_geometric.set_debug(True)

Expand Down