Skip to content

Commit 5890e42

Browse files
authored
Merge branch 'master' into lint/ruff
2 parents 5a1c635 + 0a82679 commit 5890e42

File tree

7 files changed

+95
-31
lines changed

7 files changed

+95
-31
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3636
- Fixed dtype being changed by deepspeed for certain regression metrics ([#2379](https://github.com/Lightning-AI/torchmetrics/pull/2379))
3737

3838

39+
- Fixed plotting of metric collection when prefix/postfix is set ([#2429](https://github.com/Lightning-AI/torchmetrics/pull/2429))
40+
41+
3942
- Fixed bug when `top_k>1` and `average="macro"` for classification metrics ([#2423](https://github.com/Lightning-AI/torchmetrics/pull/2423))
4043

4144

45+
- Fixed case where label prediction tensors in classification metrics were not validated correctly ([#2427](https://github.com/Lightning-AI/torchmetrics/pull/2427))
46+
47+
4248
- Fixed how auc scores are calculated in `PrecisionRecallCurve.plot` methods ([#2437](https://github.com/Lightning-AI/torchmetrics/pull/2437))
4349

4450
## [1.3.1] - 2024-02-12

src/torchmetrics/collections.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -647,12 +647,11 @@ def plot(
647647
f"Expected argument `ax` to be a sequence of matplotlib axis objects with the same length as the "
648648
f"number of metrics in the collection, but got {type(ax)} with len {len(ax)} when `together=False`"
649649
)
650-
651650
val = val or self.compute()
652651
if together:
653652
return plot_single_or_multi_val(val, ax=ax)
654653
fig_axs = []
655-
for i, (k, m) in enumerate(self.items(keep_base=True, copy_state=False)):
654+
for i, (k, m) in enumerate(self.items(keep_base=False, copy_state=False)):
656655
if isinstance(val, dict):
657656
f, a = m.plot(val[k], ax=ax[i] if ax is not None else ax)
658657
elif isinstance(val, Sequence):

src/torchmetrics/functional/classification/confusion_matrix.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -285,21 +285,13 @@ def _multiclass_confusion_matrix_tensor_validation(
285285
" and `preds` should be (N, C, ...)."
286286
)
287287

288-
num_unique_values = len(torch.unique(target))
289-
check = num_unique_values > num_classes if ignore_index is None else num_unique_values > num_classes + 1
290-
if check:
291-
raise RuntimeError(
292-
"Detected more unique values in `target` than `num_classes`. Expected only "
293-
f"{num_classes if ignore_index is None else num_classes + 1} but found "
294-
f"{num_unique_values} in `target`."
295-
)
296-
297-
if not preds.is_floating_point():
298-
num_unique_values = len(torch.unique(preds))
299-
if num_unique_values > num_classes:
288+
check_value = num_classes if ignore_index is None else num_classes + 1
289+
for t, name in ((target, "target"),) + ((preds, "preds"),) if not preds.is_floating_point() else (): # noqa: RUF005
290+
num_unique_values = len(torch.unique(t))
291+
if num_unique_values > check_value:
300292
raise RuntimeError(
301-
"Detected more unique values in `preds` than `num_classes`. Expected only "
302-
f"{num_classes} but found {num_unique_values} in `preds`."
293+
f"Detected more unique values in `{name}` than expected. Expected only {check_value} but found"
294+
f" {num_unique_values} in `target`."
303295
)
304296

305297

src/torchmetrics/functional/classification/stat_scores.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -304,21 +304,13 @@ def _multiclass_stat_scores_tensor_validation(
304304
" and `preds` should be (N, C, ...)."
305305
)
306306

307-
num_unique_values = len(torch.unique(target))
308-
check = num_unique_values > num_classes if ignore_index is None else num_unique_values > num_classes + 1
309-
if check:
310-
raise RuntimeError(
311-
"Detected more unique values in `target` than `num_classes`. Expected only"
312-
f" {num_classes if ignore_index is None else num_classes + 1} but found"
313-
f" {num_unique_values} in `target`."
314-
)
315-
316-
if not preds.is_floating_point():
317-
unique_values = torch.unique(preds)
318-
if len(unique_values) > num_classes:
307+
check_value = num_classes if ignore_index is None else num_classes + 1
308+
for t, name in ((target, "target"),) + ((preds, "preds"),) if not preds.is_floating_point() else (): # noqa: RUF005
309+
num_unique_values = len(torch.unique(t))
310+
if num_unique_values > check_value:
319311
raise RuntimeError(
320-
"Detected more unique values in `preds` than `num_classes`. Expected only"
321-
f" {num_classes} but found {len(unique_values)} in `preds`."
312+
f"Detected more unique values in `{name}` than expected. Expected only {check_value} but found"
313+
f" {num_unique_values} in `target`."
322314
)
323315

324316

tests/unittests/classification/test_confusion_matrix.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,41 @@ def test_multiclass_confusion_matrix_dtype_gpu(self, inputs, dtype):
239239
)
240240

241241

242+
@pytest.mark.parametrize(
243+
("preds", "target", "ignore_index", "error_message"),
244+
[
245+
(
246+
torch.randint(NUM_CLASSES + 1, (100,)),
247+
torch.randint(NUM_CLASSES, (100,)),
248+
None,
249+
f"Detected more unique values in `preds` than expected. Expected only {NUM_CLASSES}.*",
250+
),
251+
(
252+
torch.randint(NUM_CLASSES, (100,)),
253+
torch.randint(NUM_CLASSES + 1, (100,)),
254+
None,
255+
f"Detected more unique values in `target` than expected. Expected only {NUM_CLASSES}.*",
256+
),
257+
(
258+
torch.randint(NUM_CLASSES + 2, (100,)),
259+
torch.randint(NUM_CLASSES, (100,)),
260+
1,
261+
f"Detected more unique values in `preds` than expected. Expected only {NUM_CLASSES + 1}.*",
262+
),
263+
(
264+
torch.randint(NUM_CLASSES, (100,)),
265+
torch.randint(NUM_CLASSES + 2, (100,)),
266+
1,
267+
f"Detected more unique values in `target` than expected. Expected only {NUM_CLASSES + 1}.*",
268+
),
269+
],
270+
)
271+
def test_raises_error_on_too_many_classes(preds, target, ignore_index, error_message):
272+
"""Test that an error is raised if the number of classes in preds or target is larger than expected."""
273+
with pytest.raises(RuntimeError, match=error_message):
274+
multiclass_confusion_matrix(preds, target, num_classes=NUM_CLASSES, ignore_index=ignore_index)
275+
276+
242277
def test_multiclass_overflow():
243278
"""Test that multiclass computations does not overflow even on byte inputs."""
244279
preds = torch.randint(20, (100,)).byte()

tests/unittests/classification/test_stat_scores.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,41 @@ def test_multiclass_stat_scores_dtype_gpu(self, inputs, dtype):
325325
)
326326

327327

328+
@pytest.mark.parametrize(
329+
("preds", "target", "ignore_index", "error_message"),
330+
[
331+
(
332+
torch.randint(NUM_CLASSES + 1, (100,)),
333+
torch.randint(NUM_CLASSES, (100,)),
334+
None,
335+
f"Detected more unique values in `preds` than expected. Expected only {NUM_CLASSES}.*",
336+
),
337+
(
338+
torch.randint(NUM_CLASSES, (100,)),
339+
torch.randint(NUM_CLASSES + 1, (100,)),
340+
None,
341+
f"Detected more unique values in `target` than expected. Expected only {NUM_CLASSES}.*",
342+
),
343+
(
344+
torch.randint(NUM_CLASSES + 2, (100,)),
345+
torch.randint(NUM_CLASSES, (100,)),
346+
1,
347+
f"Detected more unique values in `preds` than expected. Expected only {NUM_CLASSES + 1}.*",
348+
),
349+
(
350+
torch.randint(NUM_CLASSES, (100,)),
351+
torch.randint(NUM_CLASSES + 2, (100,)),
352+
1,
353+
f"Detected more unique values in `target` than expected. Expected only {NUM_CLASSES + 1}.*",
354+
),
355+
],
356+
)
357+
def test_raises_error_on_too_many_classes(preds, target, ignore_index, error_message):
358+
"""Test that an error is raised if the number of classes in preds or target is larger than expected."""
359+
with pytest.raises(RuntimeError, match=error_message):
360+
multiclass_stat_scores(preds, target, num_classes=NUM_CLASSES, ignore_index=ignore_index)
361+
362+
328363
_mc_k_target = torch.tensor([0, 1, 2])
329364
_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
330365

tests/unittests/utilities/test_plot.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -834,12 +834,17 @@ def test_confusion_matrix_plotter(metric_class, preds, target, labels, use_label
834834

835835
@pytest.mark.parametrize("together", [True, False])
836836
@pytest.mark.parametrize("num_vals", [1, 2])
837-
def test_plot_method_collection(together, num_vals):
837+
@pytest.mark.parametrize(
838+
("prefix", "postfix"), [(None, None), ("prefix", None), (None, "postfix"), ("prefix", "postfix")]
839+
)
840+
def test_plot_method_collection(together, num_vals, prefix, postfix):
838841
"""Test the plot method of metric collection."""
839842
m_collection = MetricCollection(
840843
BinaryAccuracy(),
841844
BinaryPrecision(),
842845
BinaryRecall(),
846+
prefix=prefix,
847+
postfix=postfix,
843848
)
844849
if num_vals == 1:
845850
m_collection.update(torch.randint(0, 2, size=(10,)), torch.randint(0, 2, size=(10,)))

0 commit comments

Comments
 (0)