Skip to content

Illegal memory access with GPU implementation of nearest() if batch_y includes an empty instance  #163

Closed
@Vuenc

Description

@Vuenc

I noticed a potential bug with the nearest() function when calling it with certain invalid inputs.
When a certain batch index is present in the batch_x argument, but not in the batch_y argument, I observed nearest() to have non-deterministic outputs and to sometimes cause the error CUDA error: an illegal memory access was encountered.

As an example:

In [1]: import torch_cluster
In [2]: import torch
In [3]: pos = torch.tensor([[0,0,0.],[1,1,1,],[2,2,2],[3,3,3],[4,4,4]]).cuda()
In [4]: torch_cluster.nearest(x=pos, y=pos[[0,1,2,3]],
                              batch_x=torch.tensor([0,1,2,3,4]).cuda(), batch_y=torch.tensor([0,1,2,3]).cuda())
Out[4]: tensor([0, 1, 2, 3, 0], device='cuda:0') # invalid inputs: for the empty batch, 0 is output

In [5]: torch_cluster.nearest(x=pos, y=pos,
                              batch_x=torch.tensor([0,1,2,3,4]).cuda(), batch_y=torch.tensor([0,1,2,3,4]).cuda())
Out[5]: tensor([0, 1, 2, 3, 4], device='cuda:0') # (valid inputs, output as expected)

In [6]: torch_cluster.nearest(x=pos, y=pos[[0,1,2,3]],
                              batch_x=torch.tensor([0,1,2,3,4]).cuda(), batch_y=torch.tensor([0,1,2,3]).cuda())
Out[6]: tensor([0, 1, 2, 3, 4], device='cuda:0') # invalid inputs: same code as In[4], but for the empty batch, 4 is output

Here the In [4] and In [6] lines are identical, but produce different results. In the first case, for point [4,4,4] which has batch index 4 in batch_x, but no nearest neighbor because no point has batch index 4 in batch_y, the result index 0 is output. In [5] is a call with valid inputs. But then in In [6], for the point [4,4,4], the result index 4 is output. I wonder if the reason for this could be a read of uninitialized memory.

I can't reproduce the illegal memory access error I encountered in isolation, but it occurred during a training in exactly a situation like this. If it helps you, I can try to provide more details. Possibly the illegal memory access only happens if the batch index missing from batch_y is at the end of the batch_x, at least that seemed to be the case in the crashes I observed.

I'm not sure if you'd consider this behavior a bug, because it is caused by invalid inputs: if a batch index is present in batch_x, but missing from batch_y, the nearest operation is not well-defined because we are asking it to find the closest points from an empty set. However the missing batch indices can happen easily if there's an empty graph in the data, and this was a quite hard-to-find bug for me because it is such an edge case and triggered the illegal memory access. Therefore I think that raising an error or outputting a consistent value like -1 for affected points would be preferable.

Edit: here's my environment information

Environment

torch_cluster version: 1.6.0
PyG version: 2.0.4
PyTorch version: 1.10.2
OS: Ubuntu
Python version: 3.9.12
CUDA/cuDNN version: 11.6
How you installed PyTorch and PyG (conda, pip, source): conda

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions