Skip to content

Commit

Permalink
small updates
Browse files Browse the repository at this point in the history
  • Loading branch information
kgajdamo committed Nov 28, 2023
1 parent fdaf800 commit 3d76159
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 31 deletions.
3 changes: 2 additions & 1 deletion torch_geometric/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
self.num_parts = num_parts
self.root = root
self.recursive = recursive

@property
def is_hetero(self) -> bool:
return isinstance(self.data, HeteroData)
Expand All @@ -102,7 +103,7 @@ def generate_partition(self):
recursive=self.recursive,
log=True,
keep_inter_cluster_edges=True,
format='csc'
sparse_format='csc',
)

node_perm = cluster_data.partition.node_perm
Expand Down
50 changes: 20 additions & 30 deletions torch_geometric/loader/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class ClusterData(torch.utils.data.Dataset):
keep_inter_cluster_edges (bool, optional): If set to :obj:`True`,
will keep inter-cluster edge connections. (default: :obj:`False`)
"""

def __init__(
self,
data,
Expand All @@ -68,8 +67,7 @@ def __init__(
self.sparse_format = sparse_format
else:
raise NotImplementedError(
f"Supported formats: ['csr', 'csc'], got {sparse_format}"
)
f"Supported formats: ['csr', 'csc'], got {sparse_format}")

recursive_str = '_recursive' if recursive else ''
filename = f'metis_{num_parts}{recursive_str}.pt'
Expand Down Expand Up @@ -102,9 +100,8 @@ def _metis(self, edge_index: Tensor, num_nodes: int) -> Tensor:
index = col
else:
# Calculate CSR representation:
row, col = sort_edge_index(
edge_index, num_nodes=num_nodes, sort_by_row=False
)
row, col = sort_edge_index(edge_index, num_nodes=num_nodes,
sort_by_row=False)
colptr = index2ptr(col, size=num_nodes)
indptr = colptr
index = row
Expand Down Expand Up @@ -133,10 +130,8 @@ def _metis(self, edge_index: Tensor, num_nodes: int) -> Tensor:
).to(edge_index.device)

if cluster is None:
raise ImportError(
f"'{self.__class__.__name__}' requires either "
f"'pyg-lib' or 'torch-sparse'"
)
raise ImportError(f"'{self.__class__.__name__}' requires either "
f"'pyg-lib' or 'torch-sparse'")

return cluster

Expand All @@ -151,7 +146,8 @@ def _partition(self, edge_index: Tensor, cluster: Tensor) -> Partition:
# Permute `edge_index` based on node permutation:
edge_perm = torch.arange(edge_index.size(1), device=edge_index.device)
arange = torch.empty_like(node_perm)
arange[node_perm] = torch.arange(cluster.numel(), device=cluster.device)
arange[node_perm] = torch.arange(cluster.numel(),
device=cluster.device)
edge_index = arange[edge_index]

# Compute final CSR representation:
Expand All @@ -168,9 +164,8 @@ def _partition(self, edge_index: Tensor, cluster: Tensor) -> Partition:
indptr = index2ptr(col, size=cluster.numel())
index = row

return Partition(
indptr, index, partptr, node_perm, edge_perm, self.sparse_format
)
return Partition(indptr, index, partptr, node_perm, edge_perm,
self.sparse_format)

def _permute_data(self, data: Data, partition: Partition) -> Data:
# Permute node-level and edge-level attributes according to the
Expand All @@ -197,25 +192,21 @@ def __getitem__(self, idx: int) -> Data:
node_end = int(self.partition.partptr[idx + 1])
node_length = node_end - node_start

indptr = self.partition.indptr[node_start:node_end + 1]
edge_start = int(indptr[0])
edge_end = int(indptr[-1])
edge_length = edge_end - edge_start
indptr = indptr - edge_start

if self.sparse_format == 'csr':
rowptr = self.partition.indptr[node_start : node_end + 1]
edge_start = int(rowptr[0])
edge_end = int(rowptr[-1])
edge_length = edge_end - edge_start
rowptr = rowptr - edge_start
row = ptr2index(rowptr)
row = ptr2index(indptr)
col = self.partition.index[edge_start:edge_end]
if not self.keep_inter_cluster_edges:
edge_mask = (col >= node_start) & (col < node_end)
row = row[edge_mask]
col = col[edge_mask] - node_start
else:
colptr = self.partition.indptr[node_start : node_end + 1]
edge_start = int(colptr[0])
edge_end = int(colptr[-1])
edge_length = edge_end - edge_start
colptr = colptr - edge_start
col = ptr2index(colptr)
col = ptr2index(indptr)
row = self.partition.index[edge_start:edge_end]
if not self.keep_inter_cluster_edges:
edge_mask = (row >= node_start) & (row < node_end)
Expand Down Expand Up @@ -269,7 +260,6 @@ class ClusterLoader(torch.utils.data.DataLoader):
:class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,
:obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.
"""

def __init__(self, cluster_data, **kwargs):
self.cluster_data = cluster_data
iterator = range(len(cluster_data))
Expand All @@ -296,15 +286,15 @@ def _collate(self, batch: List[int]) -> Data:
rows, cols, nodes, cumsum = [], [], [], 0
for i in range(batch.numel()):
nodes.append(torch.arange(node_start[i], node_end[i]))
indptr = global_ptr[node_start[i] : node_end[i] + 1]
indptr = global_ptr[node_start[i]:node_end[i] + 1]
indptr = indptr - edge_start[i]
if self.cluster_data.partition.sparse_format == 'csr':
row = ptr2index(indptr) + cumsum
col = global_index[edge_start[i] : edge_end[i]]
col = global_index[edge_start[i]:edge_end[i]]

else:
col = ptr2index(indptr) + cumsum
row = global_index[edge_start[i] : edge_end[i]]
row = global_index[edge_start[i]:edge_end[i]]

rows.append(row)
cols.append(col)
Expand Down

0 comments on commit 3d76159

Please sign in to comment.