-
Notifications
You must be signed in to change notification settings - Fork 186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
compute_synchrony_metrics update #2605
Changes from 22 commits
e8e4ca4
220add7
cb45927
b91f5f2
bfbc5e3
f92e2a9
4f5b119
1bb06b6
0706b4d
635d330
f22698b
56f0559
3de7c27
c53fedd
c9244a2
7cfd0ea
f3217f4
51f0039
f5d20fa
23c3355
7719ee1
66077e2
11d5d39
5de7b66
378976f
ecc6a9a
f6652f0
0b9a58a
16b344c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -845,6 +845,85 @@ def clean_refractory_period(times, refractory_period): | |||||
return times | ||||||
|
||||||
|
||||||
def _add_spikes_to_spiketrain( | ||||||
spike_indices, | ||||||
spike_labels, | ||||||
segment_indices=[], | ||||||
added_spikes_indices=None, | ||||||
added_spikes_labels=None, | ||||||
added_segment_indices=[], | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
replace=False, | ||||||
seed=None, | ||||||
): | ||||||
""" | ||||||
Add specified spikes into a spike train | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
spike_indices: array-like | ||||||
List of sample-indices for each spike | ||||||
spike_labels: array-like | ||||||
List of units for each spike | ||||||
segment_indices: array | ||||||
List of segment numbers for each spike | ||||||
added_spike_indices: array-like | ||||||
List of sample-indices for added spikes | ||||||
added_spike_labels: array-like | ||||||
List of units for added spikes | ||||||
added_segment_indices: array-like | ||||||
List of segments for added spikes | ||||||
replace: bool, default: False | ||||||
If True, randomly replace generated spikes. If False, add to existing spike train | ||||||
rng: numpy.random.Generator | ||||||
A random number generator | ||||||
seed: int, default: None | ||||||
seed for the generator | ||||||
|
||||||
Returns | ||||||
------- | ||||||
spike_train: np.ndarray | ||||||
Numpy array in same form as spike_train input, including added spikes | ||||||
|
||||||
""" | ||||||
|
||||||
# check lengths are consistent | ||||||
assert len(spike_indices) == len(spike_labels), "Length of spike indices and labels are not equal" | ||||||
assert (len(segment_indices) == 0) or ( | ||||||
len(spike_indices) == len(segment_indices) | ||||||
), "Length of spike indices and segments are not equal" | ||||||
assert len(added_spikes_indices) == len( | ||||||
added_spikes_labels | ||||||
), "Length of added spike indices and labels are not equal" | ||||||
assert (len(added_segment_indices) == 0) or ( | ||||||
len(added_spikes_indices) == len(added_segment_indices) | ||||||
), "Length of added spike indices and segments are not equal" | ||||||
|
||||||
new_spike_indices = np.array(spike_indices) | ||||||
new_spike_labels = np.array(spike_labels) | ||||||
new_spike_segments = np.array(segment_indices) | ||||||
|
||||||
rng = np.random.default_rng(seed=seed) | ||||||
|
||||||
if replace: | ||||||
replacement_indices = rng.choice(len(spike_indices), len(added_spikes_indices), replace=False) | ||||||
new_spike_indices[replacement_indices] = added_spikes_indices | ||||||
new_spike_labels[replacement_indices] = added_spikes_labels | ||||||
if len(segment_indices) != 0: | ||||||
print(new_spike_segments[replacement_indices]) | ||||||
print(added_segment_indices) | ||||||
new_spike_segments[replacement_indices] = added_segment_indices | ||||||
else: | ||||||
new_spike_indices = np.concatenate((new_spike_indices, added_spikes_indices)) | ||||||
new_spike_labels = np.concatenate((new_spike_labels, added_spikes_labels)) | ||||||
if len(segment_indices) != 0: | ||||||
new_spike_segments = np.concatenate((new_spike_segments, added_segment_indices)) | ||||||
|
||||||
if len(segment_indices) == 0: | ||||||
return new_spike_indices, new_spike_labels | ||||||
else: | ||||||
return new_spike_indices, new_spike_labels, new_spike_segments | ||||||
|
||||||
|
||||||
def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=None): | ||||||
""" | ||||||
Inject some duplicate units in a sorting. | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -496,7 +496,51 @@ def compute_sliding_rp_violations( | |||||
) | ||||||
|
||||||
|
||||||
def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ids=None, **kwargs): | ||||||
def get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would excpet a delta in the signature no ? |
||||||
"""Compute synchrony counts, the number of simultaneous spikes with sizes `synchrony_sizes` | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
spikes : np.array | ||||||
Structured numpy array with fields ("sample_index", "unit_index", "segment_index"). | ||||||
synchrony_sizes : numpy array | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this given in seconds, samples, milliseconds? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've made the description more accruate (since it's actually a structured numpy array). One of the fields is "sample_index" which clarifies the unit. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was talking about the synchrony_sizes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, sorry. These are the number of synchronous events you want to count. So if you want to see when two or four spikes fire at the same time you use synchrony_sizes = (2,4). So it's an integer, and I think it will be clear for anyone who knows enough about the metric that their using it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. I guess my lack of familarity with the metric is evident then. I somehow imagined that it was the windows in which an event would be counted as syncronous. Thanks for explaining it. |
||||||
The synchrony sizes to compute. Should be pre-sorted. | ||||||
unit_ids : list or None, default: None | ||||||
List of unit ids to compute the synchrony metrics. Expecting all units. | ||||||
|
||||||
Returns | ||||||
------- | ||||||
synchrony_counts : dict | ||||||
The synchrony counts for the synchrony sizes. | ||||||
|
||||||
References | ||||||
---------- | ||||||
Based on concepts described in [Gruen]_ | ||||||
This code was adapted from `Elephant - Electrophysiology Analysis Toolkit <https://github.com/NeuralEnsemble/elephant/blob/master/elephant/spike_train_synchrony.py#L245>`_ | ||||||
""" | ||||||
|
||||||
synchrony_counts = np.zeros((np.size(synchrony_sizes), len(all_unit_ids)), dtype=np.int64) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this the synhcony is higly dependant on the sampling rate no ? |
||||||
|
||||||
# compute the occurrence of each sample_index. Count >2 means there's synchrony | ||||||
_, unique_spike_index, counts = np.unique(spikes["sample_index"], return_index=True, return_counts=True) | ||||||
|
||||||
sync_indices = unique_spike_index[counts >= 2] | ||||||
sync_counts = counts[counts >= 2] | ||||||
|
||||||
for i, sync_index in enumerate(sync_indices): | ||||||
|
||||||
num_of_syncs = sync_counts[i] | ||||||
units_with_sync = [spikes[sync_index + a][1] for a in range(0, num_of_syncs)] | ||||||
|
||||||
# Counts inclusively. E.g. if there are 3 simultaneous spikes, these are also added | ||||||
# to the 2 simultaneous spike bins. | ||||||
how_many_bins_to_add_to = np.size(synchrony_sizes[synchrony_sizes <= num_of_syncs]) | ||||||
synchrony_counts[:how_many_bins_to_add_to, units_with_sync] += 1 | ||||||
|
||||||
return synchrony_counts | ||||||
|
||||||
|
||||||
def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ids=None): | ||||||
"""Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of | ||||||
"synchrony_size" spikes at the exact same sample index. | ||||||
|
||||||
|
@@ -521,49 +565,39 @@ def compute_synchrony_metrics(sorting_analyzer, synchrony_sizes=(2, 4, 8), unit_ | |||||
This code was adapted from `Elephant - Electrophysiology Analysis Toolkit <https://github.com/NeuralEnsemble/elephant/blob/master/elephant/spike_train_synchrony.py#L245>`_ | ||||||
""" | ||||||
assert min(synchrony_sizes) > 1, "Synchrony sizes must be greater than 1" | ||||||
spike_counts = sorting_analyzer.sorting.count_num_spikes_per_unit(outputs="dict") | ||||||
sorting = sorting_analyzer.sorting | ||||||
spikes = sorting.to_spike_vector(concatenated=False) | ||||||
# Sort the synchrony times so we can slice numpy arrays, instead of using dicts | ||||||
synchrony_sizes_np = np.array(synchrony_sizes, dtype=np.int16) | ||||||
synchrony_sizes_np.sort() | ||||||
|
||||||
if unit_ids is None: | ||||||
unit_ids = sorting_analyzer.unit_ids | ||||||
res = namedtuple("synchrony", [f"sync_spike_{size}" for size in synchrony_sizes_np]) | ||||||
|
||||||
# Pre-allocate synchrony counts | ||||||
synchrony_counts = {} | ||||||
for synchrony_size in synchrony_sizes: | ||||||
synchrony_counts[synchrony_size] = np.zeros(len(sorting_analyzer.unit_ids), dtype=np.int64) | ||||||
sorting = sorting_analyzer.sorting | ||||||
|
||||||
all_unit_ids = list(sorting.unit_ids) | ||||||
for segment_index in range(sorting.get_num_segments()): | ||||||
spikes_in_segment = spikes[segment_index] | ||||||
spike_counts = sorting.count_num_spikes_per_unit(outputs="dict") | ||||||
|
||||||
# we compute just by counting the occurrence of each sample_index | ||||||
unique_spike_index, complexity = np.unique(spikes_in_segment["sample_index"], return_counts=True) | ||||||
spikes = sorting.to_spike_vector() | ||||||
all_unit_ids = sorting.unit_ids | ||||||
synchrony_counts = get_synchrony_counts(spikes, synchrony_sizes_np, all_unit_ids) | ||||||
|
||||||
synchrony_metrics_dict = {} | ||||||
for sync_idx, synchrony_size in enumerate(synchrony_sizes_np): | ||||||
synch_id_metrics_dict = {} | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Can we stick to the same naming? :) Sorry for being overly annoying @chrishalcrow !!! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can name it anything you want if I can stop working on this ;) Just you wait: my next quality_metrics update is gonna be so careful! |
||||||
for i, unit_id in enumerate(all_unit_ids): | ||||||
if spike_counts[unit_id] != 0: | ||||||
synch_id_metrics_dict[unit_id] = synchrony_counts[sync_idx][i] / spike_counts[unit_id] | ||||||
else: | ||||||
synch_id_metrics_dict[unit_id] = 0 | ||||||
synchrony_metrics_dict[f"sync_spike_{synchrony_size}"] = synch_id_metrics_dict | ||||||
|
||||||
# add counts for this segment | ||||||
for unit_id in unit_ids: | ||||||
unit_index = all_unit_ids.index(unit_id) | ||||||
spikes_per_unit = spikes_in_segment[spikes_in_segment["unit_index"] == unit_index] | ||||||
# some segments/units might have no spikes | ||||||
if len(spikes_per_unit) == 0: | ||||||
continue | ||||||
spike_complexity = complexity[np.isin(unique_spike_index, spikes_per_unit["sample_index"])] | ||||||
for synchrony_size in synchrony_sizes: | ||||||
synchrony_counts[synchrony_size][unit_index] += np.count_nonzero(spike_complexity >= synchrony_size) | ||||||
|
||||||
# add counts for this segment | ||||||
synchrony_metrics_dict = { | ||||||
f"sync_spike_{synchrony_size}": { | ||||||
unit_id: synchrony_counts[synchrony_size][all_unit_ids.index(unit_id)] / spike_counts[unit_id] | ||||||
for unit_id in unit_ids | ||||||
} | ||||||
for synchrony_size in synchrony_sizes | ||||||
} | ||||||
|
||||||
# Convert dict to named tuple | ||||||
synchrony_metrics_tuple = namedtuple("synchrony_metrics", synchrony_metrics_dict.keys()) | ||||||
synchrony_metrics = synchrony_metrics_tuple(**synchrony_metrics_dict) | ||||||
return synchrony_metrics | ||||||
if np.all(unit_ids == None) or (len(unit_ids) == len(all_unit_ids)): | ||||||
return res(**synchrony_metrics_dict) | ||||||
else: | ||||||
reduced_synchrony_metrics_dict = {} | ||||||
for key in synchrony_metrics_dict: | ||||||
reduced_synchrony_metrics_dict[key] = { | ||||||
unit_id: synchrony_metrics_dict[key][unit_id] for unit_id in unit_ids | ||||||
} | ||||||
return res(**reduced_synchrony_metrics_dict) | ||||||
|
||||||
|
||||||
_default_params["synchrony"] = dict(synchrony_sizes=(2, 4, 8)) | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #2345
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To quote Joe: "Thanks I did not know about that, I am learning so much working on the SI codebase 😍!"