Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ repos:
rev: 23.12.1
hooks:
- id: black
files: ^src/
files: ^src/|^tests/
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ test = [
"scipy",
"pandas",
"h5py",
]
]

docs = [
"pillow",
Expand Down
32 changes: 18 additions & 14 deletions tests/test_generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from probeinterface import (generate_dummy_probe, generate_dummy_probe_group,
generate_tetrode, generate_linear_probe, generate_multi_columns_probe,
generate_multi_shank)
from probeinterface import (
generate_dummy_probe,
generate_dummy_probe_group,
generate_tetrode,
generate_linear_probe,
generate_multi_columns_probe,
generate_multi_shank,
)


from pathlib import Path
Expand All @@ -15,20 +20,19 @@ def test_generate():

tetrode = generate_tetrode()

multi_columns = generate_multi_columns_probe(num_columns=3,
num_contact_per_column=[10, 12, 10],
xpitch=22, ypitch=20,
y_shift_per_column=[0, -10, 0])
multi_columns = generate_multi_columns_probe(
num_columns=3, num_contact_per_column=[10, 12, 10], xpitch=22, ypitch=20, y_shift_per_column=[0, -10, 0]
)

linear = generate_linear_probe(num_elec=16, ypitch=20,
contact_shapes='square', contact_shape_params={'width': 15})
linear = generate_linear_probe(num_elec=16, ypitch=20, contact_shapes="square", contact_shape_params={"width": 15})

multi_shank = generate_multi_shank()

#~ from probeinterface.plotting import plot_probe_group, plot_probe
#~ import matplotlib.pyplot as plt
#~ plot_probe(multi_shank, with_contact_id=True,)
#~ plt.show()
# ~ from probeinterface.plotting import plot_probe_group, plot_probe
# ~ import matplotlib.pyplot as plt
# ~ plot_probe(multi_shank, with_contact_id=True,)
# ~ plt.show()

if __name__ == '__main__':

if __name__ == "__main__":
test_generate()
27 changes: 7 additions & 20 deletions tests/test_io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test_probeinterface_format(tmp_path):
# ~ plot_probe_group(probegroup2, with_contact_id=True, same_axes=False)
# ~ plt.show()


def test_writeprobeinterface(tmp_path):
probe = generate_dummy_probe()
file_path = tmp_path / "test.prb"
Expand All @@ -61,7 +62,6 @@ def test_writeprobeinterface_raises_error_with_bad_input(tmp_path):
write_probeinterface(file_path, probe)



def test_BIDS_format(tmp_path):
folder_path = tmp_path / "test_BIDS"
folder_path.mkdir()
Expand All @@ -77,9 +77,7 @@ def test_BIDS_format(tmp_path):
# with BIDS specifications
n_els = sum([p.get_contact_count() for p in probegroup.probes])
# using np.random.choice to ensure uniqueness of contact ids
el_ids = np.random.choice(
np.arange(1e4, 1e5, dtype="int"), replace=False, size=n_els
).astype(str)
el_ids = np.random.choice(np.arange(1e4, 1e5, dtype="int"), replace=False, size=n_els).astype(str)
for probe in probegroup.probes:
probe_el_ids, el_ids = np.split(el_ids, [probe.get_contact_count()])
probe.set_contact_ids(probe_el_ids)
Expand All @@ -102,12 +100,7 @@ def test_BIDS_format(tmp_path):
assert all(np.isin(probe_orig.contact_ids, probe_read.contact_ids))

# the transformation of contact order between the two probes
t = np.array(
[
list(probe_read.contact_ids).index(elid)
for elid in probe_orig.contact_ids
]
)
t = np.array([list(probe_read.contact_ids).index(elid) for elid in probe_orig.contact_ids])

assert all(probe_orig.contact_ids == probe_read.contact_ids[t])
assert all(probe_orig.shank_ids == probe_read.shank_ids[t])
Expand All @@ -116,21 +109,14 @@ def test_BIDS_format(tmp_path):
assert probe_orig.si_units == probe_read.si_units

for i in range(len(probe_orig.probe_planar_contour)):
assert all(
probe_orig.probe_planar_contour[i] == probe_read.probe_planar_contour[i]
)
assert all(probe_orig.probe_planar_contour[i] == probe_read.probe_planar_contour[i])
for sid, shape_params in enumerate(probe_orig.contact_shape_params):
assert shape_params == probe_read.contact_shape_params[t][sid]
for i in range(len(probe_orig.contact_positions)):
assert all(
probe_orig.contact_positions[i] == probe_read.contact_positions[t][i]
)
assert all(probe_orig.contact_positions[i] == probe_read.contact_positions[t][i])
for i in range(len(probe.contact_plane_axes)):
for dim in range(len(probe.contact_plane_axes[i])):
assert all(
probe_orig.contact_plane_axes[i][dim]
== probe_read.contact_plane_axes[t][i][dim]
)
assert all(probe_orig.contact_plane_axes[i][dim] == probe_read.contact_plane_axes[t][i][dim])


def test_BIDS_format_empty(tmp_path):
Expand Down Expand Up @@ -218,6 +204,7 @@ def test_prb(tmp_path):
# plot_probe(probe)
# plt.show()


if __name__ == "__main__":
# test_probeinterface_format()
# test_BIDS_format()
Expand Down
22 changes: 5 additions & 17 deletions tests/test_io/test_openephys.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,15 @@ def test_NP_Ultra():
assert len(np.unique(probeD.contact_positions[:, 0])) == 1



def test_NP1_subset():
# NP1 - 200 channels selected by recording_state in Record Node
probe_ap = read_openephys(
data_path / "OE_Neuropix-PXI-subset" / "settings.xml", stream_name="ProbeA-AP"
)
probe_ap = read_openephys(data_path / "OE_Neuropix-PXI-subset" / "settings.xml", stream_name="ProbeA-AP")

assert probe_ap.get_shank_count() == 1
assert "1.0" in probe_ap.model_name
assert probe_ap.get_contact_count() == 200

probe_lf = read_openephys(
data_path / "OE_Neuropix-PXI-subset" / "settings.xml", stream_name="ProbeA-LFP"
)
probe_lf = read_openephys(data_path / "OE_Neuropix-PXI-subset" / "settings.xml", stream_name="ProbeA-LFP")

assert probe_lf.get_shank_count() == 1
assert "1.0" in probe_lf.model_name
Expand All @@ -88,9 +83,7 @@ def test_NP1_subset():

def test_multiple_probes():
# multiple probes
probeA = read_openephys(
data_path / "OE_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="ProbeA"
)
probeA = read_openephys(data_path / "OE_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="ProbeA")

assert probeA.get_shank_count() == 1
assert "1.0" in probeA.model_name
Expand All @@ -109,9 +102,7 @@ def test_multiple_probes():

assert probeC.get_shank_count() == 1

probeD = read_openephys(
data_path / "OE_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="ProbeD"
)
probeD = read_openephys(data_path / "OE_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="ProbeD")

assert probeD.get_shank_count() == 1

Expand Down Expand Up @@ -148,13 +139,10 @@ def test_np_opto_with_sync():
assert probe.get_contact_count() == 384



def test_older_than_06_format():
## Test with the open ephys < 0.6 format

probe = read_openephys(
data_path / "OE_5_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="100.0"
)
probe = read_openephys(data_path / "OE_5_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="100.0")

assert probe.get_shank_count() == 4
assert "2.0 - Four Shank" in probe.model_name
Expand Down
15 changes: 8 additions & 7 deletions tests/test_io/test_spikeglx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

data_path = Path(__file__).absolute().parent.parent / "data" / "spikeglx"


def test_parse_meta():
for meta_file in [
"doppio-checkerboard_t0.imec0.ap.meta",
Expand All @@ -19,19 +20,17 @@ def test_parse_meta():
]:
meta = parse_spikeglx_meta(data_path / meta_file)


def test_get_saved_channel_indices_from_spikeglx_meta():
# all channel saved + 1 synchro
chan_inds = get_saved_channel_indices_from_spikeglx_meta(
data_path / "Noise_g0_t0.imec0.ap.meta"
)
chan_inds = get_saved_channel_indices_from_spikeglx_meta(data_path / "Noise_g0_t0.imec0.ap.meta")
assert chan_inds.size == 385

# example by Pierre Yger NP1.0 with 384 but only 151 channels are saved + 1 synchro
chan_inds = get_saved_channel_indices_from_spikeglx_meta(
data_path / "NP1_saved_only_subset_of_channels.meta"
)
chan_inds = get_saved_channel_indices_from_spikeglx_meta(data_path / "NP1_saved_only_subset_of_channels.meta")
assert chan_inds.size == 152


def test_NP1():
probe = read_spikeglx(data_path / "Noise_g0_t0.imec0.ap.meta")
assert "1.0" in probe.model_name
Expand Down Expand Up @@ -187,6 +186,7 @@ def tes_NP1_384_channels():
assert probe.get_contact_count() == 151
assert 152 not in probe.contact_annotations["channel_ids"]


def test_NPH_long_staggered():
# Data provided by Nate Dolensek
probe = read_spikeglx(data_path / "non_human_primate_long_staggered.imec0.ap.meta")
Expand Down Expand Up @@ -242,6 +242,7 @@ def test_NPH_long_staggered():
assert np.allclose(references, 0)
assert np.allclose(filters, 1)


def test_NPH_short_linear_probe_type_0():
# Data provided by Jonathan A Michaels
probe = read_spikeglx(data_path / "non_human_primate_short_linear_probe_type_0.meta")
Expand All @@ -254,7 +255,6 @@ def test_NPH_short_linear_probe_type_0():
assert probe.get_shank_count() == 1
assert probe.get_contact_count() == 384


# Test contact geometry
x_pitch = 56.0
y_pitch = 20.0
Expand Down Expand Up @@ -320,6 +320,7 @@ def test_ultra_probe():
unique_y_values = np.unique(y)
assert unique_y_values.size == expected_electode_rows


def test_CatGT_NP1():
probe = read_spikeglx(data_path / "catgt.meta")
assert "1.0" in probe.model_name
Expand Down
12 changes: 6 additions & 6 deletions tests/test_library.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from probeinterface import Probe
from probeinterface.library import (download_probeinterface_file,
get_from_cache, get_probe)
from probeinterface.library import download_probeinterface_file, get_from_cache, get_probe


from pathlib import Path
Expand All @@ -9,19 +8,20 @@
import pytest


manufacturer = 'neuronexus'
probe_name = 'A1x32-Poly3-10mm-50-177'
manufacturer = "neuronexus"
probe_name = "A1x32-Poly3-10mm-50-177"


def test_download_probeinterface_file():
download_probeinterface_file(manufacturer, probe_name)


def test_get_from_cache():
download_probeinterface_file(manufacturer, probe_name)
probe = get_from_cache(manufacturer, probe_name)
assert isinstance(probe, Probe)

probe = get_from_cache('yep', 'yop')
probe = get_from_cache("yep", "yop")
assert probe is None


Expand All @@ -31,7 +31,7 @@ def test_get_probe():
assert probe.get_contact_count() == 32


if __name__ == '__main__':
if __name__ == "__main__":
test_download_probeinterface_file()
test_get_from_cache()
test_get_probe()
6 changes: 3 additions & 3 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ def test_plot_probe():
plot_probe(probe)
plot_probe(probe, with_contact_id=True)
plot_probe(probe, with_device_index=True)
plot_probe(probe, text_on_contact=['abcde'[i%5] for i in range(probe.get_contact_count())])
plot_probe(probe, text_on_contact=["abcde"[i % 5] for i in range(probe.get_contact_count())])

# with color
n = probe.get_contact_count()
contacts_colors = np.random.rand(n, 3)
plot_probe(probe, contacts_colors=contacts_colors)

# 3d
probe_3d = probe.to_3d(axes='xz')
probe_3d = probe.to_3d(axes="xz")
plot_probe(probe_3d)

# on click
Expand All @@ -43,7 +43,7 @@ def test_plot_probe_group():
plot_probe_group(probegroup_3d, same_axes=True)


if __name__ == '__main__':
if __name__ == "__main__":
test_plot_probe()
# test_plot_probe_group()
plt.show()
Loading