Skip to content

Commit 14271b2

Browse files
committed
Error on unsorted batch_x/batch_y in nearest()
1 parent 7f80eba commit 14271b2

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

test/test_nearest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,15 @@ def test_nearest(dtype, device):
5656
batch_y = tensor([0, 1, 3, 3], torch.long, device)
5757
with pytest.raises(ValueError):
5858
out = nearest(x, y, batch_x, batch_y)
59+
60+
# Invalid input: batch_x unsorted
61+
batch_x = tensor([0, 0, 1, 0, 0, 0, 0], torch.long, device)
62+
batch_y = tensor([0, 0, 1, 1], torch.long, device)
63+
with pytest.raises(ValueError):
64+
out = nearest(x, y, batch_x, batch_y)
65+
66+
# Invalid input: batch_y unsorted
67+
batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
68+
batch_y = tensor([0, 0, 1, 0], torch.long, device)
69+
with pytest.raises(ValueError):
70+
out = nearest(x, y, batch_x, batch_y)

torch_cluster/nearest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ def nearest(x: torch.Tensor, y: torch.Tensor,
4242
y = y.view(-1, 1) if y.dim() == 1 else y
4343
assert x.size(1) == y.size(1)
4444

45+
if batch_x is not None and (batch_x[1:] - batch_x[:-1] < 0).any():
46+
raise ValueError("batch_x is not sorted")
47+
if batch_y is not None and (batch_y[1:] - batch_y[:-1] < 0).any():
48+
raise ValueError("batch_y is not sorted")
49+
4550
if x.is_cuda:
4651
if batch_x is not None:
4752
assert x.size(0) == batch_x.numel()

0 commit comments

Comments
 (0)