Skip to content

Commit 00121ff

Browse files
SkafteNickiBorda
authored andcommitted
Fix plotting of metric collection when prefix/postfix is set (Lightning-AI#2429)
* implementation * add tests * changelog (cherry picked from commit 0a82679)
1 parent d3e891e commit 00121ff

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3333
- Fixed dtype being changed by deepspeed for certain regression metrics ([#2379](https://github.com/Lightning-AI/torchmetrics/pull/2379))
3434

3535

36+
- Fixed plotting of metric collection when prefix/postfix is set ([#2429](https://github.com/Lightning-AI/torchmetrics/pull/2429))
37+
38+
3639
- Fixed bug when `top_k>1` and `average="macro"` for classification metrics ([#2423](https://github.com/Lightning-AI/torchmetrics/pull/2423))
3740

3841

src/torchmetrics/collections.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -647,12 +647,11 @@ def plot(
647647
f"Expected argument `ax` to be a sequence of matplotlib axis objects with the same length as the "
648648
f"number of metrics in the collection, but got {type(ax)} with len {len(ax)} when `together=False`"
649649
)
650-
651650
val = val or self.compute()
652651
if together:
653652
return plot_single_or_multi_val(val, ax=ax)
654653
fig_axs = []
655-
for i, (k, m) in enumerate(self.items(keep_base=True, copy_state=False)):
654+
for i, (k, m) in enumerate(self.items(keep_base=False, copy_state=False)):
656655
if isinstance(val, dict):
657656
f, a = m.plot(val[k], ax=ax[i] if ax is not None else ax)
658657
elif isinstance(val, Sequence):

tests/unittests/utilities/test_plot.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -855,12 +855,17 @@ def test_confusion_matrix_plotter(metric_class, preds, target, labels, use_label
855855

856856
@pytest.mark.parametrize("together", [True, False])
857857
@pytest.mark.parametrize("num_vals", [1, 2])
858-
def test_plot_method_collection(together, num_vals):
858+
@pytest.mark.parametrize(
859+
("prefix", "postfix"), [(None, None), ("prefix", None), (None, "postfix"), ("prefix", "postfix")]
860+
)
861+
def test_plot_method_collection(together, num_vals, prefix, postfix):
859862
"""Test the plot method of metric collection."""
860863
m_collection = MetricCollection(
861864
BinaryAccuracy(),
862865
BinaryPrecision(),
863866
BinaryRecall(),
867+
prefix=prefix,
868+
postfix=postfix,
864869
)
865870
if num_vals == 1:
866871
m_collection.update(torch.randint(0, 2, size=(10,)), torch.randint(0, 2, size=(10,)))

0 commit comments

Comments
 (0)