torch_geometric.nn.nearest argument order flipped #9763
Open
Description
🐛 Describe the bug
The variable order in nearest
is flipped.
The documentation stated "Finds for each element in y
the k
nearest point in x
", which means the output dimension should be the same as x
. However, when I used the following code to test this function, the output dimension is the same as x
.
import torch
from torch_geometric.nn import nearest
x = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]])
y = torch.tensor([[-1.0, 0.0], [1.0, 0.0]])
cluster = nearest(x, y)
print(cluster, cluster.shape)
output:
tensor([0, 0, 1, 1]) torch.Size([4])
The issue was fixed after I flipped x and y.
I think this has to do with the fact that torch_geometric.nn.nearest
directly adopted the ordering convention form torch_cluster.nearest()
, which has a flipped order fromtorch_cluster.knn()
.
Versions
[pip3] numpy==2.0.2
[pip3] torch==2.4.1
[pip3] torch_cluster==1.6.3
[pip3] torch_geometric==2.5.3
[pip3] torch_scatter==2.1.2
[pip3] torchaudio==2.4.1
[pip3] torchvision==0.19.1