@@ -33,3 +33,38 @@ def test_nearest(dtype, device):
33
33
34
34
out = nearest (x , y )
35
35
assert out .tolist () == [0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 ]
36
+
37
+ # Invalid input: instance 1 only in batch_x
38
+ batch_x = tensor ([0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 ], torch .long , device )
39
+ batch_y = tensor ([0 , 0 , 0 , 0 ], torch .long , device )
40
+ with pytest .raises (ValueError ):
41
+ out = nearest (x , y , batch_x , batch_y )
42
+
43
+ # Invalid input: instance 1 only in batch_x (implicitly as batch_y=None)
44
+ with pytest .raises (ValueError ):
45
+ out = nearest (x , y , batch_x , batch_y = None )
46
+
47
+ # Valid input: instance 1 only in batch_y
48
+ batch_x = tensor ([0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ], torch .long , device )
49
+ batch_y = tensor ([0 , 0 , 1 , 1 ], torch .long , device )
50
+ out = nearest (x , y , batch_x , batch_y )
51
+ assert out .tolist () == [0 , 0 , 1 , 1 , 0 , 0 , 1 , 1 ]
52
+
53
+ # Invalid input: instance 2 only in batch_x
54
+ # (i.e.instance in the middle missing)
55
+ batch_x = tensor ([0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 ], torch .long , device )
56
+ batch_y = tensor ([0 , 1 , 3 , 3 ], torch .long , device )
57
+ with pytest .raises (ValueError ):
58
+ out = nearest (x , y , batch_x , batch_y )
59
+
60
+ # Invalid input: batch_x unsorted
61
+ batch_x = tensor ([0 , 0 , 1 , 0 , 0 , 0 , 0 ], torch .long , device )
62
+ batch_y = tensor ([0 , 0 , 1 , 1 ], torch .long , device )
63
+ with pytest .raises (ValueError ):
64
+ out = nearest (x , y , batch_x , batch_y )
65
+
66
+ # Invalid input: batch_y unsorted
67
+ batch_x = tensor ([0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 ], torch .long , device )
68
+ batch_y = tensor ([0 , 0 , 1 , 0 ], torch .long , device )
69
+ with pytest .raises (ValueError ):
70
+ out = nearest (x , y , batch_x , batch_y )
0 commit comments