Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compute_synchrony_metrics update #2605

Merged
merged 29 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
e8e4ca4
Update and test compute_synchrony_metrics
chrishalcrow Mar 20, 2024
220add7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 20, 2024
cb45927
Make namedtuple and check div by zero
chrishalcrow Mar 20, 2024
b91f5f2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 20, 2024
bfbc5e3
Redo tests to use namedtuple
chrishalcrow Mar 20, 2024
f92e2a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 20, 2024
4f5b119
Merge branch 'SpikeInterface:main' into sync_counts_update
chrishalcrow Mar 26, 2024
1bb06b6
Revert generate_stuff outputs, add tests + docstrings
chrishalcrow Mar 26, 2024
0706b4d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 26, 2024
635d330
delete mentions of segments in in synthesize_things doc
chrishalcrow Mar 26, 2024
f22698b
Change "spikes" description
chrishalcrow Mar 26, 2024
56f0559
type bins for numba
chrishalcrow Mar 27, 2024
3de7c27
revert to old synthesize_stuff functions (before cleanup)
chrishalcrow Mar 27, 2024
c53fedd
Merge branch 'main' into sync_counts_update
chrishalcrow Mar 27, 2024
c9244a2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 27, 2024
7cfd0ea
oops, remove isi.py edits
chrishalcrow Mar 27, 2024
f3217f4
Merge branch 'main' into sync_counts_update
chrishalcrow Mar 27, 2024
51f0039
Merge branch 'main' into sync_counts_update
h-mayorquin Mar 27, 2024
f5d20fa
rename insertions and unpack spike train into indices, labels, segments
chrishalcrow Apr 12, 2024
23c3355
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2024
7719ee1
Merge branch 'main' into sync_counts_update
chrishalcrow Apr 12, 2024
66077e2
Update src/spikeinterface/qualitymetrics/misc_metrics.py
chrishalcrow Apr 12, 2024
11d5d39
Change to None, add some ifs
chrishalcrow Apr 15, 2024
5de7b66
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 15, 2024
378976f
Revert namedtuple name
chrishalcrow Apr 15, 2024
ecc6a9a
Merge branch 'main' into sync_counts_update
chrishalcrow Apr 15, 2024
f6652f0
Delete _add_spikes_to_spiketrain function, for simplicity
chrishalcrow Apr 15, 2024
0b9a58a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 15, 2024
16b344c
synch -> sync
chrishalcrow Apr 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,85 @@ def clean_refractory_period(times, refractory_period):
return times


def _add_spikes_to_spiketrain(
spike_indices,
spike_labels,
segment_indices=[],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
segment_indices=[],
segment_indices=None,

See #2345

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To quote Joe: "Thanks I did not know about that, I am learning so much working on the SI codebase 😍!"

added_spikes_indices=None,
added_spikes_labels=None,
added_segment_indices=[],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
added_segment_indices=[],
added_segment_indices=None,

replace=False,
seed=None,
):
"""
Add specified spikes into a spike train

Parameters
----------
spike_indices: array-like
List of sample-indices for each spike
spike_labels: array-like
List of units for each spike
segment_indices: array
List of segment numbers for each spike
added_spike_indices: array-like
List of sample-indices for added spikes
added_spike_labels: array-like
List of units for added spikes
added_segment_indices: array-like
List of segments for added spikes
replace: bool, default: False
If True, randomly replace generated spikes. If False, add to existing spike train
rng: numpy.random.Generator
A random number generator
seed: int, default: None
seed for the generator

Returns
-------
spike_train: np.ndarray
Numpy array in same form as spike_train input, including added spikes

"""

# check lengths are consistent
assert len(spike_indices) == len(spike_labels), "Length of spike indices and labels are not equal"
assert (len(segment_indices) == 0) or (
len(spike_indices) == len(segment_indices)
), "Length of spike indices and segments are not equal"
assert len(added_spikes_indices) == len(
added_spikes_labels
), "Length of added spike indices and labels are not equal"
assert (len(added_segment_indices) == 0) or (
len(added_spikes_indices) == len(added_segment_indices)
), "Length of added spike indices and segments are not equal"

