Skip to content

Commit

Permalink
[Transforms] Fix sparse-sparse matrix multiplication support on Windo…
Browse files Browse the repository at this point in the history
…ws (#8197)
  • Loading branch information
rusty1s authored Oct 16, 2023
1 parent 6be5a57 commit 8c0f635
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 9 deletions.
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [2.5.0] - 2023-MM-DD

### Added

- Added distributed `GAT + ogbn-products` example targeting XPU device ([#8032](https://github.com/pyg-team/pytorch_geometric/pull/8032))

### Changed

### Deprecated

### Removed

### Fixed

- Fixed sparse-sparse matrix multiplication support on Windows in `TwoHop` and `AddRandomWalkPE` transformations ([#8197](https://github.com/pyg-team/pytorch_geometric/pull/8197))

### Removed

## [2.4.0] - 2023-10-12

### Added
Expand Down
2 changes: 0 additions & 2 deletions test/transforms/test_add_positional_encoding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch

from torch_geometric.data import Data
from torch_geometric.testing import onlyLinux
from torch_geometric.transforms import (
AddLaplacianEigenvectorPE,
AddRandomWalkPE,
Expand Down Expand Up @@ -74,7 +73,6 @@ def test_eigenvector_permutation_invariance():
assert torch.allclose(out1.x[perm].abs(), out2.x.abs(), atol=1e-6)


@onlyLinux # TODO (matthias) Investigate CSR @ CSR support on Windows.
def test_add_random_walk_pe():
x = torch.randn(6, 4)
edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5],
Expand Down
2 changes: 0 additions & 2 deletions test/transforms/test_two_hop.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import torch

from torch_geometric.data import Data
from torch_geometric.testing import onlyLinux
from torch_geometric.transforms import TwoHop


@onlyLinux # TODO (matthias) Investigate CSR @ CSR support on Windows.
def test_two_hop():
transform = TwoHop()
assert str(transform) == 'TwoHop()'
Expand Down
7 changes: 6 additions & 1 deletion torch_geometric/transforms/add_positional_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import torch

import torch_geometric.typing
from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
Expand All @@ -12,6 +13,7 @@
scatter,
to_edge_index,
to_scipy_sparse_matrix,
to_torch_coo_tensor,
to_torch_csr_tensor,
)

Expand Down Expand Up @@ -136,7 +138,10 @@ def forward(self, data: Data) -> Data:
value = scatter(value, row, dim_size=N, reduce='sum').clamp(min=1)[row]
value = 1.0 / value

adj = to_torch_csr_tensor(data.edge_index, value, size=data.size())
if torch_geometric.typing.WITH_WINDOWS:
adj = to_torch_coo_tensor(data.edge_index, value, size=data.size())
else:
adj = to_torch_csr_tensor(data.edge_index, value, size=data.size())

out = adj
pe_list = [get_self_loop_attr(*to_edge_index(out), num_nodes=N)]
Expand Down
12 changes: 10 additions & 2 deletions torch_geometric/transforms/two_hop.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import torch

import torch_geometric.typing
from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import (
coalesce,
remove_self_loops,
to_edge_index,
to_torch_coo_tensor,
to_torch_csr_tensor,
)

Expand All @@ -19,8 +21,14 @@ def forward(self, data: Data) -> Data:
edge_index, edge_attr = data.edge_index, data.edge_attr
N = data.num_nodes

adj = to_torch_csr_tensor(edge_index, size=(N, N))
edge_index2, _ = to_edge_index(adj @ adj)
if torch_geometric.typing.WITH_WINDOWS:
adj = to_torch_coo_tensor(edge_index, size=(N, N))
else:
adj = to_torch_csr_tensor(edge_index, size=(N, N))

adj = adj @ adj

edge_index2, _ = to_edge_index(adj)
edge_index2, _ = remove_self_loops(edge_index2)

edge_index = torch.cat([edge_index, edge_index2], dim=1)
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import os
import platform
import sys
import warnings
Expand All @@ -14,6 +15,7 @@
WITH_PT112 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 12
WITH_PT113 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 13

WITH_WINDOWS = os.name == 'nt'
WITH_ARM = platform.machine() != 'x86_64'

if not hasattr(torch, 'sparse_csc'):
Expand Down

0 comments on commit 8c0f635

Please sign in to comment.