Skip to content

Commit 23aef27

Browse files
authored
Merge pull request #1948 from samuelgarcia/generator
Refactor generate.py
2 parents 3817ee0 + 20f5108 commit 23aef27

File tree

14 files changed

+1466
-1124
lines changed

14 files changed

+1466
-1124
lines changed

src/spikeinterface/comparison/hybrid.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@
66
BaseSorting,
77
WaveformExtractor,
88
NumpySorting,
9-
NpzSortingExtractor,
10-
InjectTemplatesRecording,
119
)
1210
from spikeinterface.core.core_tools import define_function_from_class
13-
from spikeinterface.core import generate_sorting
11+
from spikeinterface.core.generate import generate_sorting, InjectTemplatesRecording, _ensure_seed
1412

1513

1614
class HybridUnitsRecording(InjectTemplatesRecording):
@@ -60,6 +58,7 @@ def __init__(
6058
amplitude_std: float = 0.0,
6159
refractory_period_ms: float = 2.0,
6260
injected_sorting_folder: Union[str, Path, None] = None,
61+
seed=None,
6362
):
6463
num_samples = [
6564
parent_recording.get_num_frames(seg_index) for seg_index in range(parent_recording.get_num_segments())
@@ -80,8 +79,8 @@ def __init__(
8079
num_units=len(templates),
8180
sampling_frequency=fs,
8281
durations=durations,
83-
firing_rate=firing_rate,
84-
refractory_period=refractory_period_ms,
82+
firing_rates=firing_rate,
83+
refractory_period_ms=refractory_period_ms,
8584
)
8685
# save injected sorting if necessary
8786
self.injected_sorting = injected_sorting
@@ -90,17 +89,10 @@ def __init__(
9089
self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder)
9190

9291
if amplitude_factor is None:
93-
amplitude_factor = [
94-
[
95-
np.random.normal(
96-
loc=1.0,
97-
scale=amplitude_std,
98-
size=len(self.injected_sorting.get_unit_spike_train(unit_id, segment_index=seg_index)),
99-
)
100-
for unit_id in self.injected_sorting.unit_ids
101-
]
102-
for seg_index in range(parent_recording.get_num_segments())
103-
]
92+
seed = _ensure_seed(seed)
93+
rng = np.random.default_rng(seed=seed)
94+
num_spikes = self.injected_sorting.to_spike_vector().size
95+
amplitude_factor = rng.normal(loc=1.0, scale=amplitude_std, size=num_spikes)
10496

10597
InjectTemplatesRecording.__init__(
10698
self, self.injected_sorting, templates, nbefore, amplitude_factor, parent_recording, num_samples
@@ -116,6 +108,7 @@ def __init__(
116108
amplitude_std=amplitude_std,
117109
refractory_period_ms=refractory_period_ms,
118110
injected_sorting_folder=None,
111+
seed=seed,
119112
)
120113

121114

src/spikeinterface/core/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@
3434
inject_some_duplicate_units,
3535
inject_some_split_units,
3636
synthetize_spike_train_bad_isi,
37+
generate_templates,
38+
NoiseGeneratorRecording,
39+
noise_generator_recording,
40+
generate_recording_by_size,
41+
InjectTemplatesRecording,
42+
inject_templates,
43+
generate_ground_truth_recording,
3744
)
3845

3946
# utils to append and concatenate segment (equivalent to OLD MultiRecordingTimeExtractor)
@@ -109,7 +116,7 @@
109116
)
110117

111118
# templates addition
112-
from .injecttemplates import InjectTemplatesRecording, InjectTemplatesRecordingSegment, inject_templates
119+
# from .injecttemplates import InjectTemplatesRecording, InjectTemplatesRecordingSegment, inject_templates
113120

114121
# template tools
115122
from .template_tools import (

0 commit comments

Comments
 (0)