Skip to content

Commit

Permalink
Merge branch 'master' into fix_lpips_functional_device_dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Nov 24, 2023
2 parents cafecb2 + 9c964ec commit 3bb305c
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for Pytorch v2.1 ([#2142](https://github.com/Lightning-AI/torchmetrics/pull/2142))


- Added confidence scores when `extended_summary=True` in `MeanAveragePrecision` ([#2212](https://github.com/Lightning-AI/torchmetrics/pull/2212))


- Added support for logging `MultiTaskWrapper` directly with lightnings `log_dict` method ([#2213](https://github.com/Lightning-AI/torchmetrics/pull/2213))


Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/detection/mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ class MeanAveragePrecision(Metric):
- ``recall``: a tensor of shape ``(TxKxAxM)`` containing the recall values. Here ``T`` is the number of
IoU thresholds, ``K`` is the number of classes, ``A`` is the number of areas and ``M`` is the number
of max detections per image.
- ``scores``: a tensor of shape ``(TxRxKxAxM)`` containing the confidence scores. Here ``T`` is the
number of IoU thresholds, ``R`` is the number of recall thresholds, ``K`` is the number of classes,
``A`` is the number of areas and ``M`` is the number of max detections per image.
average:
Method for averaging scores over labels. Choose between "``"macro"`` and ``"micro"``.
Expand Down Expand Up @@ -531,6 +534,7 @@ def compute(self) -> dict:
),
f"{prefix}precision": torch.tensor(coco_eval.eval["precision"]),
f"{prefix}recall": torch.tensor(coco_eval.eval["recall"]),
f"{prefix}scores": torch.tensor(coco_eval.eval["scores"]),
}
result_dict.update(summary)

Expand Down
10 changes: 8 additions & 2 deletions tests/unittests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ def test_warning_on_many_detections(self, iou_type, backend):
metric.update(preds, targets)

@pytest.mark.parametrize(
("preds", "target", "expected_iou_len", "iou_keys", "precision_shape", "recall_shape"),
("preds", "target", "expected_iou_len", "iou_keys", "precision_shape", "recall_shape", "scores_shape"),
[
(
[
Expand All @@ -758,6 +758,7 @@ def test_warning_on_many_detections(self, iou_type, backend):
[(0, 0)],
(10, 101, 1, 4, 3),
(10, 1, 4, 3),
(10, 101, 1, 4, 3),
),
(
_inputs["preds"],
Expand All @@ -766,11 +767,12 @@ def test_warning_on_many_detections(self, iou_type, backend):
list(product([0, 1, 2, 3], [0, 1, 2, 3, 4, 49])),
(10, 101, 6, 4, 3),
(10, 6, 4, 3),
(10, 101, 6, 4, 3),
),
],
)
def test_for_extended_stats(
self, preds, target, expected_iou_len, iou_keys, precision_shape, recall_shape, backend
self, preds, target, expected_iou_len, iou_keys, precision_shape, recall_shape, scores_shape, backend
):
"""Test that extended stats are computed correctly."""
metric = MeanAveragePrecision(extended_summary=True, backend=backend)
Expand All @@ -793,6 +795,10 @@ def test_for_extended_stats(
assert isinstance(recall, Tensor)
assert recall.shape == recall_shape

scores = result["scores"]
assert isinstance(scores, Tensor)
assert scores.shape == scores_shape

@pytest.mark.parametrize("class_metrics", [False, True])
def test_average_argument(self, class_metrics, backend):
"""Test that average argument works.
Expand Down

0 comments on commit 3bb305c

Please sign in to comment.