1
- import pytest
1
+ import shutil
2
2
from pathlib import Path
3
3
4
- from spikewrap .structure ._preprocess_run import PreprocessedRun
5
- import spikewrap .visualise
6
- import numpy as np
7
4
import matplotlib .pyplot as plt
8
- import shutil
5
+ import numpy as np
6
+ import pytest
7
+
8
+ from spikewrap .structure ._preprocess_run import PreprocessedRun
9
+
9
10
10
11
@pytest .fixture
11
12
def mock_preprocessed_run (tmp_path , monkeypatch ):
12
13
"""
13
14
Fixture to create a temporary PreprocessedRun instance with mock data.
14
15
"""
15
16
16
- from spikewrap .structure ._preprocess_run import PreprocessedRun
17
-
18
17
def mock_plot (* args , ** kwargs ):
19
18
pass
20
-
19
+
21
20
def mock_figure (* args , ** kwargs ):
22
21
class MockFigure :
23
22
def savefig (self , path ):
24
23
Path (path ).parent .mkdir (parents = True , exist_ok = True )
25
24
Path (path ).touch ()
26
-
25
+
27
26
def clf (self ):
28
27
pass
29
-
28
+
30
29
return MockFigure ()
31
-
30
+
32
31
def mock_visualise (* args , ** kwargs ):
33
32
return mock_figure ()
34
-
33
+
35
34
monkeypatch .setattr (plt , "figure" , mock_figure )
36
35
monkeypatch .setattr (plt , "plot" , mock_plot )
37
36
monkeypatch .setattr (plt , "subplot" , lambda * args , ** kwargs : None )
38
37
monkeypatch .setattr (plt , "title" , lambda * args , ** kwargs : None )
39
-
38
+
40
39
import sys
40
+
41
41
module_name = PreprocessedRun .__module__
42
42
module = sys .modules [module_name ]
43
43
monkeypatch .setattr (module , "visualise_run_preprocessed" , mock_visualise )
44
-
44
+
45
45
class MockRecording :
46
46
def __init__ (self ):
47
47
self .properties = {}
48
- self .data = np .random .random ((10 , 1000 ))
49
-
48
+ self .data = np .random .random ((10 , 1000 ))
49
+
50
50
def save (self , folder , chunk_duration ):
51
51
Path (folder ).mkdir (parents = True , exist_ok = True )
52
52
(Path (folder ) / "mock_recording_saved.txt" ).touch ()
53
53
return True
54
-
54
+
55
55
def get_property (self , property_name ):
56
56
return self .properties .get (property_name , [])
57
-
57
+
58
58
def get_traces (self , * args , ** kwargs ):
59
59
return self .data
60
-
60
+
61
61
def __array__ (self ):
62
62
return self .data
63
-
63
+
64
64
# Set up a mock recording with bad channels
65
65
mock_recording = MockRecording ()
66
- mock_recording .properties ["bad_channels" ] = [0 , 1 ]
67
-
66
+ mock_recording .properties ["bad_channels" ] = [0 , 1 ]
67
+
68
68
raw_data_path = tmp_path / "raw_data"
69
69
session_output_path = tmp_path / "output"
70
70
run_name = "test_run"
71
-
71
+
72
72
preprocessed_data = {"shank_0" : {"0" : mock_recording , "1" : mock_recording }}
73
73
74
- raw_data_path .mkdir (parents = True , exist_ok = True )
74
+ raw_data_path .mkdir (parents = True , exist_ok = True )
75
75
session_output_path .mkdir (parents = True , exist_ok = True )
76
-
76
+
77
77
preprocessed_path = session_output_path / run_name / "preprocessed"
78
78
preprocessed_path .mkdir (parents = True , exist_ok = True )
79
-
79
+
80
80
diagnostic_path = session_output_path / "diagnostic_plots"
81
81
diagnostic_path .mkdir (parents = True , exist_ok = True )
82
-
82
+
83
83
preprocessed_run = PreprocessedRun (
84
84
raw_data_path = raw_data_path ,
85
85
ses_name = "test_session" ,
@@ -89,20 +89,24 @@ def __array__(self):
89
89
preprocessed_data = preprocessed_data ,
90
90
pp_steps = {"step_1" : "bad_channel_detection" },
91
91
)
92
-
92
+
93
93
def mock_save_diagnostic_plots (self ):
94
94
diagnostic_path = self ._output_path / "diagnostic_plots"
95
95
diagnostic_path .mkdir (parents = True , exist_ok = True )
96
-
96
+
97
97
for shank_name in self ._preprocessed :
98
98
(diagnostic_path / f"{ shank_name } _before_detection.png" ).touch ()
99
99
(diagnostic_path / f"{ shank_name } _after_detection.png" ).touch ()
100
-
100
+
101
101
for ch in [0 , 1 ]:
102
102
(diagnostic_path / f"{ shank_name } _bad_channel_{ ch } .png" ).touch ()
103
-
103
+
104
104
# Monkeypatch the method to create placeholder files instead of real plots
105
- monkeypatch .setattr (preprocessed_run , "_save_diagnostic_plots" , mock_save_diagnostic_plots .__get__ (preprocessed_run ))
105
+ monkeypatch .setattr (
106
+ preprocessed_run ,
107
+ "_save_diagnostic_plots" ,
108
+ mock_save_diagnostic_plots .__get__ (preprocessed_run ),
109
+ )
106
110
107
111
yield preprocessed_run
108
112
@@ -120,10 +124,14 @@ def test_diagnostic_plots_saved(self, mock_preprocessed_run):
120
124
121
125
if output_dir .exists ():
122
126
shutil .rmtree (output_dir )
123
- assert not output_dir .exists (), "Diagnostic plots directory should not exist before running save_preprocessed"
127
+ assert (
128
+ not output_dir .exists ()
129
+ ), "Diagnostic plots directory should not exist before running save_preprocessed"
124
130
125
131
# Should trigger the diagnostic plot saving
126
- mock_preprocessed_run .save_preprocessed (overwrite = True , chunk_duration_s = 1.0 , n_jobs = 1 , slurm = False )
132
+ mock_preprocessed_run .save_preprocessed (
133
+ overwrite = True , chunk_duration_s = 1.0 , n_jobs = 1 , slurm = False
134
+ )
127
135
128
136
assert output_dir .exists (), "Diagnostic plots directory was not created"
129
137
shank_name = "shank_0"
0 commit comments