Skip to content

Commit

Permalink
Faster topk implementation for large enough graphs (#6123)
Browse files Browse the repository at this point in the history
Topk needs quite some time to construct the mask in my application.
I noticed I can avoid the loop, if all of the graphs in the batch are
large than `ratio`.

Co-authored-by: mova <mova@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
4 people authored Dec 3, 2022
1 parent 11c8cbd commit adf3bad
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the RandLA-Net architecture as an example ([#5117](https://github.com/pyg-team/pytorch_geometric/pull/5117))
### Changed
- Optimized `utils.softmax` implementation ([#6113](https://github.com/pyg-team/pytorch_geometric/pull/6113))
- Optimized `topk` implementation for large enough graphs ([#6123](https://github.com/pyg-team/pytorch_geometric/pull/6123/))
### Removed

## [2.2.0] - 2022-12-01
Expand Down
16 changes: 10 additions & 6 deletions test/nn/pool/test_topk_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,25 @@ def test_topk():
batch = torch.tensor([0, 0, 1, 1, 1, 1])

perm1 = topk(x, 0.5, batch)

assert perm1.tolist() == [1, 5, 3]
assert x[perm1].tolist() == [4, 9, 6]
assert batch[perm1].tolist() == [0, 1, 1]

perm2 = topk(x, 3, batch)
perm2 = topk(x, 2, batch)
assert perm2.tolist() == [1, 0, 5, 3]
assert x[perm2].tolist() == [4, 2, 9, 6]
assert batch[perm2].tolist() == [0, 0, 1, 1]

assert perm2.tolist() == [1, 0, 5, 3, 2]
assert x[perm2].tolist() == [4, 2, 9, 6, 5]
assert batch[perm2].tolist() == [0, 0, 1, 1, 1]
perm3 = topk(x, 3, batch)
assert perm3.tolist() == [1, 0, 5, 3, 2]
assert x[perm3].tolist() == [4, 2, 9, 6, 5]
assert batch[perm3].tolist() == [0, 0, 1, 1, 1]

if is_full_test():
jit = torch.jit.script(topk)
assert torch.equal(jit(x, 0.5, batch), perm1)
assert torch.equal(jit(x, 3, batch), perm2)
assert torch.equal(jit(x, 2, batch), perm2)
assert torch.equal(jit(x, 3, batch), perm3)


def test_filter_adj():
Expand Down
19 changes: 13 additions & 6 deletions torch_geometric/nn/pool/topk_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,20 @@ def topk(
else:
k = (float(ratio) * num_nodes.to(x.dtype)).ceil().to(torch.long)

mask = [
torch.arange(k[i], dtype=torch.long, device=x.device) +
i * max_num_nodes for i in range(batch_size)
]
mask = torch.cat(mask, dim=0)
if isinstance(ratio, int) and (k == ratio).all():
# If all graphs have exactly `ratio` or more than `ratio` entries,
# we can just pick the first entries in `perm` batch-wise:
index = torch.arange(batch_size, device=x.device) * max_num_nodes
index = index.view(-1, 1).repeat(1, ratio).view(-1)
index += torch.arange(ratio, device=x.device).repeat(batch_size)
else:
# Otherwise, compute indices per graph:
index = torch.cat([
torch.arange(k[i], device=x.device) + i * max_num_nodes
for i in range(batch_size)
], dim=0)

perm = perm[mask]
perm = perm[index]

else:
raise ValueError("At least one of 'min_score' and 'ratio' parameters "
Expand Down

0 comments on commit adf3bad

Please sign in to comment.