Skip to content

Commit 9253717

Browse files
authored
Fix naming of statistics in MeanAveragePrecision with custom max det thresholds (Lightning-AI#2367)
* fix src code + fix docs * fix tests
1 parent dce2368 commit 9253717

File tree

3 files changed

+33
-42
lines changed

3 files changed

+33
-42
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3636
- Fixed cached network in `FeatureShare` not being moved to the correct device ([#2348](https://github.com/Lightning-AI/torchmetrics/pull/2348))
3737

3838

39+
- Fix naming of statistics in `MeanAveragePrecision` with custom max det thresholds ([#2367](https://github.com/Lightning-AI/torchmetrics/pull/2367))
40+
41+
3942
- Fixed custom aggregation in retrieval metrics ([#2364](https://github.com/Lightning-AI/torchmetrics/pull/2364))
4043

4144

src/torchmetrics/detection/mean_ap.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,12 @@ class MeanAveragePrecision(Metric):
128128
- map_small: (:class:`~torch.Tensor`), mean average precision for small objects
129129
- map_medium:(:class:`~torch.Tensor`), mean average precision for medium objects
130130
- map_large: (:class:`~torch.Tensor`), mean average precision for large objects
131-
- mar_1: (:class:`~torch.Tensor`), mean average recall for 1 detection per image
132-
- mar_10: (:class:`~torch.Tensor`), mean average recall for 10 detections per image
133-
- mar_100: (:class:`~torch.Tensor`), mean average recall for 100 detections per image
131+
- mar_{mdt[0]}: (:class:`~torch.Tensor`), mean average recall for `max_detection_thresholds[0]` (default 1)
132+
detection per image
133+
- mar_{mdt[1]}: (:class:`~torch.Tensor`), mean average recall for `max_detection_thresholds[1]` (default 10)
134+
detection per image
135+
- mar_{mdt[1]}: (:class:`~torch.Tensor`), mean average recall for `max_detection_thresholds[2]` (default 100)
136+
detection per image
134137
- mar_small: (:class:`~torch.Tensor`), mean average recall for small objects
135138
- mar_medium: (:class:`~torch.Tensor`), mean average recall for medium objects
136139
- mar_large: (:class:`~torch.Tensor`), mean average recall for large objects
@@ -140,8 +143,8 @@ class MeanAveragePrecision(Metric):
140143
IoU=0.75
141144
- map_per_class: (:class:`~torch.Tensor`) (-1 if class metrics are disabled), mean average precision per
142145
observed class
143-
- mar_100_per_class: (:class:`~torch.Tensor`) (-1 if class metrics are disabled), mean average recall for 100
144-
detections per image per observed class
146+
- mar_{mdt[2]}_per_class: (:class:`~torch.Tensor`) (-1 if class metrics are disabled), mean average recall for
147+
`max_detection_thresholds[2]` (default 100) detections per image per observed class
145148
- classes (:class:`~torch.Tensor`), list of all observed classes
146149
147150
For an example on how to use this metric check the `torchmetrics mAP example`_.
@@ -184,8 +187,7 @@ class MeanAveragePrecision(Metric):
184187
with step ``0.01``. Else provide a list of floats.
185188
max_detection_thresholds:
186189
Thresholds on max detections per image. If set to `None` will use thresholds ``[1, 10, 100]``.
187-
Else, please provide a list of ints. If the `pycocotools` backend is used then the list needs to have
188-
length 3. If this is a problem, shift to `faster_coco_eval` which supports more detection thresholds.
190+
Else, please provide a list of ints of length 3, which is the only supported length by both backends.
189191
class_metrics:
190192
Option to enable per-class metrics for mAP and mAR_100. Has a performance impact that scales linearly with
191193
the number of classes in the dataset.
@@ -410,10 +412,10 @@ def __init__(
410412
f"Expected argument `max_detection_thresholds` to either be `None` or a list of ints"
411413
f" but got {max_detection_thresholds}"
412414
)
413-
if max_detection_thresholds is not None and backend == "pycocotools" and len(max_detection_thresholds) != 3:
415+
if max_detection_thresholds is not None and len(max_detection_thresholds) != 3:
414416
raise ValueError(
415-
"When using `pycocotools` backend the number of max detection thresholds should be 3 else"
416-
f" it will not work correctly with the backend. Got value {len(max_detection_thresholds)}."
417+
"When providing a list of max detection thresholds it should have length 3."
418+
" Got value {len(max_detection_thresholds)}"
417419
)
418420
max_det_thr, _ = torch.sort(torch.tensor(max_detection_thresholds or [1, 10, 100], dtype=torch.int))
419421
self.max_detection_thresholds = max_det_thr.tolist()
@@ -556,7 +558,7 @@ def compute(self) -> dict:
556558
coco_eval.params.maxDets = self.max_detection_thresholds
557559

558560
map_per_class_list = []
559-
mar_100_per_class_list = []
561+
mar_per_class_list = []
560562
for class_id in self._get_classes():
561563
coco_eval.params.catIds = [class_id]
562564
with contextlib.redirect_stdout(io.StringIO()):
@@ -566,18 +568,18 @@ def compute(self) -> dict:
566568
class_stats = coco_eval.stats
567569

568570
map_per_class_list.append(torch.tensor([class_stats[0]]))
569-
mar_100_per_class_list.append(torch.tensor([class_stats[8]]))
571+
mar_per_class_list.append(torch.tensor([class_stats[8]]))
570572

571573
map_per_class_values = torch.tensor(map_per_class_list, dtype=torch.float32)
572-
mar_100_per_class_values = torch.tensor(mar_100_per_class_list, dtype=torch.float32)
574+
mar_per_class_values = torch.tensor(mar_per_class_list, dtype=torch.float32)
573575
else:
574576
map_per_class_values = torch.tensor([-1], dtype=torch.float32)
575-
mar_100_per_class_values = torch.tensor([-1], dtype=torch.float32)
577+
mar_per_class_values = torch.tensor([-1], dtype=torch.float32)
576578
prefix = "" if len(self.iou_type) == 1 else f"{i_type}_"
577579
result_dict.update(
578580
{
579581
f"{prefix}map_per_class": map_per_class_values,
580-
f"{prefix}mar_100_per_class": mar_100_per_class_values,
582+
f"{prefix}mar_{self.max_detection_thresholds[-1]}_per_class": mar_per_class_values,
581583
},
582584
)
583585
result_dict.update({"classes": torch.tensor(self._get_classes(), dtype=torch.int32)})
@@ -616,19 +618,19 @@ def _get_coco_datasets(self, average: Literal["macro", "micro"]) -> Tuple[object
616618

617619
return coco_preds, coco_target
618620

619-
@staticmethod
620-
def _coco_stats_to_tensor_dict(stats: List[float], prefix: str) -> Dict[str, Tensor]:
621+
def _coco_stats_to_tensor_dict(self, stats: List[float], prefix: str) -> Dict[str, Tensor]:
621622
"""Converts the output of COCOeval.stats to a dict of tensors."""
623+
mdt = self.max_detection_thresholds
622624
return {
623625
f"{prefix}map": torch.tensor([stats[0]], dtype=torch.float32),
624626
f"{prefix}map_50": torch.tensor([stats[1]], dtype=torch.float32),
625627
f"{prefix}map_75": torch.tensor([stats[2]], dtype=torch.float32),
626628
f"{prefix}map_small": torch.tensor([stats[3]], dtype=torch.float32),
627629
f"{prefix}map_medium": torch.tensor([stats[4]], dtype=torch.float32),
628630
f"{prefix}map_large": torch.tensor([stats[5]], dtype=torch.float32),
629-
f"{prefix}mar_1": torch.tensor([stats[6]], dtype=torch.float32),
630-
f"{prefix}mar_10": torch.tensor([stats[7]], dtype=torch.float32),
631-
f"{prefix}mar_100": torch.tensor([stats[8]], dtype=torch.float32),
631+
f"{prefix}mar_{mdt[0]}": torch.tensor([stats[6]], dtype=torch.float32),
632+
f"{prefix}mar_{mdt[1]}": torch.tensor([stats[7]], dtype=torch.float32),
633+
f"{prefix}mar_{mdt[2]}": torch.tensor([stats[8]], dtype=torch.float32),
632634
f"{prefix}mar_small": torch.tensor([stats[9]], dtype=torch.float32),
633635
f"{prefix}mar_medium": torch.tensor([stats[10]], dtype=torch.float32),
634636
f"{prefix}mar_large": torch.tensor([stats[11]], dtype=torch.float32),

tests/unittests/detection/test_map.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,10 @@ def test_many_detection_thresholds(self, backend):
861861
else:
862862
assert round(res["map"].item(), 5) == 0.6
863863

864+
assert "mar_1" in res
865+
assert "mar_10" in res
866+
assert "mar_1000" in res
867+
864868
@pytest.mark.parametrize("max_detection_thresholds", [[1, 10], [1, 10, 50, 100]])
865869
def test_with_more_and_less_detection_thresholds(self, max_detection_thresholds, backend):
866870
"""Test how metric is working when list of max detection thresholds is not 3.
@@ -869,25 +873,7 @@ def test_with_more_and_less_detection_thresholds(self, max_detection_thresholds,
869873
https://github.com/ppwwyyxx/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py#L461
870874
871875
"""
872-
preds = [
873-
{
874-
"boxes": torch.tensor([[258.0, 41.0, 606.0, 285.0]]),
875-
"scores": torch.tensor([0.536]),
876-
"labels": torch.tensor([0]),
877-
}
878-
]
879-
target = [
880-
{
881-
"boxes": torch.tensor([[214.0, 41.0, 562.0, 285.0]]),
882-
"labels": torch.tensor([0]),
883-
}
884-
]
885-
886-
if backend == "pycocotools":
887-
with pytest.raises(
888-
ValueError, match="When using `pycocotools` backend the number of max detection thresholds should.*"
889-
):
890-
metric = MeanAveragePrecision(max_detection_thresholds=max_detection_thresholds, backend=backend)
891-
else:
892-
metric = MeanAveragePrecision(max_detection_thresholds=max_detection_thresholds, backend=backend)
893-
metric(preds, target)
876+
with pytest.raises(
877+
ValueError, match="When providing a list of max detection thresholds it should have length 3.*"
878+
):
879+
MeanAveragePrecision(max_detection_thresholds=max_detection_thresholds, backend=backend)

0 commit comments

Comments
 (0)