Skip to content

Commit

Permalink
Fix axes in confusion matrix (#1976)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Aug 7, 2023
1 parent cd7ef55 commit 4259943
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed x/y labels when plotting confusion matrices ([#1976](https://github.com/Lightning-AI/torchmetrics/pull/1976))


- Fixed IOU compute in cuda ([#1982](https://github.com/Lightning-AI/torchmetrics/pull/1982))


Expand Down
34 changes: 34 additions & 0 deletions src/torchmetrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,17 @@
class BinaryConfusionMatrix(Metric):
r"""Compute the `confusion matrix`_ for binary tasks.
The confusion matrix :math:`C` is constructed such that :math:`C_{i, j}` is equal to the number of observations
known to be in class :math:`i` but predicted to be in class :math:`j`. Thus row indices of the confusion matrix
correspond to the true class labels and column indices correspond to the predicted class labels.
For binary tasks, the confusion matrix is a 2x2 matrix with the following structure:
- :math:`C_{0, 0}`: True negatives
- :math:`C_{0, 1}`: False positives
- :math:`C_{1, 0}`: False negatives
- :math:`C_{1, 1}`: True positives
As input to ``forward`` and ``update`` the metric accepts the following input:
- ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, ...)``. If preds is a floating point
Expand Down Expand Up @@ -176,6 +187,17 @@ def plot(
class MulticlassConfusionMatrix(Metric):
r"""Compute the `confusion matrix`_ for multiclass tasks.
The confusion matrix :math:`C` is constructed such that :math:`C_{i, j}` is equal to the number of observations
known to be in class :math:`i` but predicted to be in class :math:`j`. Thus row indices of the confusion matrix
correspond to the true class labels and column indices correspond to the predicted class labels.
For multiclass tasks, the confusion matrix is a NxN matrix, where:
- :math:`C_{i, i}` represents the number of true positives for class :math:`i`
- :math:`\sum_{j=1, j\neq i}^N C_{i, j}` represents the number of false negatives for class :math:`i`
- :math:`\sum_{i=1, i\neq j}^N C_{i, j}` represents the number of false positives for class :math:`i`
- the sum of the remaining cells in the matrix represents the number of true negatives for class :math:`i`
As input to ``forward`` and ``update`` the metric accepts the following input:
- ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, ...)``. If preds is a floating point
Expand Down Expand Up @@ -305,6 +327,18 @@ def plot(
class MultilabelConfusionMatrix(Metric):
r"""Compute the `confusion matrix`_ for multilabel tasks.
The confusion matrix :math:`C` is constructed such that :math:`C_{i, j}` is equal to the number of observations
known to be in class :math:`i` but predicted to be in class :math:`j`. Thus row indices of the confusion matrix
correspond to the true class labels and column indices correspond to the predicted class labels.
For multilabel tasks, the confusion matrix is a Nx2x2 tensor, where each 2x2 matrix corresponds to the confusion
for that label. The structure of each 2x2 matrix is as follows:
- :math:`C_{0, 0}`: True negatives
- :math:`C_{0, 1}`: False positives
- :math:`C_{1, 0}`: False negatives
- :math:`C_{1, 1}`: True positives
As input to 'update' the metric accepts the following input:
- ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/utilities/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ def plot_confusion_matrix(
if fig_label is not None:
ax.set_title(f"Label {fig_label[i]}", fontsize=15)
ax.imshow(confmat[i].cpu().detach() if confmat.ndim == 3 else confmat.cpu().detach())
ax.set_xlabel("True class", fontsize=15)
ax.set_ylabel("Predicted class", fontsize=15)
ax.set_xlabel("Predicted class", fontsize=15)
ax.set_ylabel("True class", fontsize=15)
ax.set_xticks(list(range(n_classes)))
ax.set_yticks(list(range(n_classes)))
ax.set_xticklabels(labels, rotation=45, fontsize=10)
Expand Down

0 comments on commit 4259943

Please sign in to comment.