diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 88c63ea30a..ec76fcbaa9 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1,5 +1,4 @@ from __future__ import annotations - import math import warnings import numpy as np @@ -908,9 +907,9 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No return sorting_with_dup -def inject_some_split_units(sorting, split_ids=[], num_split=2, output_ids=False, seed=None): +def inject_some_split_units(sorting, split_ids: list, num_split=2, output_ids=False, seed=None): """ """ - assert len(split_ids) > 0, "you need to provide some ids to split" + unit_ids = sorting.unit_ids assert unit_ids.dtype.kind == "i" @@ -1444,7 +1443,8 @@ def generate_templates( seed=None, dtype="float32", upsample_factor=None, - unit_params=dict(), + unit_params=None, + unit_params_range=None, mode="ellipsoid", ): """ @@ -1501,6 +1501,9 @@ def generate_templates( * (num_units, num_samples, num_channels, upsample_factor) if upsample_factor is not None """ + + unit_params = unit_params or dict() + unit_params_range = unit_params_range or dict() rng = np.random.default_rng(seed=seed) # neuron location must be 3D @@ -1968,7 +1971,7 @@ def generate_ground_truth_recording( generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=4.0), noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"), generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0, minimum_distance=20), - generate_templates_kwargs=dict(), + generate_templates_kwargs=None, dtype="float32", seed=None, ): @@ -2027,6 +2030,7 @@ def generate_ground_truth_recording( sorting: Sorting The generated sorting extractor. """ + generate_templates_kwargs = generate_templates_kwargs or dict() # TODO implement upsample_factor in InjectTemplatesRecording and propagate into toy_example diff --git a/src/spikeinterface/curation/tests/test_auto_merge.py b/src/spikeinterface/curation/tests/test_auto_merge.py index f8dea5b270..bbf861dac9 100644 --- a/src/spikeinterface/curation/tests/test_auto_merge.py +++ b/src/spikeinterface/curation/tests/test_auto_merge.py @@ -25,14 +25,15 @@ def test_get_auto_merge_list(sorting_analyzer_for_curation): num_unit_splited = 1 num_split = 2 + split_ids = sorting.unit_ids[:num_unit_splited] sorting_with_split, other_ids = inject_some_split_units( - sorting, split_ids=sorting.unit_ids[:num_unit_splited], num_split=num_split, output_ids=True, seed=42 + sorting, + split_ids=split_ids, + num_split=num_split, + output_ids=True, + seed=42, ) - # print(sorting_with_split) - # print(sorting_with_split.unit_ids) - # print(other_ids) - job_kwargs = dict(n_jobs=-1) sorting_analyzer = create_sorting_analyzer(sorting_with_split, recording, format="memory") @@ -59,8 +60,6 @@ def test_get_auto_merge_list(sorting_analyzer_for_curation): firing_contamination_balance=1.5, extra_outputs=True, ) - # print(potential_merges) - # print(num_unit_splited) assert len(potential_merges) == num_unit_splited for true_pair in other_ids.values():