new_spike_indices = np.array(spike_indices)
new_spike_labels = np.array(spike_labels)
new_spike_segments = np.array(segment_indices)

rng = np.random.default_rng(seed=seed)

if replace:
replacement_indices = rng.choice(len(spike_indices), len(added_spikes_indices), replace=False)
new_spike_indices[replacement_indices] = added_spikes_indices
new_spike_labels[replacement_indices] = added_spikes_labels
if len(segment_indices) != 0:
print(new_spike_segments[replacement_indices])
print(added_segment_indices)
new_spike_segments[replacement_indices] = added_segment_indices
else:
new_spike_indices = np.concatenate((new_spike_indices, added_spikes_indices))
new_spike_labels = np.concatenate((new_spike_labels, added_spikes_labels))
if len(segment_indices) != 0:
new_spike_segments = np.concatenate((new_spike_segments, added_segment_indices))

if len(segment_indices) == 0:
return new_spike_indices, new_spike_labels
else:
return new_spike_indices, new_spike_labels, new_spike_segments


def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=None):
"""
Inject some duplicate units in a sorting.
Expand Down
114 changes: 114 additions & 0 deletions src/spikeinterface/core/tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
generate_unit_locations,
generate_ground_truth_recording,
generate_sorting_to_inject,
synthesize_random_firings,
_add_spikes_to_spiketrain,
)

from spikeinterface.core.numpyextractors import NumpySorting
Expand Down Expand Up @@ -555,6 +557,118 @@ def test_generate_sorting_to_inject():
assert num_injected_spikes[unit_id] <= num_spikes[unit_id]


def test_synthesize_random_firings_length():

firing_rates = [2.0, 3.0]
duration = 2
num_units = 2

spike_times, spike_units = synthesize_random_firings(
num_units=num_units, duration=duration, firing_rates=firing_rates
)

assert len(spike_times) == int(np.sum(firing_rates) * duration)

units, counts = np.unique(spike_units, return_counts=True)

assert len(units) == num_units
assert np.sum(counts) == int(np.sum(firing_rates) * duration)


def test_synthesize_random_firings_length_with_insertion():

firing_rates = [2.0, 3.0]
duration = 2
num_units = 2

added_spikes_indices = [15000, 1]
added_spikes_labels = [12, 2]

spike_train_indices, spike_train_labels = synthesize_random_firings(
num_units=num_units, duration=duration, firing_rates=firing_rates
)

spike_times, _ = _add_spikes_to_spiketrain(
spike_train_indices,
spike_train_labels,
added_spikes_indices=added_spikes_indices,
added_spikes_labels=added_spikes_labels,
replace=False,
)

assert len(spike_times) == int(np.sum(firing_rates) * duration) + 2

spike_train_indices, spike_train_labels = synthesize_random_firings(
num_units=num_units, duration=duration, firing_rates=firing_rates
)

spike_times, _ = _add_spikes_to_spiketrain(
spike_train_indices,
spike_train_labels,
added_spikes_indices=added_spikes_indices,
added_spikes_labels=added_spikes_labels,
replace=True,
)

assert len(spike_times) == int(np.sum(firing_rates) * duration)


def test_add_insertions_replacement():

train_length = 10

spike_train_indices = np.zeros(train_length)
spike_train_labels = np.zeros(train_length)
segment_indices = np.zeros(train_length)

added_spikes_indices = [15.0, 1.0]
added_spikes_labels = [12.0, 2.0]
added_segment_indices = [0.0, 3.0]

new_spike_indices, new_spike_labels, new_segment_indices = _add_spikes_to_spiketrain(
spike_train_indices,
spike_train_labels,
segment_indices=segment_indices,
added_spikes_indices=added_spikes_indices,
added_spikes_labels=added_spikes_labels,
added_segment_indices=added_segment_indices,
replace=True,
)

assert added_spikes_indices[0] in new_spike_indices
assert added_spikes_labels[1] in new_spike_labels
assert added_segment_indices[1] in new_segment_indices
assert np.shape(new_spike_indices) == np.shape(spike_train_indices)


def test_add_insertions_no_replacement():

train_length = 10

spike_train_indices = np.zeros(train_length)
spike_train_labels = np.zeros(train_length)
segment_indices = np.zeros(train_length)

added_spikes_indices = [15.0, 1.0]
added_spikes_labels = [12.0, 2.0]
added_segment_indices = [0.0, 3.0]

new_spike_indices, new_spike_labels, new_segment_indices = _add_spikes_to_spiketrain(
spike_train_indices,
spike_train_labels,
segment_indices=segment_indices,
added_spikes_indices=added_spikes_indices,
added_spikes_labels=added_spikes_labels,
added_segment_indices=added_segment_indices,
replace=False,
)

assert added_spikes_indices[0] in new_spike_indices
assert added_spikes_labels[1] in new_spike_labels
assert added_segment_indices[1] in new_segment_indices
assert len(new_spike_indices) == len(spike_train_indices) + len(added_spikes_indices)


if __name__ == "__main__":
strategy = "tile_pregenerated"
# strategy = "on_the_fly"
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/qualitymetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
get_default_qm_params,
)
from .pca_metrics import get_quality_pca_metric_list
from .misc_metrics import get_synchrony_counts
112 changes: 73 additions & 39 deletions src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,51 @@ def compute_sliding_rp_violations(
)


def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ids=None, **kwargs):
def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would excpet a delta in the signature no ?

"""Compute synchrony counts, the number of simultaneous spikes with sizes `synchrony_sizes`

