Skip to content

Commit

Permalink
Restore correct order in the output of HeteroLinear in case of `is_…
Browse files Browse the repository at this point in the history
…sorted=False` (#6198)

If "is_sorted=False", the output of the forward pass will be unsorted to
the original vertex order of the input.

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
toenshoff and rusty1s authored Dec 12, 2022
1 parent bc47556 commit 25abbb1
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for dropping nodes in `utils.to_dense_batch` in case `max_num_nodes` is smaller than the number of nodes ([#6124](https://github.com/pyg-team/pytorch_geometric/pull/6124))
- Added the RandLA-Net architecture as an example ([#5117](https://github.com/pyg-team/pytorch_geometric/pull/5117))
### Changed
- Fixed a bug in the output order in `HeteroLinear` for un-sorted type vectors ([#6198](https://github.com/pyg-team/pytorch_geometric/pull/6198))
- Breaking Change: Move `ExplainerConfig` arguments to the `Explainer` class ([#6176](https://github.com/pyg-team/pytorch_geometric/pull/6176))
- Refactored `NeighborSampler` to be input-type agnostic ([#6173](https://github.com/pyg-team/pytorch_geometric/pull/6173))
- Infer correct CUDA device ID in `profileit` decorator ([#6164](https://github.com/pyg-team/pytorch_geometric/pull/6164))
Expand Down
27 changes: 22 additions & 5 deletions test/nn/dense/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,32 @@ def test_copy_linear(lazy):


def test_hetero_linear():
x = torch.randn((3, 16))
node_type = torch.tensor([0, 1, 2])
x = torch.randn(3, 16)
type_vec = torch.tensor([0, 1, 2])

lin = HeteroLinear(in_channels=16, out_channels=32, num_types=3)
lin = HeteroLinear(16, 32, num_types=3)
assert str(lin) == 'HeteroLinear(16, 32, num_types=3, bias=True)'

out = lin(x, node_type)
out = lin(x, type_vec)
assert out.size() == (3, 32)

if is_full_test():
jit = torch.jit.script(lin)
assert torch.allclose(jit(x, node_type), out)
assert torch.allclose(jit(x, type_vec), out)


@withPackage('pyg_lib')
@pytest.mark.parametrize('type_vec', [
torch.tensor([0, 0, 1, 1, 2, 2]),
torch.tensor([0, 1, 2, 0, 1, 2]),
])
def test_hetero_linear_sort(type_vec):
x = torch.randn(type_vec.numel(), 16)

lin = HeteroLinear(16, 32, num_types=3)
out = lin(x, type_vec)

for i in range(type_vec.numel()):
node_type = int(type_vec[i])
expected = x[i] @ lin.weight[node_type] + lin.bias[node_type]
assert torch.allclose(out[i], expected, atol=1e-6)
6 changes: 6 additions & 0 deletions torch_geometric/nn/dense/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def forward(self, x: Tensor, type_vec: Tensor) -> Tensor:
if torch_geometric.typing.WITH_PYG_LIB:
assert self.weight is not None

perm: Optional[Tensor] = None
if not self.is_sorted:
if (type_vec[1:] < type_vec[:-1]).any():
type_vec, perm = type_vec.sort()
Expand All @@ -260,6 +261,11 @@ def forward(self, x: Tensor, type_vec: Tensor) -> Tensor:
out = pyg_lib.ops.segment_matmul(x, type_vec_ptr, self.weight)
if self.bias is not None:
out += self.bias[type_vec]

if perm is not None: # Restore original order (if necessary).
out_unsorted = torch.empty_like(out)
out_unsorted[perm] = out
out = out_unsorted
else:
assert self.lins is not None
out = x.new_empty(x.size(0), self.out_channels)
Expand Down

0 comments on commit 25abbb1

Please sign in to comment.