Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Aug 6, 2023
1 parent 5c0e083 commit f167c04
Showing 1 changed file with 24 additions and 4 deletions.
28 changes: 24 additions & 4 deletions tests/unittests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,19 +628,19 @@ def test_error_on_wrong_input():
)


def _generate_random_segm_input(device):
def _generate_random_segm_input(device, batch_size=2, num_preds_size=10, num_gt_size=10, random_size=True):
"""Generate random inputs for mAP when iou_type=segm."""
preds = []
targets = []
for _ in range(2):
for _ in range(batch_size):
result = {}
num_preds = torch.randint(0, 10, (1,)).item()
num_preds = torch.randint(0, num_preds_size, (1,)).item() if random_size else num_preds_size
result["scores"] = torch.rand((num_preds,), device=device)
result["labels"] = torch.randint(0, 10, (num_preds,), device=device)
result["masks"] = torch.randint(0, 2, (num_preds, 10, 10), device=device).bool()
preds.append(result)
gt = {}
num_gt = torch.randint(0, 10, (1,)).item()
num_gt = torch.randint(0, num_gt_size, (1,)).item() if random_size else num_gt_size
gt["labels"] = torch.randint(0, 10, (num_gt,), device=device)
gt["masks"] = torch.randint(0, 2, (num_gt, 10, 10), device=device).bool()
targets.append(gt)
Expand Down Expand Up @@ -690,3 +690,23 @@ def test_for_box_format(box_format, iou_val_expected, map_val_expected):
result = metric.compute()
assert result["map"].item() == map_val_expected
assert round(float(metric.coco_eval.ious[(0, 0)]), 3) == iou_val_expected


@pytest.mark.parametrize("iou_type", ["bbox", "segm"])
def test_warning_on_many_detections(iou_type):
"""Test that a warning is raised when there are many detections."""
if iou_type == "bbox":
preds = [
{
"boxes": torch.tensor([[0.5, 0.5, 1, 1]]).repeat(101, 1),
"scores": torch.tensor([1.0]).repeat(101),
"labels": torch.tensor([0]).repeat(101),
}
]
targets = [{"boxes": torch.tensor([[0, 0, 1, 1]]), "labels": torch.tensor([0])}]
else:
preds, targets = _generate_random_segm_input("cpu", 1, 101, 10, False)

metric = MeanAveragePrecision(iou_type=iou_type)
with pytest.warns(UserWarning, match="Encountered more than 100 detections in a single image.*"):
metric.update(preds, targets)

0 comments on commit f167c04

Please sign in to comment.