Skip to content

Allow returning edge indices from random walk #139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion test/test_rw.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


@pytest.mark.parametrize('device', devices)
def test_rw(device):
def test_rw_large(device):
row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device)
col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device)
start = tensor([0, 1, 2, 3, 4], torch.long, device)
Expand All @@ -21,10 +21,59 @@ def test_rw(device):
assert out[n, i].item() in col[row == cur].tolist()
cur = out[n, i].item()


@pytest.mark.parametrize('device', devices)
def test_rw_small(device):
row = tensor([0, 1], torch.long, device)
col = tensor([1, 0], torch.long, device)
start = tensor([0, 1, 2], torch.long, device)
walk_length = 4

out = random_walk(row, col, start, walk_length, num_nodes=3)
assert out.tolist() == [[0, 1, 0, 1, 0], [1, 0, 1, 0, 1], [2, 2, 2, 2, 2]]


@pytest.mark.parametrize('device', devices)
def test_rw_large_with_edge_indices(device):
row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device)
col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device)
start = tensor([0, 1, 2, 3, 4], torch.long, device)
walk_length = 10

node_seq, edge_seq = random_walk(
row, col, start, walk_length,
return_edge_indices=True,
)
assert node_seq[:, 0].tolist() == start.tolist()

for n in range(start.size(0)):
cur = start[n].item()
for i in range(1, walk_length):
assert node_seq[n, i].item() in col[row == cur].tolist()
cur = node_seq[n, i].item()

assert (edge_seq != -1).all()


@pytest.mark.parametrize('device', devices)
def test_rw_small_with_edge_indices(device):
row = tensor([0, 1], torch.long, device)
col = tensor([1, 0], torch.long, device)
start = tensor([0, 1, 2], torch.long, device)
walk_length = 4

node_seq, edge_seq = random_walk(
row, col, start, walk_length,
num_nodes=3,
return_edge_indices=True,
)
assert node_seq.tolist() == [
[0, 1, 0, 1, 0],
[1, 0, 1, 0, 1],
[2, 2, 2, 2, 2],
]
assert edge_seq.tolist() == [
[0, 1, 0, 1],
[1, 0, 1, 0],
[-1, -1, -1, -1],
]
29 changes: 23 additions & 6 deletions torch_cluster/rw.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
from typing import Optional
from typing import Optional, Tuple, Union

import torch
from torch import Tensor


@torch.jit.script
def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int,
p: float = 1, q: float = 1, coalesced: bool = True,
num_nodes: Optional[int] = None) -> Tensor:
def random_walk(
row: Tensor,
col: Tensor,
start: Tensor,
walk_length: int,
p: float = 1,
q: float = 1,
coalesced: bool = True,
num_nodes: Optional[int] = None,
return_edge_indices: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Samples random walks of length :obj:`walk_length` from all node indices
in :obj:`start` in the graph given by :obj:`(row, col)` as described in the
`"node2vec: Scalable Feature Learning for Networks"
Expand All @@ -28,6 +36,9 @@ def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int,
the graph given by :obj:`(row, col)` according to :obj:`row`.
(default: :obj:`True`)
num_nodes (int, optional): The number of nodes. (default: :obj:`None`)
return_edge_indices (bool, optional): Whether to additionally return
the indices of edges traversed during the random walk.
(default: :obj:`False`)

:rtype: :class:`LongTensor`
"""
Expand All @@ -43,5 +54,11 @@ def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int,
rowptr = row.new_zeros(num_nodes + 1)
torch.cumsum(deg, 0, out=rowptr[1:])

return torch.ops.torch_cluster.random_walk(rowptr, col, start, walk_length,
p, q)[0]
node_seq, edge_seq = torch.ops.torch_cluster.random_walk(
rowptr, col, start, walk_length, p, q,
)

if return_edge_indices:
return node_seq, edge_seq

return node_seq