Skip to content

Get edge attributes for random walk  #137

Closed
@jacobdanovitch

Description

@jacobdanovitch

Hi,

I'm trying to get the edge attributes for each step in a random walk, with the use case of trying to retain edge types in a walk on a heterogeneous graph I've converted to homogeneous. Using the idea in rusty1s/pytorch_sparse#214, I have this so far:

import torch
from torch_geometric.datasets import DBLP
from torch_cluster import random_walk

def get_edge_attr(edge_index: torch.Tensor, edge_attr: torch.Tensor, query_row: torch.Tensor, query_col: torch.Tensor) -> torch.Tensor:
    row, col = edge_index
    row_mask = row == query_row.view(-1, 1)
    col_mask = col == query_col.view(-1, 1)
    mask = torch.max(torch.logical_and(row_mask, col_mask), dim=0).values
    return edge_attr[mask]

data = DBLP()[0].to_homogeneous()

num_walks = 1
walk_length = 5

start = torch.arange(data.num_nodes).view(-1, 1).repeat(1, num_walks).view(-1)
rw = random_walk(data.edge_index[0], data.edge_index[1], start, walk_length, num_nodes=data.num_nodes)
print(rw)
>>> tensor([[    0, 10514, 19555, 10973, 21385,  9952],
...        [    1,  6520, 20381,  6520, 20381,  6520]])

l, r = rw[:2].unfold(1, 2, 1).flatten().t() # sliding window of size 2 over each walk
print(get_edge_attr(data.edge_index, data.edge_type, l, r))
>>> tensor([1, 1, 4, 4, 4, 2, 2, 2])

I know this can be made more efficient with searchsorted as mentioned in the linked issue, but is this correct in general? Is there a built-in way to do this that I've missed or is this my best bet?

Thanks in advance.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions