-
Notifications
You must be signed in to change notification settings - Fork 231
Implementing sparsity for ComputeTemplates #4212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Kudo for the code rewriting... All tests are failing :-) |
…sparsity_templates
for more information, see https://pre-commit.ci
|
@alejoe91 @chrishalcrow : this will help a lot ith very high channel count (in vitro mea). |
|
@alejoe91 we need to squash |
| ) | ||
| for unit_index, unit_id in enumerate(self.sorting_analyzer.unit_ids): | ||
| chan_inds = self.sparsity.unit_id_to_channel_indices[unit_id] | ||
| dense_arr[unit_index][:, chan_inds] = arr[unit_index, :, : chan_inds.size] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have the exact same logic in ChannelSparsity.densify_waveforms:
My proposal is to add a densify_templates function to the ChannelSparsity class (it already has a sparsify_templates):
def densify_templates(self, templates_array: np.ndarray) -> np.ndarray:
assert templates_array.shape[0] == self.num_units
densified_shape = (self.num_units, templates_array.shape[1], self.num_channels)
dense_templates = np.zeros(shape=densified_shape, dtype=templates_array.dtype)
for unit_index, unit_id in enumerate(self.unit_ids):
sparse_template = templates_array[unit_index, ...]
dense_template = self.densify_waveforms(waveforms=sparse_template[np.newaxis, :, :], unit_id=unit_id)
dense_templates[unit_index, :, :] = dense_template
return dense_templates
Then the logic in the ComputeTemplates extension could simply become:
if self.sparsity is not None:
# make average and std dense again
for k, arr in data.items():
dense_arr = self.sparsity.densify_templates(arr)
data[k] = dense_arr
What do you think?
alejoe91
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should use the ChannelSparsity methods
This implements the use of sparsity to compute templates in the Analyzer. Currently, analzers can handle sparsity mask, and while they are used to compute waveforms, they are not used to estimate templates. This can lead to large memory usage with numerous threads. This PR makes sure the analyzer, if sparse, uses the sparsity masks while estimating templates ONLY in the run() method. For compatibility with downstream extensions, the internal storage is kept dense*
See #4194