Description
I noticed a potential bug with the nearest()
function when calling it with certain invalid inputs.
When a certain batch index is present in the batch_x
argument, but not in the batch_y
argument, I observed nearest()
to have non-deterministic outputs and to sometimes cause the error CUDA error: an illegal memory access was encountered
.
As an example:
In [1]: import torch_cluster
In [2]: import torch
In [3]: pos = torch.tensor([[0,0,0.],[1,1,1,],[2,2,2],[3,3,3],[4,4,4]]).cuda()
In [4]: torch_cluster.nearest(x=pos, y=pos[[0,1,2,3]],
batch_x=torch.tensor([0,1,2,3,4]).cuda(), batch_y=torch.tensor([0,1,2,3]).cuda())
Out[4]: tensor([0, 1, 2, 3, 0], device='cuda:0') # invalid inputs: for the empty batch, 0 is output
In [5]: torch_cluster.nearest(x=pos, y=pos,
batch_x=torch.tensor([0,1,2,3,4]).cuda(), batch_y=torch.tensor([0,1,2,3,4]).cuda())
Out[5]: tensor([0, 1, 2, 3, 4], device='cuda:0') # (valid inputs, output as expected)
In [6]: torch_cluster.nearest(x=pos, y=pos[[0,1,2,3]],
batch_x=torch.tensor([0,1,2,3,4]).cuda(), batch_y=torch.tensor([0,1,2,3]).cuda())
Out[6]: tensor([0, 1, 2, 3, 4], device='cuda:0') # invalid inputs: same code as In[4], but for the empty batch, 4 is output
Here the In [4]
and In [6]
lines are identical, but produce different results. In the first case, for point [4,4,4] which has batch index 4 in batch_x
, but no nearest neighbor because no point has batch index 4 in batch_y
, the result index 0 is output. In [5]
is a call with valid inputs. But then in In [6]
, for the point [4,4,4], the result index 4 is output. I wonder if the reason for this could be a read of uninitialized memory.
I can't reproduce the illegal memory access error I encountered in isolation, but it occurred during a training in exactly a situation like this. If it helps you, I can try to provide more details. Possibly the illegal memory access only happens if the batch index missing from batch_y
is at the end of the batch_x
, at least that seemed to be the case in the crashes I observed.
I'm not sure if you'd consider this behavior a bug, because it is caused by invalid inputs: if a batch index is present in batch_x
, but missing from batch_y
, the nearest
operation is not well-defined because we are asking it to find the closest points from an empty set. However the missing batch indices can happen easily if there's an empty graph in the data, and this was a quite hard-to-find bug for me because it is such an edge case and triggered the illegal memory access. Therefore I think that raising an error or outputting a consistent value like -1 for affected points would be preferable.
Edit: here's my environment information
Environment
torch_cluster version: 1.6.0
PyG version: 2.0.4
PyTorch version: 1.10.2
OS: Ubuntu
Python version: 3.9.12
CUDA/cuDNN version: 11.6
How you installed PyTorch and PyG (conda, pip, source): conda