Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warning on many detections in MeanAveragePrecision #1978

Merged
merged 4 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `top_k` argument to `RetrievalMRR` in retrieval package ([#1961](https://github.com/Lightning-AI/torchmetrics/pull/1961))


- Added warning to `MeanAveragePrecision` if too many detections are observed ([#1978](https://github.com/Lightning-AI/torchmetrics/pull/1978))


### Changed

-
Expand Down
22 changes: 20 additions & 2 deletions src/torchmetrics/detection/mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from torchmetrics.detection.helpers import _fix_empty_tensors, _input_validator
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.imports import (
_MATPLOTLIB_AVAILABLE,
_PYCOCOTOOLS_AVAILABLE,
Expand Down Expand Up @@ -239,6 +240,8 @@ class MeanAveragePrecision(Metric):
groundtruth_crowds: List[Tensor]
groundtruth_area: List[Tensor]

warn_on_many_detections: bool = True

def __init__(
self,
box_format: Literal["xyxy", "xywh", "cxcywh"] = "xyxy",
Expand Down Expand Up @@ -329,7 +332,7 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]
_input_validator(preds, target, iou_type=self.iou_type)

for item in preds:
detections = self._get_safe_item_values(item)
detections = self._get_safe_item_values(item, warn=self.warn_on_many_detections)

self.detections.append(detections)
self.detection_labels.append(item["labels"])
Expand Down Expand Up @@ -542,11 +545,12 @@ def tm_to_coco(self, name: str = "tm_map_input") -> None:
with open(f"{name}_target.json", "w") as f:
f.write(target_json)

def _get_safe_item_values(self, item: Dict[str, Any]) -> Union[Tensor, Tuple]:
def _get_safe_item_values(self, item: Dict[str, Any], warn: bool = False) -> Union[Tensor, Tuple]:
"""Convert and return the boxes or masks from the item depending on the iou_type.

Args:
item: input dictionary containing the boxes or masks
warn: whether to warn if the number of boxes or masks exceeds the max_detection_thresholds

Returns:
boxes or masks depending on the iou_type
Expand All @@ -556,12 +560,16 @@ def _get_safe_item_values(self, item: Dict[str, Any]) -> Union[Tensor, Tuple]:
boxes = _fix_empty_tensors(item["boxes"])
if boxes.numel() > 0:
boxes = box_convert(boxes, in_fmt=self.box_format, out_fmt="xywh")
if warn and len(boxes) > self.max_detection_thresholds[-1]:
_warning_on_too_many_detections(self.max_detection_thresholds[-1])
return boxes
if self.iou_type == "segm":
masks = []
for i in item["masks"].cpu().numpy():
rle = mask_utils.encode(np.asfortranarray(i))
masks.append((tuple(rle["size"]), rle["counts"]))
if warn and len(masks) > self.max_detection_thresholds[-1]:
_warning_on_too_many_detections(self.max_detection_thresholds[-1])
return tuple(masks)
raise Exception(f"IOU type {self.iou_type} is not supported")

Expand Down Expand Up @@ -747,3 +755,13 @@ def _gather_tuple_list(list_to_gather: List[Tuple], process_group: Optional[Any]
dist.all_gather_object(list_gathered, list_to_gather, group=process_group)

return [list_gathered[rank][idx] for idx in range(len(list_gathered[0])) for rank in range(world_size)]


def _warning_on_too_many_detections(limit: int) -> None:
rank_zero_warn(
f"Encountered more than {limit} detections in a single image. This means that certain detections with the"
" lowest scores will be ignored, that may have an undesirable impact on performance. Please consider adjusting"
" the `max_detection_threshold` to suit your use case. To disable this warning, set attribute class"
" `warn_on_many_detections=False`, after initializing the metric.",
UserWarning,
)
28 changes: 24 additions & 4 deletions tests/unittests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,19 +628,19 @@ def test_error_on_wrong_input():
)


def _generate_random_segm_input(device):
def _generate_random_segm_input(device, batch_size=2, num_preds_size=10, num_gt_size=10, random_size=True):
"""Generate random inputs for mAP when iou_type=segm."""
preds = []
targets = []
for _ in range(2):
for _ in range(batch_size):
result = {}
num_preds = torch.randint(0, 10, (1,)).item()
num_preds = torch.randint(0, num_preds_size, (1,)).item() if random_size else num_preds_size
result["scores"] = torch.rand((num_preds,), device=device)
result["labels"] = torch.randint(0, 10, (num_preds,), device=device)
result["masks"] = torch.randint(0, 2, (num_preds, 10, 10), device=device).bool()
preds.append(result)
gt = {}
num_gt = torch.randint(0, 10, (1,)).item()
num_gt = torch.randint(0, num_gt_size, (1,)).item() if random_size else num_gt_size
gt["labels"] = torch.randint(0, 10, (num_gt,), device=device)
gt["masks"] = torch.randint(0, 2, (num_gt, 10, 10), device=device).bool()
targets.append(gt)
Expand Down Expand Up @@ -690,3 +690,23 @@ def test_for_box_format(box_format, iou_val_expected, map_val_expected):
result = metric.compute()
assert result["map"].item() == map_val_expected
assert round(float(metric.coco_eval.ious[(0, 0)]), 3) == iou_val_expected


@pytest.mark.parametrize("iou_type", ["bbox", "segm"])
def test_warning_on_many_detections(iou_type):
"""Test that a warning is raised when there are many detections."""
if iou_type == "bbox":
preds = [
{
"boxes": torch.tensor([[0.5, 0.5, 1, 1]]).repeat(101, 1),
"scores": torch.tensor([1.0]).repeat(101),
"labels": torch.tensor([0]).repeat(101),
}
]
targets = [{"boxes": torch.tensor([[0, 0, 1, 1]]), "labels": torch.tensor([0])}]
else:
preds, targets = _generate_random_segm_input("cpu", 1, 101, 10, False)

metric = MeanAveragePrecision(iou_type=iou_type)
with pytest.warns(UserWarning, match="Encountered more than 100 detections in a single image.*"):
metric.update(preds, targets)
Loading