Skip to content

Bug with compute confuision matrix when using 'best' match mode #2793

Open
@HughYau

Description

@HughYau

When I do the comparision with ground truth data, I found that it often happens error when using 'best' mode, but when I turned it to 'hungarian' mode, it works. It look like some problem when create the pandas dataframe, it should provides 3+13+1 (44) table, but something wrong with its indices so expected (4,5) values


ValueError Traceback (most recent call last)
Cell In[19], line 9
7 si.plot_agreement_matrix(comp, ordered=True,backend='matplotlib',figtitle = sorter_name+' Agreement matrix')
8 plt.savefig(output_folder/f'agreement_{sorter_name}.png')
----> 9 si.plot_confusion_matrix(comp,backend='matplotlib',figtitle = sorter_name+' Confusion matrix')
10 plt.savefig(output_folder/f'confusion_{sorter_name}.png')
11 perf = comp.get_performance()

File d:\Applications\anaconda3\envs\eeg\lib\site-packages\spikeinterface\widgets\comparison.py:29, in ConfusionMatrixWidget.init(self, gt_comparison, count_text, unit_ticks, backend, **backend_kwargs)
23 def init(self, gt_comparison, count_text=True, unit_ticks=True, backend=None, **backend_kwargs):
24 plot_data = dict(
25 gt_comparison=gt_comparison,
26 count_text=count_text,
27 unit_ticks=unit_ticks,
28 )
---> 29 BaseWidget.init(self, plot_data, backend=backend, **backend_kwargs)

File d:\Applications\anaconda3\envs\eeg\lib\site-packages\spikeinterface\widgets\base.py:82, in BaseWidget.init(self, data_plot, backend, immediate_plot, **backend_kwargs)
79 self.backend_kwargs = backend_kwargs_
81 if immediate_plot:
---> 82 self.do_plot()

File d:\Applications\anaconda3\envs\eeg\lib\site-packages\spikeinterface\widgets\base.py:103, in BaseWidget.do_plot(self)
101 def do_plot(self):
102 func = getattr(self, f"plot_{self.backend}")
--> 103 func(self.data_plot, **self.backend_kwargs)

File d:\Applications\anaconda3\envs\eeg\lib\site-packages\spikeinterface\widgets\comparison.py:41, in ConfusionMatrixWidget.plot_matplotlib(self, data_plot, **backend_kwargs)
37 self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)
39 comp = dp.gt_comparison
---> 41 confusion_matrix = comp.get_confusion_matrix()
42 N1 = confusion_matrix.shape[0] - 1
43 N2 = confusion_matrix.shape[1] - 1

File d:\Applications\anaconda3\envs\eeg\lib\site-packages\spikeinterface\comparison\paircomparisons.py:373, in GroundTruthComparison.get_confusion_matrix(self)
364 """
365 Computes the confusion matrix.
366
(...)
370 The confusion matrix
371 """
372 if self._confusion_matrix is None:
--> 373 self._do_confusion_matrix()
374 return self._confusion_matrix

File d:\Applications\anaconda3\envs\eeg\lib\site-packages\spikeinterface\comparison\paircomparisons.py:359, in GroundTruthComparison._do_confusion_matrix(self)
356 elif self.match_mode == "best":
357 match_12 = self.best_match_12
--> 359 self._confusion_matrix = do_confusion_matrix(
360 self.event_counts1, self.event_counts2, match_12, self.match_event_count
361 )

File d:\Applications\anaconda3\envs\eeg\lib\site-packages\spikeinterface\comparison\comparisontools.py:720, in do_confusion_matrix(event_counts1, event_counts2, match_12, match_event_count)
716 ordered_units2 = np.hstack([matched_units2, unmatched_units2])
718 import pandas as pd
--> 720 conf_matrix = pd.DataFrame(
721 np.zeros((N1 + 1, N2 + 1), dtype=int),
722 index=list(ordered_units1) + ["FP"],
723 columns=list(ordered_units2) + ["FN"],
724 )
726 for u1 in matched_units1:
727 u2 = match_12[u1]

File d:\Applications\anaconda3\envs\eeg\lib\site-packages\pandas\core\frame.py:816, in DataFrame.init(self, data, index, columns, dtype, copy)
805 mgr = dict_to_mgr(
806 # error: Item "ndarray" of "Union[ndarray, Series, Index]" has no
807 # attribute "name"
(...)
813 copy=_copy,
814 )
815 else:
--> 816 mgr = ndarray_to_mgr(
817 data,
818 index,
819 columns,
820 dtype=dtype,
821 copy=copy,
822 typ=manager,
823 )
825 # For data is list-like, or Iterable (will consume into list)
826 elif is_list_like(data):

File d:\Applications\anaconda3\envs\eeg\lib\site-packages\pandas\core\internals\construction.py:336, in ndarray_to_mgr(values, index, columns, dtype, copy, typ)
331 # _prep_ndarraylike ensures that values.ndim == 2 at this point
332 index, columns = _get_axes(
333 values.shape[0], values.shape[1], index=index, columns=columns
334 )
--> 336 _check_values_indices_shape_match(values, index, columns)
338 if typ == "array":
339 if issubclass(values.dtype.type, str):

File d:\Applications\anaconda3\envs\eeg\lib\site-packages\pandas\core\internals\construction.py:420, in _check_values_indices_shape_match(values, index, columns)
418 passed = values.shape
419 implied = (len(index), len(columns))
--> 420 raise ValueError(f"Shape of passed values is {passed}, indices imply {implied}")

ValueError: Shape of passed values is (4, 4), indices imply (4, 5)

Metadata

Metadata

Assignees

No one assigned

    Labels

    comparisonRelated to comparison module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions