Skip to content

Commit

Permalink
Add Index.index_select() (#9286)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored May 4, 2024
1 parent 6d4c63f commit c0e1459
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 1 deletion.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277), [#9278](https://github.com/pyg-team/pytorch_geometric/pull/9278), [#9279](https://github.com/pyg-team/pytorch_geometric/pull/9279), [#9280](https://github.com/pyg-team/pytorch_geometric/pull/9280), [#9281](https://github.com/pyg-team/pytorch_geometric/pull/9281), [#9284](https://github.com/pyg-team/pytorch_geometric/pull/9284), [#9285](https://github.com/pyg-team/pytorch_geometric/pull/9285))
- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277), [#9278](https://github.com/pyg-team/pytorch_geometric/pull/9278), [#9279](https://github.com/pyg-team/pytorch_geometric/pull/9279), [#9280](https://github.com/pyg-team/pytorch_geometric/pull/9280), [#9281](https://github.com/pyg-team/pytorch_geometric/pull/9281), [#9284](https://github.com/pyg-team/pytorch_geometric/pull/9284), [#9285](https://github.com/pyg-team/pytorch_geometric/pull/9285), [#9286](https://github.com/pyg-team/pytorch_geometric/pull/9286))
- Added support for PyTorch 2.3 ([#9240](https://github.com/pyg-team/pytorch_geometric/pull/9240))
- Added support for `EdgeIndex` in `message_and_aggregate` ([#9131](https://github.com/pyg-team/pytorch_geometric/pull/9131))
- Added `CornellTemporalHyperGraphDataset` ([#9090](https://github.com/pyg-team/pytorch_geometric/pull/9090))
Expand Down
21 changes: 21 additions & 0 deletions test/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,24 @@ def test_flip(dtype, device):
assert out.equal(tensor([2, 1, 1, 0], device=device))
assert out.dim_size == 3
assert not out.is_sorted


@withCUDA
@pytest.mark.parametrize('dtype', DTYPES)
def test_index_select(dtype, device):
kwargs = dict(dtype=dtype, device=device)
index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs)

i = tensor([1, 3], device=device)
out = index.index_select(0, i)
assert out.equal(tensor([1, 2], device=device))
assert isinstance(out, Index)
assert out.dim_size == 3
assert not out.is_sorted

inplace = torch.empty(2, dtype=dtype, device=device)
out = torch.index_select(index, 0, i, out=inplace)
assert out.equal(tensor([1, 2], device=device))
assert out.data_ptr() == inplace.data_ptr()
assert not isinstance(out, Index)
assert not isinstance(inplace, Index)
15 changes: 15 additions & 0 deletions torch_geometric/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,18 @@ def _flip(
out._dim_size = input.dim_size

return out


@implements(aten.index_select.default)
def _index_select(
input: Index,
dim: int,
index: Tensor,
) -> Index:

data = aten.index_select.default(input._data, dim, index)

out = Index(data)
out._dim_size = input.dim_size

return out

0 comments on commit c0e1459

Please sign in to comment.