Skip to content

torch_geometric.nn.nearest argument order flipped #9763

Open
@gegewen

Description

🐛 Describe the bug

The variable order in nearest is flipped.

The documentation stated "Finds for each element in y the k nearest point in x", which means the output dimension should be the same as x. However, when I used the following code to test this function, the output dimension is the same as x.

import torch
from torch_geometric.nn import nearest

x = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]])
y = torch.tensor([[-1.0, 0.0], [1.0, 0.0]])
cluster = nearest(x, y)
print(cluster, cluster.shape)

output:

tensor([0, 0, 1, 1]) torch.Size([4])

The issue was fixed after I flipped x and y.

I think this has to do with the fact that torch_geometric.nn.nearest directly adopted the ordering convention form torch_cluster.nearest(), which has a flipped order fromtorch_cluster.knn().

Versions

[pip3] numpy==2.0.2
[pip3] torch==2.4.1
[pip3] torch_cluster==1.6.3
[pip3] torch_geometric==2.5.3
[pip3] torch_scatter==2.1.2
[pip3] torchaudio==2.4.1
[pip3] torchvision==0.19.1

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions