Skip to content

Commit

Permalink
Warning on many detections in MeanAveragePrecision (#1978)
Browse files Browse the repository at this point in the history
(cherry picked from commit a4f67f3)
  • Loading branch information
SkafteNicki authored and Borda committed Aug 8, 2023
1 parent a483720 commit 80fa056
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 7 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Added warning to `MeanAveragePrecision` if too many detections are observed ([#1978](https://github.com/Lightning-AI/torchmetrics/pull/1978))


### Changed
Expand Down
22 changes: 20 additions & 2 deletions src/torchmetrics/detection/mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from torchmetrics.detection.helpers import _fix_empty_tensors, _input_validator
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.imports import (
_MATPLOTLIB_AVAILABLE,
_PYCOCOTOOLS_AVAILABLE,
Expand Down Expand Up @@ -238,6 +239,8 @@ class MeanAveragePrecision(Metric):
groundtruth_crowds: List[Tensor]
groundtruth_area: List[Tensor]

warn_on_many_detections: bool = True

def __init__(
self,
box_format: Literal["xyxy", "xywh", "cxcywh"] = "xyxy",
Expand Down Expand Up @@ -327,7 +330,7 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]
_input_validator(preds, target, iou_type=self.iou_type)

for item in preds:
detections = self._get_safe_item_values(item)
detections = self._get_safe_item_values(item, warn=self.warn_on_many_detections)

self.detections.append(detections)
self.detection_labels.append(item["labels"])
Expand Down Expand Up @@ -540,11 +543,12 @@ def tm_to_coco(self, name: str = "tm_map_input") -> None:
with open(f"{name}_target.json", "w") as f:
f.write(target_json)

def _get_safe_item_values(self, item: Dict[str, Any]) -> Union[Tensor, Tuple]:
def _get_safe_item_values(self, item: Dict[str, Any], warn: bool = False) -> Union[Tensor, Tuple]:
"""Convert and return the boxes or masks from the item depending on the iou_type.
Args:
item: input dictionary containing the boxes or masks
warn: whether to warn if the number of boxes or masks exceeds the max_detection_thresholds
Returns:
boxes or masks depending on the iou_type
Expand All @@ -554,12 +558,16 @@ def _get_safe_item_values(self, item: Dict[str, Any]) -> Union[Tensor, Tuple]:
boxes = _fix_empty_tensors(item["boxes"])
if boxes.numel() > 0:
boxes = box_convert(boxes, in_fmt=self.box_format, out_fmt="xywh")
if warn and len(boxes) > self.max_detection_thresholds[-1]:
_warning_on_too_many_detections(self.max_detection_thresholds[-1])
return boxes
if self.iou_type == "segm":
masks = []
for i in item["masks"].cpu().numpy():
rle = mask_utils.encode(np.asfortranarray(i))
masks.append((tuple(rle["size"]), rle["counts"]))
if warn and len(masks) > self.max_detection_thresholds[-1]:
_warning_on_too_many_detections(self.max_detection_thresholds[-1])
return tuple(masks)
raise Exception(f"IOU type {self.iou_type} is not supported")

Expand Down Expand Up @@ -741,3 +749,13 @@ def _gather_tuple_list(list_to_gather: List[Tuple], process_group: Optional[Any]
dist.all_gather_object(list_gathered, list_to_gather, group=process_group)

return [list_gathered[rank][idx] for idx in range(len(list_gathered[0])) for rank in range(world_size)]


def _warning_on_too_many_detections(limit: int) -> None:
rank_zero_warn(
f"Encountered more than {limit} detections in a single image. This means that certain detections with the"
" lowest scores will be ignored, that may have an undesirable impact on performance. Please consider adjusting"
" the `max_detection_threshold` to suit your use case. To disable this warning, set attribute class"
" `warn_on_many_detections=False`, after initializing the metric.",
UserWarning,
)
28 changes: 24 additions & 4 deletions tests/unittests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,19 +622,19 @@ def test_error_on_wrong_input():
)


def _generate_random_segm_input(device):
def _generate_random_segm_input(device, batch_size=2, num_preds_size=10, num_gt_size=10, random_size=True):
"""Generate random inputs for mAP when iou_type=segm."""
preds = []
targets = []
for _ in range(2):
for _ in range(batch_size):
result = {}
num_preds = torch.randint(0, 10, (1,)).item()
num_preds = torch.randint(0, num_preds_size, (1,)).item() if random_size else num_preds_size
result["scores"] = torch.rand((num_preds,), device=device)
result["labels"] = torch.randint(0, 10, (num_preds,), device=device)
result["masks"] = torch.randint(0, 2, (num_preds, 10, 10), device=device).bool()
preds.append(result)
gt = {}
num_gt = torch.randint(0, 10, (1,)).item()
num_gt = torch.randint(0, num_gt_size, (1,)).item() if random_size else num_gt_size
gt["labels"] = torch.randint(0, 10, (num_gt,), device=device)
gt["masks"] = torch.randint(0, 2, (num_gt, 10, 10), device=device).bool()
targets.append(gt)
Expand Down Expand Up @@ -683,3 +683,23 @@ def test_for_box_format(box_format, iou_val_expected, map_val_expected):
result = metric.compute()
assert result["map"].item() == map_val_expected
assert round(float(metric.coco_eval.ious[(0, 0)]), 3) == iou_val_expected


@pytest.mark.parametrize("iou_type", ["bbox", "segm"])
def test_warning_on_many_detections(iou_type):
"""Test that a warning is raised when there are many detections."""
if iou_type == "bbox":
preds = [
{
"boxes": torch.tensor([[0.5, 0.5, 1, 1]]).repeat(101, 1),
"scores": torch.tensor([1.0]).repeat(101),
"labels": torch.tensor([0]).repeat(101),
}
]
targets = [{"boxes": torch.tensor([[0, 0, 1, 1]]), "labels": torch.tensor([0])}]
else:
preds, targets = _generate_random_segm_input("cpu", 1, 101, 10, False)

metric = MeanAveragePrecision(iou_type=iou_type)
with pytest.warns(UserWarning, match="Encountered more than 100 detections in a single image.*"):
metric.update(preds, targets)

0 comments on commit 80fa056

Please sign in to comment.