Parameters
----------
spikes : np.array
Structured numpy array with fields ("sample_index", "unit_index", "segment_index").
synchrony_sizes : numpy array
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this given in seconds, samples, milliseconds?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made the description more accruate (since it's actually a structured numpy array). One of the fields is "sample_index" which clarifies the unit.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was talking about the synchrony_sizes

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry. These are the number of synchronous events you want to count. So if you want to see when two or four spikes fire at the same time you use synchrony_sizes = (2,4). So it's an integer, and I think it will be clear for anyone who knows enough about the metric that their using it.

Copy link
Collaborator

@h-mayorquin h-mayorquin Mar 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I guess my lack of familarity with the metric is evident then. I somehow imagined that it was the windows in which an event would be counted as syncronous. Thanks for explaining it.

The synchrony sizes to compute. Should be pre-sorted.
unit_ids : list or None, default: None
List of unit ids to compute the synchrony metrics. Expecting all units.

Returns
-------
synchrony_counts : dict
The synchrony counts for the synchrony sizes.

References
----------
Based on concepts described in [Gruen]_
This code was adapted from `Elephant - Electrophysiology Analysis Toolkit <https://github.com/NeuralEnsemble/elephant/blob/master/elephant/spike_train_synchrony.py#L245>`_
"""

synchrony_counts = np.zeros((np.size(synchrony_sizes), len(all_unit_ids)), dtype=np.int64)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this the synhcony is higly dependant on the sampling rate no ?


# compute the occurrence of each sample_index. Count >2 means there's synchrony
_, unique_spike_index, counts = np.unique(spikes["sample_index"], return_index=True, return_counts=True)

sync_indices = unique_spike_index[counts >= 2]
sync_counts = counts[counts >= 2]

for i, sync_index in enumerate(sync_indices):

num_of_syncs = sync_counts[i]
units_with_sync = [spikes[sync_index + a][1] for a in range(0, num_of_syncs)]

# Counts inclusively. E.g. if there are 3 simultaneous spikes, these are also added
# to the 2 simultaneous spike bins.
how_many_bins_to_add_to = np.size(synchrony_sizes[synchrony_sizes <= num_of_syncs])
synchrony_counts[:how_many_bins_to_add_to, units_with_sync] += 1

return synchrony_counts


def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ids=None):
"""Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of
"synchrony_size" spikes at the exact same sample index.

Expand All @@ -521,49 +565,39 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_
This code was adapted from `Elephant - Electrophysiology Analysis Toolkit <https://github.com/NeuralEnsemble/elephant/blob/master/elephant/spike_train_synchrony.py#L245>`_
"""
assert min(synchrony_sizes) > 1, "Synchrony sizes must be greater than 1"
spike_counts = sorting_analyzer.sorting.count_num_spikes_per_unit(outputs="dict")
sorting = sorting_analyzer.sorting
spikes = sorting.to_spike_vector(concatenated=False)
# Sort the synchrony times so we can slice numpy arrays, instead of using dicts
synchrony_sizes_np = np.array(synchrony_sizes, dtype=np.int16)
synchrony_sizes_np.sort()

if unit_ids is None:
unit_ids = sorting_analyzer.unit_ids
res = namedtuple("synchrony", [f"sync_spike_{size}" for size in synchrony_sizes_np])

# Pre-allocate synchrony counts
synchrony_counts = {}
for synchrony_size in synchrony_sizes:
synchrony_counts[synchrony_size] = np.zeros(len(sorting_analyzer.unit_ids), dtype=np.int64)
sorting = sorting_analyzer.sorting

all_unit_ids = list(sorting.unit_ids)
for segment_index in range(sorting.get_num_segments()):
spikes_in_segment = spikes[segment_index]
spike_counts = sorting.count_num_spikes_per_unit(outputs="dict")

# we compute just by counting the occurrence of each sample_index
unique_spike_index, complexity = np.unique(spikes_in_segment["sample_index"], return_counts=True)
spikes = sorting.to_spike_vector()
all_unit_ids = sorting.unit_ids
synchrony_counts = get_synchrony_counts(spikes, synchrony_sizes_np, all_unit_ids)

synchrony_metrics_dict = {}
for sync_idx, synchrony_size in enumerate(synchrony_sizes_np):
synch_id_metrics_dict = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
synch_id_metrics_dict = {}
sync_id_metrics_dict = {}

Can we stick to the same naming? :)

Sorry for being overly annoying @chrishalcrow !!!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can name it anything you want if I can stop working on this ;) Just you wait: my next quality_metrics update is gonna be so careful!

for i, unit_id in enumerate(all_unit_ids):
if spike_counts[unit_id] != 0:
synch_id_metrics_dict[unit_id] = synchrony_counts[sync_idx][i] / spike_counts[unit_id]
else:
synch_id_metrics_dict[unit_id] = 0
synchrony_metrics_dict[f"sync_spike_{synchrony_size}"] = synch_id_metrics_dict

# add counts for this segment
for unit_id in unit_ids:
unit_index = all_unit_ids.index(unit_id)
spikes_per_unit = spikes_in_segment[spikes_in_segment["unit_index"] == unit_index]
# some segments/units might have no spikes
if len(spikes_per_unit) == 0:
continue
spike_complexity = complexity[np.isin(unique_spike_index, spikes_per_unit["sample_index"])]
for synchrony_size in synchrony_sizes:
synchrony_counts[synchrony_size][unit_index] += np.count_nonzero(spike_complexity >= synchrony_size)

# add counts for this segment
synchrony_metrics_dict = {
f"sync_spike_{synchrony_size}": {
unit_id: synchrony_counts[synchrony_size][all_unit_ids.index(unit_id)] / spike_counts[unit_id]
for unit_id in unit_ids
}
for synchrony_size in synchrony_sizes
}

# Convert dict to named tuple
synchrony_metrics_tuple = namedtuple("synchrony_metrics", synchrony_metrics_dict.keys())
synchrony_metrics = synchrony_metrics_tuple(**synchrony_metrics_dict)
return synchrony_metrics
if np.all(unit_ids == None) or (len(unit_ids) == len(all_unit_ids)):
return res(**synchrony_metrics_dict)
else:
reduced_synchrony_metrics_dict = {}
for key in synchrony_metrics_dict:
reduced_synchrony_metrics_dict[key] = {
unit_id: synchrony_metrics_dict[key][unit_id] for unit_id in unit_ids
}
return res(**reduced_synchrony_metrics_dict)


_default_params["synchrony"] = dict(synchrony_sizes=(2, 4, 8))
Expand Down
Loading
Loading