Skip to content

Commit 2ae0c8b

Browse files
JonasVerbickasBorda
authored andcommitted
Improve confusion matrix plotting (Lightning-AI#2358)
Round floats to avoid floating point errors leading to UI overflow. Remove overlapping text in multilabel plots by reducing redundant `Predicted class` and `True class` labels. Use `constrained_layout` to prevent some text from being cut off. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com> (cherry picked from commit 71089f0)
1 parent cba48a2 commit 2ae0c8b

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
4545
- Fixed initialize aggregation metrics with default floating type ([#2366](https://github.com/Lightning-AI/torchmetrics/pull/2366))
4646

4747

48+
- Fixed plotting of confusion matrices ([#2358](https://github.com/Lightning-AI/torchmetrics/pull/2358))
49+
50+
4851
## [1.3.0] - 2024-01-10
4952

5053
### Added

src/torchmetrics/utilities/plot.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,15 +242,17 @@ def plot_confusion_matrix(
242242
fig_label = None
243243
labels = labels or np.arange(n_classes).tolist()
244244

245-
fig, axs = plt.subplots(nrows=rows, ncols=cols) if ax is None else (ax.get_figure(), ax)
245+
fig, axs = plt.subplots(nrows=rows, ncols=cols, constrained_layout=True) if ax is None else (ax.get_figure(), ax)
246246
axs = trim_axs(axs, nb)
247247
for i in range(nb):
248248
ax = axs[i] if rows != 1 and cols != 1 else axs
249249
if fig_label is not None:
250250
ax.set_title(f"Label {fig_label[i]}", fontsize=15)
251251
ax.imshow(confmat[i].cpu().detach() if confmat.ndim == 3 else confmat.cpu().detach())
252-
ax.set_xlabel("Predicted class", fontsize=15)
253-
ax.set_ylabel("True class", fontsize=15)
252+
if i // cols == rows - 1: # bottom row only
253+
ax.set_xlabel("Predicted class", fontsize=15)
254+
if i % cols == 0: # leftmost column only
255+
ax.set_ylabel("True class", fontsize=15)
254256
ax.set_xticks(list(range(n_classes)))
255257
ax.set_yticks(list(range(n_classes)))
256258
ax.set_xticklabels(labels, rotation=45, fontsize=10)
@@ -259,7 +261,7 @@ def plot_confusion_matrix(
259261
if add_text:
260262
for ii, jj in product(range(n_classes), range(n_classes)):
261263
val = confmat[i, ii, jj] if confmat.ndim == 3 else confmat[ii, jj]
262-
ax.text(jj, ii, str(val.item()), ha="center", va="center", fontsize=15)
264+
ax.text(jj, ii, str(round(val.item(), 2)), ha="center", va="center", fontsize=15)
263265

264266
return fig, axs
265267

0 commit comments

Comments
 (0)