Skip to content

Commit 9c964ec

Browse files
authored
Add confidence scores when extended_summary=True in MeanAveragePrecision (#2212)
1 parent 9dd199a commit 9c964ec

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2222
- Added support for Pytorch v2.1 ([#2142](https://github.com/Lightning-AI/torchmetrics/pull/2142))
2323

2424

25+
- Added confidence scores when `extended_summary=True` in `MeanAveragePrecision` ([#2212](https://github.com/Lightning-AI/torchmetrics/pull/2212))
26+
27+
2528
- Added support for logging `MultiTaskWrapper` directly with lightnings `log_dict` method ([#2213](https://github.com/Lightning-AI/torchmetrics/pull/2213))
2629

2730

src/torchmetrics/detection/mean_ap.py

+4
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,9 @@ class MeanAveragePrecision(Metric):
202202
- ``recall``: a tensor of shape ``(TxKxAxM)`` containing the recall values. Here ``T`` is the number of
203203
IoU thresholds, ``K`` is the number of classes, ``A`` is the number of areas and ``M`` is the number
204204
of max detections per image.
205+
- ``scores``: a tensor of shape ``(TxRxKxAxM)`` containing the confidence scores. Here ``T`` is the
206+
number of IoU thresholds, ``R`` is the number of recall thresholds, ``K`` is the number of classes,
207+
``A`` is the number of areas and ``M`` is the number of max detections per image.
205208
206209
average:
207210
Method for averaging scores over labels. Choose between "``"macro"`` and ``"micro"``.
@@ -531,6 +534,7 @@ def compute(self) -> dict:
531534
),
532535
f"{prefix}precision": torch.tensor(coco_eval.eval["precision"]),
533536
f"{prefix}recall": torch.tensor(coco_eval.eval["recall"]),
537+
f"{prefix}scores": torch.tensor(coco_eval.eval["scores"]),
534538
}
535539
result_dict.update(summary)
536540

tests/unittests/detection/test_map.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,7 @@ def test_warning_on_many_detections(self, iou_type, backend):
741741
metric.update(preds, targets)
742742

743743
@pytest.mark.parametrize(
744-
("preds", "target", "expected_iou_len", "iou_keys", "precision_shape", "recall_shape"),
744+
("preds", "target", "expected_iou_len", "iou_keys", "precision_shape", "recall_shape", "scores_shape"),
745745
[
746746
(
747747
[
@@ -758,6 +758,7 @@ def test_warning_on_many_detections(self, iou_type, backend):
758758
[(0, 0)],
759759
(10, 101, 1, 4, 3),
760760
(10, 1, 4, 3),
761+
(10, 101, 1, 4, 3),
761762
),
762763
(
763764
_inputs["preds"],
@@ -766,11 +767,12 @@ def test_warning_on_many_detections(self, iou_type, backend):
766767
list(product([0, 1, 2, 3], [0, 1, 2, 3, 4, 49])),
767768
(10, 101, 6, 4, 3),
768769
(10, 6, 4, 3),
770+
(10, 101, 6, 4, 3),
769771
),
770772
],
771773
)
772774
def test_for_extended_stats(
773-
self, preds, target, expected_iou_len, iou_keys, precision_shape, recall_shape, backend
775+
self, preds, target, expected_iou_len, iou_keys, precision_shape, recall_shape, scores_shape, backend
774776
):
775777
"""Test that extended stats are computed correctly."""
776778
metric = MeanAveragePrecision(extended_summary=True, backend=backend)
@@ -793,6 +795,10 @@ def test_for_extended_stats(
793795
assert isinstance(recall, Tensor)
794796
assert recall.shape == recall_shape
795797

798+
scores = result["scores"]
799+
assert isinstance(scores, Tensor)
800+
assert scores.shape == scores_shape
801+
796802
@pytest.mark.parametrize("class_metrics", [False, True])
797803
def test_average_argument(self, class_metrics, backend):
798804
"""Test that average argument works.

0 commit comments

Comments
 (0)