Skip to content

Commit ee744f8

Browse files
authored
Merge pull request #249 from h-mayorquin/black_test
Black format tests
2 parents e4f36a1 + bc23643 commit ee744f8

13 files changed

+118
-146
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ repos:
99
rev: 24.1.1
1010
hooks:
1111
- id: black
12-
files: ^src/
12+
files: ^src/|^tests/

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ test = [
4545
"scipy",
4646
"pandas",
4747
"h5py",
48-
]
48+
]
4949

5050
docs = [
5151
"pillow",

tests/test_generator.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1-
from probeinterface import (generate_dummy_probe, generate_dummy_probe_group,
2-
generate_tetrode, generate_linear_probe, generate_multi_columns_probe,
3-
generate_multi_shank)
1+
from probeinterface import (
2+
generate_dummy_probe,
3+
generate_dummy_probe_group,
4+
generate_tetrode,
5+
generate_linear_probe,
6+
generate_multi_columns_probe,
7+
generate_multi_shank,
8+
)
49

510

611
from pathlib import Path
@@ -15,20 +20,19 @@ def test_generate():
1520

1621
tetrode = generate_tetrode()
1722

18-
multi_columns = generate_multi_columns_probe(num_columns=3,
19-
num_contact_per_column=[10, 12, 10],
20-
xpitch=22, ypitch=20,
21-
y_shift_per_column=[0, -10, 0])
23+
multi_columns = generate_multi_columns_probe(
24+
num_columns=3, num_contact_per_column=[10, 12, 10], xpitch=22, ypitch=20, y_shift_per_column=[0, -10, 0]
25+
)
2226

23-
linear = generate_linear_probe(num_elec=16, ypitch=20,
24-
contact_shapes='square', contact_shape_params={'width': 15})
27+
linear = generate_linear_probe(num_elec=16, ypitch=20, contact_shapes="square", contact_shape_params={"width": 15})
2528

2629
multi_shank = generate_multi_shank()
2730

28-
#~ from probeinterface.plotting import plot_probe_group, plot_probe
29-
#~ import matplotlib.pyplot as plt
30-
#~ plot_probe(multi_shank, with_contact_id=True,)
31-
#~ plt.show()
31+
# ~ from probeinterface.plotting import plot_probe_group, plot_probe
32+
# ~ import matplotlib.pyplot as plt
33+
# ~ plot_probe(multi_shank, with_contact_id=True,)
34+
# ~ plt.show()
3235

33-
if __name__ == '__main__':
36+
37+
if __name__ == "__main__":
3438
test_generate()

tests/test_io/test_io.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def test_probeinterface_format(tmp_path):
4343
# ~ plot_probe_group(probegroup2, with_contact_id=True, same_axes=False)
4444
# ~ plt.show()
4545

46+
4647
def test_writeprobeinterface(tmp_path):
4748
probe = generate_dummy_probe()
4849
file_path = tmp_path / "test.prb"
@@ -61,7 +62,6 @@ def test_writeprobeinterface_raises_error_with_bad_input(tmp_path):
6162
write_probeinterface(file_path, probe)
6263

6364

64-
6565
def test_BIDS_format(tmp_path):
6666
folder_path = tmp_path / "test_BIDS"
6767
folder_path.mkdir()
@@ -77,9 +77,7 @@ def test_BIDS_format(tmp_path):
7777
# with BIDS specifications
7878
n_els = sum([p.get_contact_count() for p in probegroup.probes])
7979
# using np.random.choice to ensure uniqueness of contact ids
80-
el_ids = np.random.choice(
81-
np.arange(1e4, 1e5, dtype="int"), replace=False, size=n_els
82-
).astype(str)
80+
el_ids = np.random.choice(np.arange(1e4, 1e5, dtype="int"), replace=False, size=n_els).astype(str)
8381
for probe in probegroup.probes:
8482
probe_el_ids, el_ids = np.split(el_ids, [probe.get_contact_count()])
8583
probe.set_contact_ids(probe_el_ids)
@@ -102,12 +100,7 @@ def test_BIDS_format(tmp_path):
102100
assert all(np.isin(probe_orig.contact_ids, probe_read.contact_ids))
103101

104102
# the transformation of contact order between the two probes
105-
t = np.array(
106-
[
107-
list(probe_read.contact_ids).index(elid)
108-
for elid in probe_orig.contact_ids
109-
]
110-
)
103+
t = np.array([list(probe_read.contact_ids).index(elid) for elid in probe_orig.contact_ids])
111104

112105
assert all(probe_orig.contact_ids == probe_read.contact_ids[t])
113106
assert all(probe_orig.shank_ids == probe_read.shank_ids[t])
@@ -116,21 +109,14 @@ def test_BIDS_format(tmp_path):
116109
assert probe_orig.si_units == probe_read.si_units
117110

118111
for i in range(len(probe_orig.probe_planar_contour)):
119-
assert all(
120-
probe_orig.probe_planar_contour[i] == probe_read.probe_planar_contour[i]
121-
)
112+
assert all(probe_orig.probe_planar_contour[i] == probe_read.probe_planar_contour[i])
122113
for sid, shape_params in enumerate(probe_orig.contact_shape_params):
123114
assert shape_params == probe_read.contact_shape_params[t][sid]
124115
for i in range(len(probe_orig.contact_positions)):
125-
assert all(
126-
probe_orig.contact_positions[i] == probe_read.contact_positions[t][i]
127-
)
116+
assert all(probe_orig.contact_positions[i] == probe_read.contact_positions[t][i])
128117
for i in range(len(probe.contact_plane_axes)):
129118
for dim in range(len(probe.contact_plane_axes[i])):
130-
assert all(
131-
probe_orig.contact_plane_axes[i][dim]
132-
== probe_read.contact_plane_axes[t][i][dim]
133-
)
119+
assert all(probe_orig.contact_plane_axes[i][dim] == probe_read.contact_plane_axes[t][i][dim])
134120

135121

136122
def test_BIDS_format_empty(tmp_path):
@@ -218,6 +204,7 @@ def test_prb(tmp_path):
218204
# plot_probe(probe)
219205
# plt.show()
220206

207+
221208
if __name__ == "__main__":
222209
# test_probeinterface_format()
223210
# test_BIDS_format()

tests/test_io/test_openephys.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,15 @@ def test_NP_Ultra():
6161
assert len(np.unique(probeD.contact_positions[:, 0])) == 1
6262

6363

64-
6564
def test_NP1_subset():
6665
# NP1 - 200 channels selected by recording_state in Record Node
67-
probe_ap = read_openephys(
68-
data_path / "OE_Neuropix-PXI-subset" / "settings.xml", stream_name="ProbeA-AP"
69-
)
66+
probe_ap = read_openephys(data_path / "OE_Neuropix-PXI-subset" / "settings.xml", stream_name="ProbeA-AP")
7067

7168
assert probe_ap.get_shank_count() == 1
7269
assert "1.0" in probe_ap.model_name
7370
assert probe_ap.get_contact_count() == 200
7471

75-
probe_lf = read_openephys(
76-
data_path / "OE_Neuropix-PXI-subset" / "settings.xml", stream_name="ProbeA-LFP"
77-
)
72+
probe_lf = read_openephys(data_path / "OE_Neuropix-PXI-subset" / "settings.xml", stream_name="ProbeA-LFP")
7873

7974
assert probe_lf.get_shank_count() == 1
8075
assert "1.0" in probe_lf.model_name
@@ -88,9 +83,7 @@ def test_NP1_subset():
8883

8984
def test_multiple_probes():
9085
# multiple probes
91-
probeA = read_openephys(
92-
data_path / "OE_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="ProbeA"
93-
)
86+
probeA = read_openephys(data_path / "OE_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="ProbeA")
9487

9588
assert probeA.get_shank_count() == 1
9689
assert "1.0" in probeA.model_name
@@ -109,9 +102,7 @@ def test_multiple_probes():
109102

110103
assert probeC.get_shank_count() == 1
111104

112-
probeD = read_openephys(
113-
data_path / "OE_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="ProbeD"
114-
)
105+
probeD = read_openephys(data_path / "OE_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="ProbeD")
115106

116107
assert probeD.get_shank_count() == 1
117108

@@ -148,13 +139,10 @@ def test_np_opto_with_sync():
148139
assert probe.get_contact_count() == 384
149140

150141

151-
152142
def test_older_than_06_format():
153143
## Test with the open ephys < 0.6 format
154144

155-
probe = read_openephys(
156-
data_path / "OE_5_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="100.0"
157-
)
145+
probe = read_openephys(data_path / "OE_5_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="100.0")
158146

159147
assert probe.get_shank_count() == 4
160148
assert "2.0 - Four Shank" in probe.model_name

tests/test_io/test_spikeglx.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
data_path = Path(__file__).absolute().parent.parent / "data" / "spikeglx"
1313

14+
1415
def test_parse_meta():
1516
for meta_file in [
1617
"doppio-checkerboard_t0.imec0.ap.meta",
@@ -19,19 +20,17 @@ def test_parse_meta():
1920
]:
2021
meta = parse_spikeglx_meta(data_path / meta_file)
2122

23+
2224
def test_get_saved_channel_indices_from_spikeglx_meta():
2325
# all channel saved + 1 synchro
24-
chan_inds = get_saved_channel_indices_from_spikeglx_meta(
25-
data_path / "Noise_g0_t0.imec0.ap.meta"
26-
)
26+
chan_inds = get_saved_channel_indices_from_spikeglx_meta(data_path / "Noise_g0_t0.imec0.ap.meta")
2727
assert chan_inds.size == 385
2828

2929
# example by Pierre Yger NP1.0 with 384 but only 151 channels are saved + 1 synchro
30-
chan_inds = get_saved_channel_indices_from_spikeglx_meta(
31-
data_path / "NP1_saved_only_subset_of_channels.meta"
32-
)
30+
chan_inds = get_saved_channel_indices_from_spikeglx_meta(data_path / "NP1_saved_only_subset_of_channels.meta")
3331
assert chan_inds.size == 152
3432

33+
3534
def test_NP1():
3635
probe = read_spikeglx(data_path / "Noise_g0_t0.imec0.ap.meta")
3736
assert "1.0" in probe.model_name
@@ -187,6 +186,7 @@ def tes_NP1_384_channels():
187186
assert probe.get_contact_count() == 151
188187
assert 152 not in probe.contact_annotations["channel_ids"]
189188

189+
190190
def test_NPH_long_staggered():
191191
# Data provided by Nate Dolensek
192192
probe = read_spikeglx(data_path / "non_human_primate_long_staggered.imec0.ap.meta")
@@ -242,6 +242,7 @@ def test_NPH_long_staggered():
242242
assert np.allclose(references, 0)
243243
assert np.allclose(filters, 1)
244244

245+
245246
def test_NPH_short_linear_probe_type_0():
246247
# Data provided by Jonathan A Michaels
247248
probe = read_spikeglx(data_path / "non_human_primate_short_linear_probe_type_0.meta")
@@ -254,7 +255,6 @@ def test_NPH_short_linear_probe_type_0():
254255
assert probe.get_shank_count() == 1
255256
assert probe.get_contact_count() == 384
256257

257-
258258
# Test contact geometry
259259
x_pitch = 56.0
260260
y_pitch = 20.0
@@ -320,6 +320,7 @@ def test_ultra_probe():
320320
unique_y_values = np.unique(y)
321321
assert unique_y_values.size == expected_electode_rows
322322

323+
323324
def test_CatGT_NP1():
324325
probe = read_spikeglx(data_path / "catgt.meta")
325326
assert "1.0" in probe.model_name

tests/test_library.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from probeinterface import Probe
2-
from probeinterface.library import (download_probeinterface_file,
3-
get_from_cache, get_probe)
2+
from probeinterface.library import download_probeinterface_file, get_from_cache, get_probe
43

54

65
from pathlib import Path
@@ -9,19 +8,20 @@
98
import pytest
109

1110

12-
manufacturer = 'neuronexus'
13-
probe_name = 'A1x32-Poly3-10mm-50-177'
11+
manufacturer = "neuronexus"
12+
probe_name = "A1x32-Poly3-10mm-50-177"
1413

1514

1615
def test_download_probeinterface_file():
1716
download_probeinterface_file(manufacturer, probe_name)
1817

18+
1919
def test_get_from_cache():
2020
download_probeinterface_file(manufacturer, probe_name)
2121
probe = get_from_cache(manufacturer, probe_name)
2222
assert isinstance(probe, Probe)
2323

24-
probe = get_from_cache('yep', 'yop')
24+
probe = get_from_cache("yep", "yop")
2525
assert probe is None
2626

2727

@@ -31,7 +31,7 @@ def test_get_probe():
3131
assert probe.get_contact_count() == 32
3232

3333

34-
if __name__ == '__main__':
34+
if __name__ == "__main__":
3535
test_download_probeinterface_file()
3636
test_get_from_cache()
3737
test_get_probe()

tests/test_plotting.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@ def test_plot_probe():
1313
plot_probe(probe)
1414
plot_probe(probe, with_contact_id=True)
1515
plot_probe(probe, with_device_index=True)
16-
plot_probe(probe, text_on_contact=['abcde'[i%5] for i in range(probe.get_contact_count())])
16+
plot_probe(probe, text_on_contact=["abcde"[i % 5] for i in range(probe.get_contact_count())])
1717

1818
# with color
1919
n = probe.get_contact_count()
2020
contacts_colors = np.random.rand(n, 3)
2121
plot_probe(probe, contacts_colors=contacts_colors)
2222

2323
# 3d
24-
probe_3d = probe.to_3d(axes='xz')
24+
probe_3d = probe.to_3d(axes="xz")
2525
plot_probe(probe_3d)
2626

2727
# on click
@@ -43,7 +43,7 @@ def test_plot_probe_group():
4343
plot_probe_group(probegroup_3d, same_axes=True)
4444

4545

46-
if __name__ == '__main__':
46+
if __name__ == "__main__":
4747
test_plot_probe()
4848
# test_plot_probe_group()
4949
plt.show()

0 commit comments

Comments
 (0)