Skip to content

Commit

Permalink
units aggergation should preserve ids
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin committed Jul 10, 2024
1 parent bd9cd1f commit b7c4309
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 4 deletions.
46 changes: 43 additions & 3 deletions src/spikeinterface/core/tests/test_unitsaggregationsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from spikeinterface.core import NpzSortingExtractor
from spikeinterface.core import create_sorting_npz
from spikeinterface.core import generate_sorting


def test_unitsaggregationsorting(create_cache_folder):
Expand Down Expand Up @@ -33,10 +34,12 @@ def test_unitsaggregationsorting(create_cache_folder):
spiketrain1_1 = sorting1.get_unit_spike_train(unit_ids[1], segment_index=seg)
spiketrains2_0 = sorting2.get_unit_spike_train(unit_ids[0], segment_index=seg)
spiketrains3_2 = sorting3.get_unit_spike_train(unit_ids[2], segment_index=seg)
assert np.allclose(spiketrain1_1, sorting_agg.get_unit_spike_train(unit_ids[1], segment_index=seg))
assert np.allclose(spiketrains2_0, sorting_agg.get_unit_spike_train(num_units + unit_ids[0], segment_index=seg))
assert np.allclose(spiketrain1_1, sorting_agg.get_unit_spike_train(str(unit_ids[1]), segment_index=seg))
assert np.allclose(
spiketrains3_2, sorting_agg.get_unit_spike_train(2 * num_units + unit_ids[2], segment_index=seg)
spiketrains2_0, sorting_agg.get_unit_spike_train(str(num_units + unit_ids[0]), segment_index=seg)
)
assert np.allclose(
spiketrains3_2, sorting_agg.get_unit_spike_train(str(2 * num_units + unit_ids[2]), segment_index=seg)
)

# test rename units
Expand Down Expand Up @@ -92,5 +95,42 @@ def test_unitsaggregationsorting(create_cache_folder):
print(sorting_agg_prop.get_property("brain_area"))


def test_unit_aggregation_preserve_ids():

sorting1 = generate_sorting(num_units=3)
sorting1 = sorting1.rename_units(new_unit_ids=["unit1", "unit2", "unit3"])

sorting2 = generate_sorting(num_units=3)
sorting2 = sorting2.rename_units(new_unit_ids=["unit4", "unit5", "unit6"])

aggregated_sorting = aggregate_units([sorting1, sorting2])
assert aggregated_sorting.get_num_units() == 6
assert list(aggregated_sorting.get_unit_ids()) == ["unit1", "unit2", "unit3", "unit4", "unit5", "unit6"]


def test_unit_aggregation_does_not_preserve_ids_if_not_unique():
sorting1 = generate_sorting(num_units=3)
sorting1 = sorting1.rename_units(new_unit_ids=["unit1", "unit2", "unit3"])

sorting2 = generate_sorting(num_units=3)
sorting2 = sorting2.rename_units(new_unit_ids=["unit1", "unit2", "unit3"])

aggregated_sorting = aggregate_units([sorting1, sorting2])
assert aggregated_sorting.get_num_units() == 6
assert list(aggregated_sorting.get_unit_ids()) == ["0", "1", "2", "3", "4", "5"]


def test_unit_aggregation_does_not_preserve_ids_not_the_same_type():
sorting1 = generate_sorting(num_units=3)
sorting1 = sorting1.rename_units(new_unit_ids=["unit1", "unit2", "unit3"])

sorting2 = generate_sorting(num_units=2)
sorting2 = sorting2.rename_units(new_unit_ids=[1, 2])

aggregated_sorting = aggregate_units([sorting1, sorting2])
assert aggregated_sorting.get_num_units() == 5
assert list(aggregated_sorting.get_unit_ids()) == ["0", "1", "2", "3", "4"]


if __name__ == "__main__":
test_unitsaggregationsorting()
12 changes: 11 additions & 1 deletion src/spikeinterface/core/unitsaggregationsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,17 @@ def __init__(self, sorting_list, renamed_unit_ids=None):
)
unit_ids = list(renamed_unit_ids)
else:
unit_ids = list(np.arange(num_all_units))
all_ids_are_same_type = np.unique([sort.get_unit_ids().dtype for sort in sorting_list]).size == 1
all_units_ids_are_unique = False
if all_ids_are_same_type:
combined_ids = np.concatenate([sort.get_unit_ids() for sort in sorting_list])
all_units_ids_are_unique = np.unique(combined_ids).size == num_all_units

if all_ids_are_same_type and all_units_ids_are_unique:
unit_ids = combined_ids
else:
default_unit_ids = [str(i) for i in range(num_all_units)]
unit_ids = default_unit_ids

# unit map maps unit ids that are used to get spike trains
u_id = 0
Expand Down

0 comments on commit b7c4309

Please sign in to comment.