6
6
BaseSorting ,
7
7
WaveformExtractor ,
8
8
NumpySorting ,
9
- NpzSortingExtractor ,
10
- InjectTemplatesRecording ,
11
9
)
12
10
from spikeinterface .core .core_tools import define_function_from_class
13
- from spikeinterface .core import generate_sorting
11
+ from spikeinterface .core . generate import generate_sorting , InjectTemplatesRecording , _ensure_seed
14
12
15
13
16
14
class HybridUnitsRecording (InjectTemplatesRecording ):
@@ -60,6 +58,7 @@ def __init__(
60
58
amplitude_std : float = 0.0 ,
61
59
refractory_period_ms : float = 2.0 ,
62
60
injected_sorting_folder : Union [str , Path , None ] = None ,
61
+ seed = None ,
63
62
):
64
63
num_samples = [
65
64
parent_recording .get_num_frames (seg_index ) for seg_index in range (parent_recording .get_num_segments ())
@@ -80,8 +79,8 @@ def __init__(
80
79
num_units = len (templates ),
81
80
sampling_frequency = fs ,
82
81
durations = durations ,
83
- firing_rate = firing_rate ,
84
- refractory_period = refractory_period_ms ,
82
+ firing_rates = firing_rate ,
83
+ refractory_period_ms = refractory_period_ms ,
85
84
)
86
85
# save injected sorting if necessary
87
86
self .injected_sorting = injected_sorting
@@ -90,17 +89,10 @@ def __init__(
90
89
self .injected_sorting = self .injected_sorting .save (folder = injected_sorting_folder )
91
90
92
91
if amplitude_factor is None :
93
- amplitude_factor = [
94
- [
95
- np .random .normal (
96
- loc = 1.0 ,
97
- scale = amplitude_std ,
98
- size = len (self .injected_sorting .get_unit_spike_train (unit_id , segment_index = seg_index )),
99
- )
100
- for unit_id in self .injected_sorting .unit_ids
101
- ]
102
- for seg_index in range (parent_recording .get_num_segments ())
103
- ]
92
+ seed = _ensure_seed (seed )
93
+ rng = np .random .default_rng (seed = seed )
94
+ num_spikes = self .injected_sorting .to_spike_vector ().size
95
+ amplitude_factor = rng .normal (loc = 1.0 , scale = amplitude_std , size = num_spikes )
104
96
105
97
InjectTemplatesRecording .__init__ (
106
98
self , self .injected_sorting , templates , nbefore , amplitude_factor , parent_recording , num_samples
@@ -116,6 +108,7 @@ def __init__(
116
108
amplitude_std = amplitude_std ,
117
109
refractory_period_ms = refractory_period_ms ,
118
110
injected_sorting_folder = None ,
111
+ seed = seed ,
119
112
)
120
113
121
114
0 commit comments