Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-factor ClusterLoader + Integrate pyg-lib metis computation #7416

Merged
merged 11 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update
  • Loading branch information
rusty1s committed May 24, 2023
commit 81507de327c56fcca6e3a76fd8e5e28b541f1bf9
23 changes: 23 additions & 0 deletions test/utils/test_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch

from torch_geometric.utils.map import map_index


def test_map_index():
src = torch.tensor([2, 0, 1, 0, 3])
index = torch.tensor([3, 2, 0, 1])

print(map_index(src, index))

out, mask = map_index(src, index)
assert out.tolist() == [1, 2, 3, 2, 0]
assert mask.tolist() == [True, True, True, True, True]


def test_map_index_na():
src = torch.tensor([2, 0, 1, 0, 3])
index = torch.tensor([3, 2, 0])

out, mask = map_index(src, index)
assert out.tolist() == [1, 2, 2, 0]
assert mask.tolist() == [True, True, False, True, True]
12 changes: 7 additions & 5 deletions torch_geometric/loader/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch_geometric.data import Data
from torch_geometric.typing import pyg_lib
from torch_geometric.utils import index_sort, narrow, select, sort_edge_index
from torch_geometric.utils.map import map_index
from torch_geometric.utils.sparse import index2ptr, ptr2index


Expand Down Expand Up @@ -241,24 +242,25 @@ def _collate(self, batch: List[int]) -> Data:
# connectivity. This is done by slicing the corresponding source and
# destination indices for each partition and adjusting their indices to
# start from zero:
rows, cols, cumsum = [], [], 0
rows, cols, nodes, cumsum = [], [], [], 0
for i in range(batch.numel()):
nodes.append(torch.arange(node_start[i], node_end[i]))
rowptr = global_rowptr[node_start[i]:node_end[i] + 1]
rowptr = rowptr - edge_start[i]
row = ptr2index(rowptr) + cumsum
col = global_col[edge_start[i]:edge_end[i]]
col = col - node_start[i] + cumsum
rows.append(row)
cols.append(col)
cumsum += rowptr.numel() - 1

node = torch.cat(nodes, dim=0)
row = torch.cat(rows, dim=0)
col = torch.cat(cols, dim=0)

# Mask out any edge that does not connect nodes within the same batch:
edge_mask = (col >= 0) & (col < cumsum)
# Map `col` vector to valid entries and remove any entries that do not
# connect two nodes within the same mini-batch:
col, edge_mask = map_index(col, node)
row = row[edge_mask]
col = col[edge_mask]

out = copy.copy(self.cluster_data.data)

Expand Down
50 changes: 50 additions & 0 deletions torch_geometric/utils/map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Tuple

import torch
from torch import Tensor


def map_index(src: Tensor, index: Tensor) -> Tuple[Tensor, Tensor]:
r"""Maps indices in :obj:`src` to the positional value of their
corresponding occurence in :obj:`index`.

Args:
src (torch.Tensor): The source tensor to map.
index (torch.Tensor): The index tensor that denotes the new mapping.

:rtype: (:class:`torch.Tensor`, :class:`torch.BoolTensor`)

Examples:

>>> src = torch.tensor([2, 0, 1, 0, 3])
>>> index = torch.tensor([3, 2, 0, 1])

>>> map_index(src, index)
(tensor([1, 2, 3, 2, 0]), tensor([True, True, True, True, True]))

>>> src = torch.tensor([2, 0, 1, 0, 3])
>>> index = torch.tensor([3, 2, 0])

>>> map_index(src, index)
(tensor([1, 2, 2, 0]), tensor([True, True, False, True, True]))
"""
import pandas as pd

assert src.dim() == 1 and index.dim() == 1
assert not src.is_floating_point()
assert not index.is_floating_point()

arange = pd.RangeIndex(0, index.size(0))
df = pd.DataFrame(index=index.detach().cpu().numpy(), data={'out': arange})
ser = pd.Series(src.detach().cpu(), name='key')
result = df.merge(ser, how='right', left_index=True, right_on='key')
out = torch.from_numpy(result['out'].values).to(index.device)

if out.is_floating_point():
mask = torch.isnan(out).logical_not_()
out = out[mask].to(index.dtype)
return out, mask

out = out.to(index.dtype)
mask = torch.ones_like(out, dtype=torch.bool)
return out, mask