Skip to content

Commit f33f80c

Browse files
committed
Look into the template option.
1 parent b5c85ff commit f33f80c

File tree

1 file changed

+26
-42
lines changed

1 file changed

+26
-42
lines changed

src/spikeinterface/working/load_kilosort_utils.py

Lines changed: 26 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def compute_spike_amplitude_and_depth(
3030
"""
3131
Compute the indicies, amplitudes and locations for all detected spikes from the kilosort output.
3232
33-
This function is based on code in Nick Steinmetz's `spikes` repository,
33+
This function is based on code in Cortex Lab's `spikes` repository,
3434
https://github.com/cortex-lab/spikes
3535
3636
Parameters
@@ -119,54 +119,27 @@ def _get_locations_from_pc_features(params):
119119
120120
Notes
121121
-----
122-
Location of of each individual spike is computed from its low-dimensional projection.
123-
During sorting, kilosort computes the '
124-
`pc_features` (num_spikes, num_PC, num_channels) holds the PC values for each spike.
125-
Taking the first component, the subset of 32 channels associated with this
126-
spike are indexed to get the actual channel locations (in um). Then, the channel
127-
locations are weighted by their PC values.
128-
129-
This function is based on code in Nick Steinmetz's `spikes` repository,
122+
My understanding so far. KS1 paper; The individual spike waveforms are decomposed into
123+
'private PCs'. Let the waveform matrix W be time (t) x channel (c). PCA
124+
decompoisition is performed to compute c basis waveforms. Scores for each
125+
channel onto the top three PCs are stored (these recover the waveform well.
126+
127+
This function is based on code in Cortex Lab's `spikes` repository,
130128
https://github.com/cortex-lab/spikes
131129
"""
132-
# Compute spike depths
133-
134-
# for each spike, a PCA is computed just on that spike (n samples x n channels).
135-
# the components are all different between spikes, so are not saved.
136-
# This gives a (n pc = 3, num channels) set of scores.
137-
# but then how it is possible for some spikes to have zero score onto the principal channel?
138-
139-
breakpoint()
140-
pc_features = params["pc_features"][:, 0, :]
130+
pc_features = params["pc_features"][:, 0, :].copy()
141131
pc_features[pc_features < 0] = 0
142132

143-
# Some spikes do not load at all onto the first PC. To avoid biasing the
144-
# dataset by removing these, we repeat the above for the next PC,
145-
# to compute distances for neurons that do not load onto the 1st PC.
146-
# This is not ideal at all, it would be much better to a) find the
147-
# max value for each channel on each of the PCs (i.e. basis vectors).
148-
# Then recompute the estimated waveform peak on each channel by
149-
# summing the PCs by their respective weights. However, the PC basis
150-
# vectors themselves do not appear to be output by KS.
151-
152-
# We include the (n_channels i.e. features) from the second PC
153-
# into the `pc_features` mostly containing the first PC. As all
154-
# operations are per-spike (i.e. row-wise)
155-
no_pc1_signal_spikes = np.where(np.sum(pc_features, axis=1) == 0)
156-
157-
pc_features_2 = params["pc_features"][:, 1, :]
158-
pc_features_2[pc_features_2 < 0] = 0
159-
160-
pc_features[no_pc1_signal_spikes] = pc_features_2[no_pc1_signal_spikes]
161-
162133
if np.any(np.sum(pc_features, axis=1) == 0):
134+
# TODO: 1) handle this case for pc_features
135+
# 2) instead use the template_features for all other versions.
163136
raise RuntimeError(
164137
"Some spikes do not load at all onto the first"
165138
"or second principal component. It is necessary"
166139
"to extend this code section to handle more components."
167140
)
168141

169-
# Get the channel indices corresponding to the 32 channels from the PC.
142+
# Get the channel indices corresponding to the channels from the PC.
170143
spike_features_indices = params["pc_features_indices"][params["spike_templates"], :]
171144

172145
# Compute the spike locations as the center of mass of the PC scores
@@ -199,7 +172,7 @@ def get_unwhite_template_info(
199172
Amplitude is calculated for each spike as the template amplitude
200173
multiplied by the `template_scaling_amplitudes`.
201174
202-
This function is based on code in Nick Steinmetz's `spikes` repository,
175+
This function is based on code in Cortex Lab's `spikes` repository,
203176
https://github.com/cortex-lab/spikes
204177
205178
Parameters
@@ -277,7 +250,7 @@ def compute_template_amplitudes_from_spikes(templates, spike_templates, spike_am
277250
Take the average of all spike amplitudes to get actual template amplitudes
278251
(since tempScalingAmps are equal mean for all templates)
279252
280-
This function is ported from Nick Steinmetz's `spikes` repository,
253+
This function is ported from Cortex Lab's `spikes` repository,
281254
https://github.com/cortex-lab/spikes
282255
"""
283256
num_indices = templates.shape[0]
@@ -297,7 +270,7 @@ def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
297270
"""
298271
Loads the output of Kilosort into a `params` dict.
299272
300-
This function was ported from Nick Steinmetz's `spikes` repository MATLAB
273+
This function was ported from Cortex Lab's `spikes` repository MATLAB
301274
code, https://github.com/cortex-lab/spikes
302275
303276
Parameters
@@ -343,8 +316,15 @@ def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
343316
if load_pcs:
344317
pc_features = np.load(sorter_output / "pc_features.npy")
345318
pc_features_indices = np.load(sorter_output / "pc_feature_ind.npy")
319+
320+
if (sorter_output / "template_features.npy").is_file():
321+
template_features = np.load(sorter_output / "template_features.npy")
322+
template_features_indices = np.load(sorter_output / "templates_ind.npy")
323+
else:
324+
template_features = template_features_indices = None
346325
else:
347326
pc_features = pc_features_indices = None
327+
template_features = template_features_indices = None
348328

349329
# This makes the assumption that there will never be different .csv and .tsv files
350330
# in the same sorter output (this should never happen, there will never even be two).
@@ -364,6 +344,8 @@ def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
364344

365345
if load_pcs:
366346
pc_features = pc_features[not_noise_clusters_by_spike, :, :]
347+
if template_features is not None:
348+
template_features = template_features[not_noise_clusters_by_spike, :, :]
367349

368350
spike_clusters = spike_clusters[not_noise_clusters_by_spike]
369351
cluster_ids = cluster_ids[cluster_groups != 0]
@@ -378,6 +360,8 @@ def load_ks_dir(sorter_output: Path, exclude_noise: bool = True, load_pcs: bool
378360
"spike_clusters": spike_clusters.squeeze(),
379361
"pc_features": pc_features,
380362
"pc_features_indices": pc_features_indices,
363+
"template_features": template_features,
364+
"template_features_indices": template_features_indices,
381365
"temp_scaling_amplitudes": temp_scaling_amplitudes.squeeze(),
382366
"cluster_ids": cluster_ids,
383367
"cluster_groups": cluster_groups,
@@ -399,7 +383,7 @@ def _load_cluster_groups(cluster_path: Path) -> tuple[np.ndarray, ...]:
399383
There is some slight formatting differences between the `.tsv` and `.csv`
400384
versions, presumably from different kilosort versions.
401385
402-
This function was ported from Nick Steinmetz's `spikes` repository MATLAB code,
386+
This function was ported from Cortex Lab's `spikes` repository MATLAB code,
403387
https://github.com/cortex-lab/spikes
404388
405389
Parameters

0 commit comments

Comments
 (0)