Skip to content

Commit 0a82679

Browse files
authored
Fix plotting of metric collection when prefix/postfix is set (Lightning-AI#2429)
* implementation * add tests * changelog
1 parent 5980744 commit 0a82679

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3636
- Fixed dtype being changed by deepspeed for certain regression metrics ([#2379](https://github.com/Lightning-AI/torchmetrics/pull/2379))
3737

3838

39+
- Fixed plotting of metric collection when prefix/postfix is set ([#2429](https://github.com/Lightning-AI/torchmetrics/pull/2429))
40+
41+
3942
- Fixed bug when `top_k>1` and `average="macro"` for classification metrics ([#2423](https://github.com/Lightning-AI/torchmetrics/pull/2423))
4043

4144

src/torchmetrics/collections.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -647,12 +647,11 @@ def plot(
647647
f"Expected argument `ax` to be a sequence of matplotlib axis objects with the same length as the "
648648
f"number of metrics in the collection, but got {type(ax)} with len {len(ax)} when `together=False`"
649649
)
650-
651650
val = val or self.compute()
652651
if together:
653652
return plot_single_or_multi_val(val, ax=ax)
654653
fig_axs = []
655-
for i, (k, m) in enumerate(self.items(keep_base=True, copy_state=False)):
654+
for i, (k, m) in enumerate(self.items(keep_base=False, copy_state=False)):
656655
if isinstance(val, dict):
657656
f, a = m.plot(val[k], ax=ax[i] if ax is not None else ax)
658657
elif isinstance(val, Sequence):

tests/unittests/utilities/test_plot.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -834,12 +834,17 @@ def test_confusion_matrix_plotter(metric_class, preds, target, labels, use_label
834834

835835
@pytest.mark.parametrize("together", [True, False])
836836
@pytest.mark.parametrize("num_vals", [1, 2])
837-
def test_plot_method_collection(together, num_vals):
837+
@pytest.mark.parametrize(
838+
("prefix", "postfix"), [(None, None), ("prefix", None), (None, "postfix"), ("prefix", "postfix")]
839+
)
840+
def test_plot_method_collection(together, num_vals, prefix, postfix):
838841
"""Test the plot method of metric collection."""
839842
m_collection = MetricCollection(
840843
BinaryAccuracy(),
841844
BinaryPrecision(),
842845
BinaryRecall(),
846+
prefix=prefix,
847+
postfix=postfix,
843848
)
844849
if num_vals == 1:
845850
m_collection.update(torch.randint(0, 2, size=(10,)), torch.randint(0, 2, size=(10,)))

0 commit comments

Comments
 (0)