-
Notifications
You must be signed in to change notification settings - Fork 226
Refactor metrics into its own module #4183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) | ||
|
|
||
| metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] | ||
| metrics.loc[new_unit_ids_f, :] = self._compute_metrics( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello, this is a new thing. It'd be great to check if we can compute this before we try to do it, for the following situation:
Suppose you originally compute a metric using spikeinterface version 103 (or some fork that you've made yourself... ahem).
Then you open your analyzer in si-gui using version 102. There was a new metric introduced in 103, which 102 doesn't know about. When you try to merge, it errors because it can't compute the new metric. So you do any merging at all due to the inability to merge one metric.
Or you no longer have the recording when you open, so you can't compute sd_ratio or something....
Instead, I'd like to warn if we can't compute and stick in anan. We could could do that here by checking that metric_names are in self.metric_list.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I think I meant to write this at line 1207 about the merging step, but also applies to splits!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nag nag nag
| if unit_ids is None: | ||
| unit_ids = sorting_analyzer.unit_ids | ||
|
|
||
| _has_required_extensions(sorting_analyzer, metric_name="snr") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These checks were done in case people use e.g. compute_snrs directly, and it still gives them a good error message.
| ComputeQualityMetrics, | ||
| compute_quality_metrics, | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We used to import all the compute functions, like compute_snrs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that users can still access the individual functions with from spikeinterface.metrics.quality.misc_metrics import ...
We can make sure we highlight this in the release notes, but I would avoid clogging the import with all the low-level functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd be more careful, because our current docs (e.g. https://spikeinterface.readthedocs.io/en/stable/modules/qualitymetrics/firing_range.html) has lots of code which imports the functions directly from spikeinterface.quality_metrics. So I reckon this could break a lot of code.
|
This looks great - love the I think this is a good chance to remove I'd vote to take the chance to make multi channel template metrics included by default: they're very helpful. |
I agree! Maybe we can make it default for number of channel > 64? |
Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com>
|
Was having a play - this PR makes it very easy to write your own metrics outside of spikeinterface, and stuff like merging works fine. Not portable, but useful for me! import spikeinterface.full as si
from spikeinterface.core.analyzer_extension_core import BaseMetric
from spikeinterface.metrics import ComputeQualityMetrics
def compute_something(
sorting_analyzer,
unit_ids=None,
):
if unit_ids is None:
unit_ids = sorting_analyzer.unit_ids
result = {unit_id: str(unit_id) + "_hi!!!!" for unit_id in unit_ids}
return result
class MyCustomMetric(BaseMetric):
metric_name = "chris_metric"
metric_function = compute_something
metric_params = {}
metric_columns = {"the_unit_id": str}
ComputeQualityMetrics.metric_list.append(MyCustomMetric)
rec, sort = si.generate_ground_truth_recording()
sa = si.create_sorting_analyzer(sort, rec)
sa.compute('quality_metrics')
merged_sa = sa.merge_units(merge_unit_groups = [['1', '2']], new_unit_ids=['12'])
merged_sa.get_extension("quality_metrics").get_data().loc['12'] |
| @@ -0,0 +1,19 @@ | |||
| Metrics | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| Metrics | |
| Metrics module | |
| -------------- |
| get_default_qm_params, | ||
| import warnings | ||
|
|
||
| warnings.warn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get the deprecation warning if I do e.g.
from spikeinterface.quality_metrics import compute_quality_metrics
Because of some import * magic. This should fix it in almost all cases:
if __name__ not in ('__main__', 'builtins'):
warnings.warn(
"The module 'spikeinterface.qualitymetrics' is deprecated and will be removed in 0.105.0."
"Please use 'spikeinterface.metrics.quality' instead.",
DeprecationWarning,
stacklevel=2,
)
| tmp_data : dict | ||
| Temporary data to pass to the metric function | ||
| job_kwargs : dict | ||
| Job keyword arguments to control paralleization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| Job keyword arguments to control paralleization | |
| Job keyword arguments to control parallelization |
| metrics : pd.DataFrame | ||
| DataFrame containing the computed metrics for each unit. | ||
| """ | ||
| import pandas as pd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just noting that we've got pandas in core here. You can't load a sorting analyzer with any metrics extensions without pandas. Time to add it to core deps?
| """ | ||
| return get_default_analyzer_extension_params(extension_name) | ||
|
|
||
| def get_metrics_extension_data(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
love love love
| sampling_frequency_up = sampling_frequency | ||
| tmp_data["sampling_frequency"] = sampling_frequency_up | ||
|
|
||
| include_multi_channel_metrics = self.params["include_multi_channel_metrics"] or any( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We didn't used to save "include_multi_channel_metrics" so self.params["include_multi_channel_metrics"] errors for old analyzers. Replace with get and it's fine
| include_multi_channel_metrics = self.params["include_multi_channel_metrics"] or any( | |
| include_multi_channel_metrics = self.params.get("include_multi_channel_metrics") or any( |
|
|
||
| extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign, outputs="index") | ||
| all_templates = get_dense_templates_array(sorting_analyzer, return_in_uV=True) | ||
| channel_locations = sorting_analyzer.recording.get_channel_locations() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't work for recordingless analyzers. Do we store the channel_locations anywhere else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah, just do sorting_analyzer.get_channel_locations()
| metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) | ||
|
|
||
| metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] | ||
| metrics.loc[new_unit_ids, :] = self._compute_metrics( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will error if we don't know how to compute the metrics in metric_names. So if a metric changes name between version, we get an error and can't merge/split. I think we should only give _compute_metrics the intersection of metric_names and self.metric_list?
This PR includes a major refactor of the metrics concept.
It defines a
BaseMetric, with core metadata of individual metrics including dtypes, column names, extension dependance, and a compute function.Another
BaseMetricExtensioncontains a collection ofBaseMetrics and deals with most of the machinery, including:The
template_metrics,quality_metrics, and a newspiketrain_metricsextensions are now in themetricsmodule. The latter only includesnum_spikesandfiring_rate, which are also imported as quality metrics.Still finalizing tests, but this should be 90% done