Skip to content

Commit 1363270

Browse files
committed
check for unique positions within probe
1 parent 27137e9 commit 1363270

3 files changed

Lines changed: 111 additions & 0 deletions

File tree

src/probeinterface/probe.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,27 @@ def set_contacts(
279279
if positions.shape[1] != self.ndim:
280280
raise ValueError(f"positions.shape[1]: {positions.shape[1]} and ndim: {self.ndim} do not match!")
281281

282+
# Check for duplicate positions
283+
unique_positions = np.unique(positions, axis=0)
284+
if len(unique_positions) != len(positions):
285+
# Find and report duplicates
286+
duplicates = {}
287+
for index, pos in enumerate(positions):
288+
pos_key = tuple(pos)
289+
if pos_key in duplicates:
290+
duplicates[pos_key].append(index)
291+
else:
292+
duplicates[pos_key] = [index]
293+
294+
duplicate_groups = {pos: indices for pos, indices in duplicates.items() if len(indices) > 1}
295+
duplicate_info = []
296+
for pos, indices in duplicate_groups.items():
297+
pos_str = f"({', '.join(map(str, pos))})"
298+
indices_str = f"[{', '.join(map(str, indices))}]"
299+
duplicate_info.append(f"Position {pos_str} appears at indices {indices_str}")
300+
301+
raise ValueError(f"Contact positions must be unique within a probe. Found {len(duplicate_groups)} duplicate(s): {'; '.join(duplicate_info)}")
302+
282303
self._contact_positions = positions
283304
n = positions.shape[0]
284305

tests/test_probe.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,71 @@ def test_save_to_zarr(tmp_path):
182182
assert probe == reloaded_probe, "Reloaded Probe object does not match the original"
183183

184184

185+
def test_position_uniqueness_validation():
186+
"""Test that the probe validates position uniqueness correctly."""
187+
# Case 1: Unique positions (should pass)
188+
unique_positions = np.array([[0, 0], [10, 10], [20, 20], [30, 30]])
189+
probe = Probe(ndim=2, si_units="um")
190+
probe.set_contacts(positions=unique_positions, shapes="circle", shape_params={"radius": 5})
191+
assert probe.get_contact_count() == 4
192+
193+
# Case 2: Duplicate positions (should fail)
194+
duplicate_positions = np.array([[0, 0], [10, 10], [0, 0], [30, 30]])
195+
probe_dup = Probe(ndim=2, si_units="um")
196+
with pytest.raises(ValueError, match="Contact positions must be unique within a probe"):
197+
probe_dup.set_contacts(positions=duplicate_positions, shapes="circle", shape_params={"radius": 5})
198+
199+
# Case 3: Multiple duplicate positions
200+
multiple_dup_positions = np.array([[0, 0], [10, 10], [0, 0], [10, 10]])
201+
probe_multi_dup = Probe(ndim=2, si_units="um")
202+
with pytest.raises(ValueError, match="Contact positions must be unique within a probe"):
203+
probe_multi_dup.set_contacts(positions=multiple_dup_positions, shapes="circle", shape_params={"radius": 5})
204+
205+
# Case 4: 3D positions uniqueness
206+
unique_3d_positions = np.array([[0, 0, 0], [10, 10, 10], [20, 20, 20]])
207+
probe_3d = Probe(ndim=3, si_units="um")
208+
plane_axes = np.zeros((3, 2, 3))
209+
plane_axes[:, 0, 0] = 1 # x-axis
210+
plane_axes[:, 1, 1] = 1 # y-axis
211+
probe_3d.set_contacts(positions=unique_3d_positions, shapes="circle", shape_params={"radius": 5}, plane_axes=plane_axes)
212+
assert probe_3d.get_contact_count() == 3
213+
214+
# Case 5: 3D duplicate positions (should fail)
215+
duplicate_3d_positions = np.array([[0, 0, 0], [10, 10, 10], [0, 0, 0]])
216+
probe_3d_dup = Probe(ndim=3, si_units="um")
217+
plane_axes_dup = np.zeros((3, 2, 3))
218+
plane_axes_dup[:, 0, 0] = 1
219+
plane_axes_dup[:, 1, 1] = 1
220+
with pytest.raises(ValueError, match="Contact positions must be unique within a probe"):
221+
probe_3d_dup.set_contacts(positions=duplicate_3d_positions, shapes="circle", shape_params={"radius": 5}, plane_axes=plane_axes_dup)
222+
223+
# Case 6: Very close positions that are actually different (should pass)
224+
close_positions = np.array([[0.0, 0.0], [0.001, 0.0], [0.0, 0.001], [0.001, 0.001]])
225+
probe_close = Probe(ndim=2, si_units="um")
226+
probe_close.set_contacts(positions=close_positions, shapes="circle", shape_params={"radius": 5})
227+
assert probe_close.get_contact_count() == 4
228+
229+
# Case 7: Exactly same positions due to floating point precision (should fail)
230+
exact_same_positions = np.array([[0.1, 0.1], [0.2, 0.2], [0.1, 0.1]])
231+
probe_exact = Probe(ndim=2, si_units="um")
232+
with pytest.raises(ValueError, match="Contact positions must be unique within a probe"):
233+
probe_exact.set_contacts(positions=exact_same_positions, shapes="circle", shape_params={"radius": 5})
234+
235+
236+
def test_position_uniqueness_error_message():
237+
"""Test that the error message matches the full expected string for three duplicates using pytest's match regex."""
238+
import re
239+
positions_with_dups = np.array([[0, 0], [10, 10], [0, 0], [20, 20], [0, 0], [10, 10]])
240+
probe = Probe(ndim=2, si_units="um")
241+
expected_error = (
242+
"Contact positions must be unique within a probe. "
243+
"Found 2 duplicate(s): Position (0, 0) appears at indices [0, 2, 4]; Position (10, 10) appears at indices [1, 5]"
244+
)
245+
246+
with pytest.raises(ValueError, match=re.escape(expected_error)):
247+
probe.set_contacts(positions=positions_with_dups, shapes="circle", shape_params={"radius": 5})
248+
249+
185250
if __name__ == "__main__":
186251
test_probe()
187252

tests/test_probegroup.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,31 @@ def test_probegroup_3d():
6767
assert probegroup.ndim == 3
6868

6969

70+
def test_probegroup_allows_duplicate_positions_across_probes():
71+
"""Test that ProbeGroup allows duplicate contact positions if they are in different probes."""
72+
from probeinterface import ProbeGroup, Probe
73+
import numpy as np
74+
75+
# Probes have the same internal relative positions
76+
positions = np.array([[0, 0], [10, 10]])
77+
probe1 = Probe(ndim=2, si_units="um")
78+
probe1.set_contacts(positions=positions, shapes="circle", shape_params={"radius": 5})
79+
probe2 = Probe(ndim=2, si_units="um")
80+
probe2.set_contacts(positions=positions, shapes="circle", shape_params={"radius": 5})
81+
82+
group = ProbeGroup()
83+
group.add_probe(probe1)
84+
group.add_probe(probe2)
85+
86+
# Should not raise any error
87+
all_positions = np.vstack([p.contact_positions for p in group.probes])
88+
# There are duplicates across probes, but this is allowed
89+
assert (all_positions == [0, 0]).any()
90+
assert (all_positions == [10, 10]).any()
91+
# The group should have both probes
92+
assert len(group.probes) == 2
93+
94+
7095
if __name__ == "__main__":
7196
test_probegroup()
7297
# ~ test_probegroup_3d()

0 commit comments

Comments
 (0)