Closed
Description
Hi,
I'm trying to get the edge attributes for each step in a random walk, with the use case of trying to retain edge types in a walk on a heterogeneous graph I've converted to homogeneous. Using the idea in rusty1s/pytorch_sparse#214, I have this so far:
import torch
from torch_geometric.datasets import DBLP
from torch_cluster import random_walk
def get_edge_attr(edge_index: torch.Tensor, edge_attr: torch.Tensor, query_row: torch.Tensor, query_col: torch.Tensor) -> torch.Tensor:
row, col = edge_index
row_mask = row == query_row.view(-1, 1)
col_mask = col == query_col.view(-1, 1)
mask = torch.max(torch.logical_and(row_mask, col_mask), dim=0).values
return edge_attr[mask]
data = DBLP()[0].to_homogeneous()
num_walks = 1
walk_length = 5
start = torch.arange(data.num_nodes).view(-1, 1).repeat(1, num_walks).view(-1)
rw = random_walk(data.edge_index[0], data.edge_index[1], start, walk_length, num_nodes=data.num_nodes)
print(rw)
>>> tensor([[ 0, 10514, 19555, 10973, 21385, 9952],
... [ 1, 6520, 20381, 6520, 20381, 6520]])
l, r = rw[:2].unfold(1, 2, 1).flatten().t() # sliding window of size 2 over each walk
print(get_edge_attr(data.edge_index, data.edge_type, l, r))
>>> tensor([1, 1, 4, 4, 4, 2, 2, 2])
I know this can be made more efficient with searchsorted
as mentioned in the linked issue, but is this correct in general? Is there a built-in way to do this that I've missed or is this my best bet?
Thanks in advance.