Skip to content

Commit

Permalink
Merge pull request #2345 from h-mayorquin/remove_wrong_expressions_in…
Browse files Browse the repository at this point in the history
…_generate

Remove default values used as expressions in `generate.py`.
  • Loading branch information
alejoe91 authored Apr 12, 2024
2 parents f0bcd4f + 82e4e31 commit 49ae822
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 12 deletions.
14 changes: 9 additions & 5 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from __future__ import annotations

import math
import warnings
import numpy as np
Expand Down Expand Up @@ -908,9 +907,9 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No
return sorting_with_dup


def inject_some_split_units(sorting, split_ids=[], num_split=2, output_ids=False, seed=None):
def inject_some_split_units(sorting, split_ids: list, num_split=2, output_ids=False, seed=None):
""" """
assert len(split_ids) > 0, "you need to provide some ids to split"

unit_ids = sorting.unit_ids
assert unit_ids.dtype.kind == "i"

Expand Down Expand Up @@ -1444,7 +1443,8 @@ def generate_templates(
seed=None,
dtype="float32",
upsample_factor=None,
unit_params=dict(),
unit_params=None,
unit_params_range=None,
mode="ellipsoid",
):
"""
Expand Down Expand Up @@ -1501,6 +1501,9 @@ def generate_templates(
* (num_units, num_samples, num_channels, upsample_factor) if upsample_factor is not None
"""

unit_params = unit_params or dict()
unit_params_range = unit_params_range or dict()
rng = np.random.default_rng(seed=seed)

# neuron location must be 3D
Expand Down Expand Up @@ -1968,7 +1971,7 @@ def generate_ground_truth_recording(
generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=4.0),
noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"),
generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0, minimum_distance=20),
generate_templates_kwargs=dict(),
generate_templates_kwargs=None,
dtype="float32",
seed=None,
):
Expand Down Expand Up @@ -2027,6 +2030,7 @@ def generate_ground_truth_recording(
sorting: Sorting
The generated sorting extractor.
"""
generate_templates_kwargs = generate_templates_kwargs or dict()

# TODO implement upsample_factor in InjectTemplatesRecording and propagate into toy_example

Expand Down
13 changes: 6 additions & 7 deletions src/spikeinterface/curation/tests/test_auto_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ def test_get_auto_merge_list(sorting_analyzer_for_curation):
num_unit_splited = 1
num_split = 2

split_ids = sorting.unit_ids[:num_unit_splited]
sorting_with_split, other_ids = inject_some_split_units(
sorting, split_ids=sorting.unit_ids[:num_unit_splited], num_split=num_split, output_ids=True, seed=42
sorting,
split_ids=split_ids,
num_split=num_split,
output_ids=True,
seed=42,
)

# print(sorting_with_split)
# print(sorting_with_split.unit_ids)
# print(other_ids)

job_kwargs = dict(n_jobs=-1)

sorting_analyzer = create_sorting_analyzer(sorting_with_split, recording, format="memory")
Expand All @@ -59,8 +60,6 @@ def test_get_auto_merge_list(sorting_analyzer_for_curation):
firing_contamination_balance=1.5,
extra_outputs=True,
)
# print(potential_merges)
# print(num_unit_splited)

assert len(potential_merges) == num_unit_splited
for true_pair in other_ids.values():
Expand Down

0 comments on commit 49ae822

Please sign in to comment.