Skip to content

Commit 284205f

Browse files
SkafteNickiBorda
authored andcommitted
Fix classwise computation in IoU metric (#1924)
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit eeb40e9)
1 parent 1525a28 commit 284205f

File tree

17 files changed

+624
-650
lines changed

17 files changed

+624
-650
lines changed

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3232
- Fixed support for pixelwise MSE ([#2017](https://github.com/Lightning-AI/torchmetrics/pull/2017)
3333

3434

35+
- Fixed bug in detection intersection metrics when `class_metrics=True` resulting in wrong values ([#1924](https://github.com/Lightning-AI/torchmetrics/pull/1924))
36+
37+
3538
- Fixed missing attributes `higher_is_better`, `is_differentiable` for some metrics ([#2028](https://github.com/Lightning-AI/torchmetrics/pull/2028)
3639

40+
3741
## [1.1.0] - 2023-08-22
3842

3943
### Added

Makefile

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
export FREEZE_REQUIREMENTS=1
44
# assume you have installed need packages
55
export SPHINX_MOCK_REQUIREMENTS=1
6+
export SPHINX_FETCH_ASSETS=0
67

78
clean:
89
# clean all temp runs

src/torchmetrics/detection/ciou.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ class CompleteIntersectionOverUnion(IntersectionOverUnion):
3737
- ``boxes`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes, 4)`` containing ``num_boxes``
3838
detection boxes of the format specified in the constructor.
3939
By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates.
40-
- ``scores`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes)`` containing detection scores
41-
for the boxes.
4240
- ``labels`` (:class:`~torch.Tensor`): integer tensor of shape ``(num_boxes)`` containing 0-indexed detection
4341
classes for the boxes.
4442
@@ -48,14 +46,14 @@ class CompleteIntersectionOverUnion(IntersectionOverUnion):
4846
- ``boxes`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes, 4)`` containing ``num_boxes`` ground
4947
truth boxes of the format specified in the constructor.
5048
By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates.
51-
- ``labels`` (:class:`~torch.Tensor`): integer tensor of shape ``(num_boxes)`` containing 0-indexed ground truth
49+
- ``labels`` (:class:`~torch.Tensor`): integer tensor of shape ``(num_boxes)`` containing 0-indexed detection
5250
classes for the boxes.
5351
5452
As output of ``forward`` and ``compute`` the metric returns the following output:
5553
5654
- ``ciou_dict``: A dictionary containing the following key-values:
5755
58-
- ciou: (:class:`~torch.Tensor`)
56+
- ciou: (:class:`~torch.Tensor`) with overall ciou value over all classes and samples.
5957
- ciou/cl_{cl}: (:class:`~torch.Tensor`), if argument ``class_metrics=True``
6058
6159
Args:
@@ -65,6 +63,9 @@ class CompleteIntersectionOverUnion(IntersectionOverUnion):
6563
Optional IoU thresholds for evaluation. If set to `None` the threshold is ignored.
6664
class_metrics:
6765
Option to enable per-class metrics for IoU. Has a performance impact.
66+
respect_labels:
67+
Ignore values from boxes that do not have the same label as the ground truth box. Else will compute Iou
68+
between all pairs of boxes.
6869
kwargs:
6970
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
7071
@@ -86,7 +87,7 @@ class CompleteIntersectionOverUnion(IntersectionOverUnion):
8687
... ]
8788
>>> metric = CompleteIntersectionOverUnion()
8889
>>> metric(preds, target)
89-
{'ciou': tensor(-0.5694)}
90+
{'ciou': tensor(0.8611)}
9091
9192
Raises:
9293
ModuleNotFoundError:
@@ -105,14 +106,15 @@ def __init__(
105106
box_format: str = "xyxy",
106107
iou_threshold: Optional[float] = None,
107108
class_metrics: bool = False,
109+
respect_labels: bool = True,
108110
**kwargs: Any,
109111
) -> None:
110112
if not _TORCHVISION_GREATER_EQUAL_0_13:
111113
raise ModuleNotFoundError(
112114
f"Metric `{self._iou_type.upper()}` requires that `torchvision` version 0.13.0 or newer is installed."
113115
" Please install with `pip install torchvision>=0.13` or `pip install torchmetrics[detection]`."
114116
)
115-
super().__init__(box_format, iou_threshold, class_metrics, **kwargs)
117+
super().__init__(box_format, iou_threshold, class_metrics, respect_labels, **kwargs)
116118

117119
@staticmethod
118120
def _iou_update_fn(*args: Any, **kwargs: Any) -> Tensor:

src/torchmetrics/detection/diou.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ class DistanceIntersectionOverUnion(IntersectionOverUnion):
3737
- ``boxes`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes, 4)`` containing ``num_boxes``
3838
detection boxes of the format specified in the constructor.
3939
By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates.
40-
- ``scores`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes)`` containing detection scores
41-
for the boxes.
4240
- ``labels`` (:class:`~torch.Tensor`): integer tensor of shape ``(num_boxes)`` containing 0-indexed detection
4341
classes for the boxes.
4442
@@ -55,7 +53,7 @@ class DistanceIntersectionOverUnion(IntersectionOverUnion):
5553
5654
- ``diou_dict``: A dictionary containing the following key-values:
5755
58-
- diou: (:class:`~torch.Tensor`)
56+
- diou: (:class:`~torch.Tensor`) with overall diou value over all classes and samples.
5957
- diou/cl_{cl}: (:class:`~torch.Tensor`), if argument ``class_metrics=True``
6058
6159
Args:
@@ -65,6 +63,9 @@ class DistanceIntersectionOverUnion(IntersectionOverUnion):
6563
Optional IoU thresholds for evaluation. If set to `None` the threshold is ignored.
6664
class_metrics:
6765
Option to enable per-class metrics for IoU. Has a performance impact.
66+
respect_labels:
67+
Ignore values from boxes that do not have the same label as the ground truth box. Else will compute Iou
68+
between all pairs of boxes.
6869
kwargs:
6970
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
7071
@@ -86,7 +87,7 @@ class DistanceIntersectionOverUnion(IntersectionOverUnion):
8687
... ]
8788
>>> metric = DistanceIntersectionOverUnion()
8889
>>> metric(preds, target)
89-
{'diou': tensor(-0.0694)}
90+
{'diou': tensor(0.8611)}
9091
9192
Raises:
9293
ModuleNotFoundError:
@@ -105,14 +106,15 @@ def __init__(
105106
box_format: str = "xyxy",
106107
iou_threshold: Optional[float] = None,
107108
class_metrics: bool = False,
109+
respect_labels: bool = True,
108110
**kwargs: Any,
109111
) -> None:
110112
if not _TORCHVISION_GREATER_EQUAL_0_13:
111113
raise ModuleNotFoundError(
112114
f"Metric `{self._iou_type.upper()}` requires that `torchvision` version 0.13.0 or newer is installed."
113115
" Please install with `pip install torchvision>=0.13` or `pip install torchmetrics[detection]`."
114116
)
115-
super().__init__(box_format, iou_threshold, class_metrics, **kwargs)
117+
super().__init__(box_format, iou_threshold, class_metrics, respect_labels, **kwargs)
116118

117119
@staticmethod
118120
def _iou_update_fn(*args: Any, **kwargs: Any) -> Tensor:

src/torchmetrics/detection/giou.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ class GeneralizedIntersectionOverUnion(IntersectionOverUnion):
3737
- ``boxes`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes, 4)`` containing ``num_boxes``
3838
detection boxes of the format specified in the constructor.
3939
By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates.
40-
- ``scores`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes)`` containing detection scores
41-
for the boxes.
4240
- ``labels`` (:class:`~torch.Tensor`): integer tensor of shape ``(num_boxes)`` containing 0-indexed detection
4341
classes for the boxes.
4442
@@ -55,7 +53,7 @@ class GeneralizedIntersectionOverUnion(IntersectionOverUnion):
5553
5654
- ``giou_dict``: A dictionary containing the following key-values:
5755
58-
- giou: (:class:`~torch.Tensor`)
56+
- giou: (:class:`~torch.Tensor`) with overall giou value over all classes and samples.
5957
- giou/cl_{cl}: (:class:`~torch.Tensor`), if argument ``class metrics=True``
6058
6159
Args:
@@ -65,6 +63,9 @@ class GeneralizedIntersectionOverUnion(IntersectionOverUnion):
6563
Optional IoU thresholds for evaluation. If set to `None` the threshold is ignored.
6664
class_metrics:
6765
Option to enable per-class metrics for IoU. Has a performance impact.
66+
respect_labels:
67+
Ignore values from boxes that do not have the same label as the ground truth box. Else will compute Iou
68+
between all pairs of boxes.
6869
kwargs:
6970
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
7071
@@ -86,7 +87,7 @@ class GeneralizedIntersectionOverUnion(IntersectionOverUnion):
8687
... ]
8788
>>> metric = GeneralizedIntersectionOverUnion()
8889
>>> metric(preds, target)
89-
{'giou': tensor(-0.0694)}
90+
{'giou': tensor(0.8613)}
9091
9192
Raises:
9293
ModuleNotFoundError:
@@ -105,9 +106,10 @@ def __init__(
105106
box_format: str = "xyxy",
106107
iou_threshold: Optional[float] = None,
107108
class_metrics: bool = False,
109+
respect_labels: bool = True,
108110
**kwargs: Any,
109111
) -> None:
110-
super().__init__(box_format, iou_threshold, class_metrics, **kwargs)
112+
super().__init__(box_format, iou_threshold, class_metrics, respect_labels, **kwargs)
111113

112114
@staticmethod
113115
def _iou_update_fn(*args: Any, **kwargs: Any) -> Tensor:

src/torchmetrics/detection/helpers.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def _input_validator(
2020
preds: Sequence[Dict[str, Tensor]],
2121
targets: Sequence[Dict[str, Tensor]],
2222
iou_type: Union[Literal["bbox", "segm"], Tuple[Literal["bbox", "segm"]]] = "bbox",
23+
ignore_score: bool = False,
2324
) -> None:
2425
"""Ensure the correct input format of `preds` and `targets`."""
2526
if isinstance(iou_type, str):
@@ -39,7 +40,7 @@ def _input_validator(
3940
f"Expected argument `preds` and `target` to have the same length, but got {len(preds)} and {len(targets)}"
4041
)
4142

42-
for k in [*item_val_name, "scores", "labels"]:
43+
for k in [*item_val_name, "labels"] + (["scores"] if not ignore_score else []):
4344
if any(k not in p for p in preds):
4445
raise ValueError(f"Expected all dicts in `preds` to contain the `{k}` key")
4546

@@ -50,7 +51,7 @@ def _input_validator(
5051
for ivn in item_val_name:
5152
if any(type(pred[ivn]) is not Tensor for pred in preds):
5253
raise ValueError(f"Expected all {ivn} in `preds` to be of type Tensor")
53-
if any(type(pred["scores"]) is not Tensor for pred in preds):
54+
if not ignore_score and any(type(pred["scores"]) is not Tensor for pred in preds):
5455
raise ValueError("Expected all scores in `preds` to be of type Tensor")
5556
if any(type(pred["labels"]) is not Tensor for pred in preds):
5657
raise ValueError("Expected all labels in `preds` to be of type Tensor")
@@ -67,6 +68,8 @@ def _input_validator(
6768
f"Input '{ivn}' and labels of sample {i} in targets have a"
6869
f" different length (expected {item[ivn].size(0)} labels, got {item['labels'].size(0)})"
6970
)
71+
if ignore_score:
72+
return
7073
for i, item in enumerate(preds):
7174
for ivn in item_val_name:
7275
if not (item[ivn].size(0) == item["labels"].size(0) == item["scores"].size(0)):

0 commit comments

Comments
 (0)