Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/probeinterface/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,38 @@ def get_shanks(self):
shanks.append(shank)
return shanks

def __eq__(self, other):
if not isinstance(other, Probe):
return False

if not (
self.ndim == other.ndim
and self.si_units == other.si_units
and self.name == other.name
and self.serial_number == other.serial_number
and self.model_name == other.model_name
and self.manufacturer == other.manufacturer
and np.array_equal(self._contact_positions, other._contact_positions)
and np.array_equal(self._contact_plane_axes, other._contact_plane_axes)
and np.array_equal(self._contact_shapes, other._contact_shapes)
and np.array_equal(self._contact_shape_params, other._contact_shape_params)
and np.array_equal(self.probe_planar_contour, other.probe_planar_contour)
and np.array_equal(self._shank_ids, other._shank_ids)
and np.array_equal(self.device_channel_indices, other.device_channel_indices)
and np.array_equal(self._contact_ids, other._contact_ids)
and self.annotations == other.annotations
):
return False

# Compare contact_annotations dictionaries
if self.contact_annotations.keys() != other.contact_annotations.keys():
return False
for key in self.contact_annotations:
if not np.array_equal(self.contact_annotations[key], other.contact_annotations[key]):
return False

return True

def copy(self):
"""
Copy to another Probe instance.
Expand Down
95 changes: 52 additions & 43 deletions tests/test_probe.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from probeinterface import Probe
from probeinterface.generator import generate_dummy_probe

import numpy as np

import pytest


def _dummy_position():
n = 24
positions = np.zeros((n, 2))
Expand All @@ -19,10 +21,10 @@ def _dummy_position():
def test_probe():
positions = _dummy_position()

probe = Probe(ndim=2, si_units='um')
probe.set_contacts(positions=positions, shapes='circle', shape_params={'radius': 5})
probe.set_contacts(positions=positions, shapes='square', shape_params={'width': 5})
probe.set_contacts(positions=positions, shapes='rect', shape_params={'width': 8, 'height':5 })
probe = Probe(ndim=2, si_units="um")
probe.set_contacts(positions=positions, shapes="circle", shape_params={"radius": 5})
probe.set_contacts(positions=positions, shapes="square", shape_params={"width": 5})
probe.set_contacts(positions=positions, shapes="rect", shape_params={"width": 8, "height": 5})

assert probe.get_contact_count() == 24

Expand All @@ -34,20 +36,20 @@ def test_probe():
probe.create_auto_shape()

# annotation
probe.annotate(manufacturer='me')
assert 'manufacturer' in probe.annotations
probe.annotate_contacts(impedance=np.random.rand(24)*1000)
assert 'impedance' in probe.contact_annotations
probe.annotate(manufacturer="me")
assert "manufacturer" in probe.annotations
probe.annotate_contacts(impedance=np.random.rand(24) * 1000)
assert "impedance" in probe.contact_annotations

# device channel
chans = np.arange(0, 24, dtype='int')
chans = np.arange(0, 24, dtype="int")
np.random.shuffle(chans)
probe.set_device_channel_indices(chans)

# contact_ids int or str
elec_ids = np.arange(24)
probe.set_contact_ids(elec_ids)
elec_ids = [f'elec #{e}' for e in range(24)]
elec_ids = [f"elec #{e}" for e in range(24)]
probe.set_contact_ids(elec_ids)

# copy
Expand All @@ -59,18 +61,17 @@ def test_probe():

# make annimage
values = np.random.randn(24)
image, xlims, ylims = probe.to_image(values, method='cubic')

image2, xlims, ylims = probe.to_image(values, method='cubic', num_pixel=16)
image, xlims, ylims = probe.to_image(values, method="cubic")

#~ from probeinterface.plotting import plot_probe_group, plot_probe
#~ import matplotlib.pyplot as plt
#~ fig, ax = plt.subplots()
#~ plot_probe(probe, ax=ax)
#~ ax.imshow(image, extent=xlims+ylims, origin='lower')
#~ ax.imshow(image2, extent=xlims+ylims, origin='lower')
#~ plt.show()
image2, xlims, ylims = probe.to_image(values, method="cubic", num_pixel=16)

# ~ from probeinterface.plotting import plot_probe_group, plot_probe
# ~ import matplotlib.pyplot as plt
# ~ fig, ax = plt.subplots()
# ~ plot_probe(probe, ax=ax)
# ~ ax.imshow(image, extent=xlims+ylims, origin='lower')
# ~ ax.imshow(image2, extent=xlims+ylims, origin='lower')
# ~ plt.show()

# 3d
probe_3d = probe.to_3d()
Expand All @@ -81,10 +82,10 @@ def test_probe():
probe_2d = probe_3d.to_2d(axes="xz")
assert np.allclose(probe_2d.contact_positions, probe_3d.contact_positions[:, [0, 2]])

#~ from probeinterface.plotting import plot_probe_group, plot_probe
#~ import matplotlib.pyplot as plt
#~ plot_probe(probe_3d)
#~ plt.show()
# ~ from probeinterface.plotting import plot_probe_group, plot_probe
# ~ import matplotlib.pyplot as plt
# ~ plot_probe(probe_3d)
# ~ plt.show()

# get shanks
for shank in probe.get_shanks():
Expand All @@ -110,40 +111,48 @@ def test_probe():
df = probe.to_dataframe(complete=False)
other2 = Probe.from_dataframe(df)
df = probe_3d.to_dataframe(complete=True)
# print(df.index)
# print(df.index)
other_3d = Probe.from_dataframe(df)
assert other_3d.ndim == 3

# slice handling
selection = np.arange(0,18,2)
selection = np.arange(0, 18, 2)
# print(selection.dtype.kind)
sliced_probe = probe.get_slice(selection)
assert sliced_probe.get_contact_count() == 9
assert sliced_probe.contact_annotations['impedance'].shape == (9, )
assert sliced_probe.contact_annotations["impedance"].shape == (9,)

#~ from probeinterface.plotting import plot_probe_group, plot_probe
#~ import matplotlib.pyplot as plt
#~ plot_probe(probe)
#~ plot_probe(sliced_probe)
# ~ from probeinterface.plotting import plot_probe_group, plot_probe
# ~ import matplotlib.pyplot as plt
# ~ plot_probe(probe)
# ~ plot_probe(sliced_probe)

selection = np.ones(24, dtype='bool')
selection = np.ones(24, dtype="bool")
selection[::2] = False
sliced_probe = probe.get_slice(selection)
assert sliced_probe.get_contact_count() == 12
assert sliced_probe.contact_annotations['impedance'].shape == (12, )
assert sliced_probe.contact_annotations["impedance"].shape == (12,)

#~ plot_probe(probe)
#~ plot_probe(sliced_probe)
#~ plt.show()
# ~ plot_probe(probe)
# ~ plot_probe(sliced_probe)
# ~ plt.show()


def test_set_shanks():
probe = Probe(ndim=2, si_units='um')
probe.set_contacts(
positions= np.arange(20).reshape(10, 2),
shapes='circle',
shape_params={'radius' : 5})
def test_probe_equality_dunder():
probe1 = generate_dummy_probe()
probe2 = generate_dummy_probe()

assert probe1 == probe1
assert probe2 == probe2
assert probe1 == probe2

# Modify probe2
probe2.move([1, 1])
assert probe2 != probe1

def test_set_shanks():
probe = Probe(ndim=2, si_units="um")
probe.set_contacts(positions=np.arange(20).reshape(10, 2), shapes="circle", shape_params={"radius": 5})

# for simplicity each contact is on separate shank
shank_ids = np.arange(10)
Expand All @@ -152,7 +161,7 @@ def test_set_shanks():
assert all(probe.shank_ids == shank_ids.astype(str))


if __name__ == '__main__':
if __name__ == "__main__":
test_probe()

test_set_shanks()