diff --git a/CHANGELOG.md b/CHANGELOG.md index 3fd763a19af..d03efe15340 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `top_k` argument to `RetrievalMRR` in retrieval package ([#1961](https://github.com/Lightning-AI/torchmetrics/pull/1961)) +- Added support for evaluating `"segm"` and `"bbox"` detection in `MeanAveragePrecision` at the same time ([#1928](https://github.com/Lightning-AI/torchmetrics/pull/1928)) + + - Added support for multioutput evaluation in `MeanSquaredError` ([#1937](https://github.com/Lightning-AI/torchmetrics/pull/1937)) diff --git a/src/torchmetrics/detection/helpers.py b/src/torchmetrics/detection/helpers.py index c86787992f3..f3681545b96 100644 --- a/src/torchmetrics/detection/helpers.py +++ b/src/torchmetrics/detection/helpers.py @@ -11,21 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Sequence +from typing import Dict, Literal, Sequence, Tuple, Union from torch import Tensor def _input_validator( - preds: Sequence[Dict[str, Tensor]], targets: Sequence[Dict[str, Tensor]], iou_type: str = "bbox" + preds: Sequence[Dict[str, Tensor]], + targets: Sequence[Dict[str, Tensor]], + iou_type: Union[Literal["bbox", "segm"], Tuple[Literal["bbox", "segm"]]] = "bbox", ) -> None: """Ensure the correct input format of `preds` and `targets`.""" - if iou_type == "bbox": - item_val_name = "boxes" - elif iou_type == "segm": - item_val_name = "masks" - else: + if isinstance(iou_type, str): + iou_type = (iou_type,) + + name_map = {"bbox": "boxes", "segm": "masks"} + if any(tp not in name_map for tp in iou_type): raise Exception(f"IOU type {iou_type} is not supported") + item_val_name = [name_map[tp] for tp in iou_type] if not isinstance(preds, Sequence): raise ValueError(f"Expected argument `preds` to be of type Sequence, but got {preds}") @@ -36,38 +39,42 @@ def _input_validator( f"Expected argument `preds` and `target` to have the same length, but got {len(preds)} and {len(targets)}" ) - for k in [item_val_name, "scores", "labels"]: + for k in [*item_val_name, "scores", "labels"]: if any(k not in p for p in preds): raise ValueError(f"Expected all dicts in `preds` to contain the `{k}` key") - for k in [item_val_name, "labels"]: + for k in [*item_val_name, "labels"]: if any(k not in p for p in targets): raise ValueError(f"Expected all dicts in `target` to contain the `{k}` key") - if any(type(pred[item_val_name]) is not Tensor for pred in preds): - raise ValueError(f"Expected all {item_val_name} in `preds` to be of type Tensor") + for ivn in item_val_name: + if any(type(pred[ivn]) is not Tensor for pred in preds): + raise ValueError(f"Expected all {ivn} in `preds` to be of type Tensor") if any(type(pred["scores"]) is not Tensor for pred in preds): raise ValueError("Expected all scores in `preds` to be of type Tensor") if any(type(pred["labels"]) is not Tensor for pred in preds): raise ValueError("Expected all labels in `preds` to be of type Tensor") - if any(type(target[item_val_name]) is not Tensor for target in targets): - raise ValueError(f"Expected all {item_val_name} in `target` to be of type Tensor") + for ivn in item_val_name: + if any(type(target[ivn]) is not Tensor for target in targets): + raise ValueError(f"Expected all {ivn} in `target` to be of type Tensor") if any(type(target["labels"]) is not Tensor for target in targets): raise ValueError("Expected all labels in `target` to be of type Tensor") for i, item in enumerate(targets): - if item[item_val_name].size(0) != item["labels"].size(0): - raise ValueError( - f"Input {item_val_name} and labels of sample {i} in targets have a" - f" different length (expected {item[item_val_name].size(0)} labels, got {item['labels'].size(0)})" - ) + for ivn in item_val_name: + if item[ivn].size(0) != item["labels"].size(0): + raise ValueError( + f"Input '{ivn}' and labels of sample {i} in targets have a" + f" different length (expected {item[ivn].size(0)} labels, got {item['labels'].size(0)})" + ) for i, item in enumerate(preds): - if not (item[item_val_name].size(0) == item["labels"].size(0) == item["scores"].size(0)): - raise ValueError( - f"Input {item_val_name}, labels and scores of sample {i} in predictions have a" - f" different length (expected {item[item_val_name].size(0)} labels and scores," - f" got {item['labels'].size(0)} labels and {item['scores'].size(0)})" - ) + for ivn in item_val_name: + if not (item[ivn].size(0) == item["labels"].size(0) == item["scores"].size(0)): + raise ValueError( + f"Input '{ivn}', labels and scores of sample {i} in predictions have a" + f" different length (expected {item[ivn].size(0)} labels and scores," + f" got {item['labels'].size(0)} labels and {item['scores'].size(0)})" + ) def _fix_empty_tensors(boxes: Tensor) -> Tensor: @@ -75,3 +82,14 @@ def _fix_empty_tensors(boxes: Tensor) -> Tensor: if boxes.numel() == 0 and boxes.ndim == 1: return boxes.unsqueeze(0) return boxes + + +def _validate_iou_type_arg(iou_type: Union[Literal["bbox", "segm"], Tuple[str]] = "bbox") -> Tuple[str]: + allowed_iou_types = ("segm", "bbox") + if isinstance(iou_type, str): + iou_type = (iou_type,) + if any(tp not in allowed_iou_types for tp in iou_type): + raise ValueError( + f"Expected argument `iou_type` to be one of {allowed_iou_types} or a list of, but got {iou_type}" + ) + return iou_type diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 9190c4135af..4f545e40f0e 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -22,7 +22,7 @@ from torch import distributed as dist from typing_extensions import Literal -from torchmetrics.detection.helpers import _fix_empty_tensors, _input_validator +from torchmetrics.detection.helpers import _fix_empty_tensors, _input_validator, _validate_iou_type_arg from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import apply_to_collection @@ -85,7 +85,8 @@ class MeanAveragePrecision(Metric): - boxes: (:class:`~torch.FloatTensor`) of shape ``(num_boxes, 4)`` containing ``num_boxes`` detection boxes of the format specified in the constructor. - By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates. + By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates, but can be changed + using the ``box_format`` parameter. Only required when `iou_type="bbox"`. - scores: :class:`~torch.FloatTensor` of shape ``(num_boxes)`` containing detection scores for the boxes. - labels: :class:`~torch.IntTensor` of shape ``(num_boxes)`` containing 0-indexed detection classes for the boxes. @@ -96,7 +97,7 @@ class MeanAveragePrecision(Metric): (each dictionary corresponds to a single image). Parameters that should be provided per dict: - boxes: :class:`~torch.FloatTensor` of shape ``(num_boxes, 4)`` containing ``num_boxes`` ground truth - boxes of the format specified in the constructor. + boxes of the format specified in the constructor. only required when `iou_type="bbox"`. By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates. - labels: :class:`~torch.IntTensor` of shape ``(num_boxes)`` containing 0-indexed ground truth classes for the boxes. @@ -138,7 +139,6 @@ class MeanAveragePrecision(Metric): .. note:: ``map`` score is calculated with @[ IoU=self.iou_thresholds | area=all | max_dets=max_detection_thresholds ]. Caution: If the initialization parameters are changed, dictionary keys for mAR can change as well. - The default properties are also accessible via fields and will raise an ``AttributeError`` if not available. .. note:: This metric utilizes the official `pycocotools` implementation as its backend. This means that the metric @@ -157,8 +157,8 @@ class MeanAveragePrecision(Metric): width and height. iou_type: - Type of input (either masks or bounding-boxes) used for computing IOU. - Supported IOU types are ``["bbox", "segm"]``. If using ``"segm"``, masks should be provided in input. + Type of input (either masks or bounding-boxes) used for computing IOU. Supported IOU types are + ``"bbox"`` or ``"segm"`` or both as a tuple. iou_thresholds: IoU thresholds for evaluation. If set to ``None`` it corresponds to the stepped range ``[0.5,...,0.95]`` with step ``0.05``. Else provide a list of floats. @@ -206,7 +206,11 @@ class MeanAveragePrecision(Metric): ValueError: If ``class_metrics`` is not a boolean - Example: + Example:: + + Basic example for when `iou_type="bbox"`. In this case the ``boxes`` key is required in the input dictionaries, + in addition to the ``scores`` and ``labels`` keys. + >>> from torch import tensor >>> from torchmetrics.detection import MeanAveragePrecision >>> preds = [ @@ -222,7 +226,7 @@ class MeanAveragePrecision(Metric): ... labels=tensor([0]), ... ) ... ] - >>> metric = MeanAveragePrecision() + >>> metric = MeanAveragePrecision(iou_type="bbox") >>> metric.update(preds, target) >>> from pprint import pprint >>> pprint(metric.compute()) @@ -242,6 +246,60 @@ class MeanAveragePrecision(Metric): 'mar_medium': tensor(-1.), 'mar_small': tensor(-1.)} + Example:: + + Basic example for when `iou_type="segm"`. In this case the ``masks`` key is required in the input dictionaries, + in addition to the ``scores`` and ``labels`` keys. + + >>> from torch import tensor + >>> from torchmetrics.detection import MeanAveragePrecision + >>> mask_pred = [ + ... [0, 0, 0, 0, 0], + ... [0, 0, 1, 1, 0], + ... [0, 0, 1, 1, 0], + ... [0, 0, 0, 0, 0], + ... [0, 0, 0, 0, 0], + ... ] + >>> mask_tgt = [ + ... [0, 0, 0, 0, 0], + ... [0, 0, 1, 0, 0], + ... [0, 0, 1, 1, 0], + ... [0, 0, 1, 0, 0], + ... [0, 0, 0, 0, 0], + ... ] + >>> preds = [ + ... dict( + ... masks=tensor([mask_pred], dtype=torch.bool), + ... scores=tensor([0.536]), + ... labels=tensor([0]), + ... ) + ... ] + >>> target = [ + ... dict( + ... masks=tensor([mask_tgt], dtype=torch.bool), + ... labels=tensor([0]), + ... ) + ... ] + >>> metric = MeanAveragePrecision(iou_type="segm") + >>> metric.update(preds, target) + >>> from pprint import pprint + >>> pprint(metric.compute()) + {'classes': tensor(0, dtype=torch.int32), + 'map': tensor(0.2000), + 'map_50': tensor(1.), + 'map_75': tensor(0.), + 'map_large': tensor(-1.), + 'map_medium': tensor(-1.), + 'map_per_class': tensor(-1.), + 'map_small': tensor(0.2000), + 'mar_1': tensor(0.2000), + 'mar_10': tensor(0.2000), + 'mar_100': tensor(0.2000), + 'mar_100_per_class': tensor(-1.), + 'mar_large': tensor(-1.), + 'mar_medium': tensor(-1.), + 'mar_small': tensor(0.2000)} + """ is_differentiable: bool = False higher_is_better: Optional[bool] = True @@ -249,10 +307,12 @@ class MeanAveragePrecision(Metric): plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 - detections: List[Tensor] + detection_box: List[Tensor] + detection_mask: List[Tensor] detection_scores: List[Tensor] detection_labels: List[Tensor] - groundtruths: List[Tensor] + groundtruth_box: List[Tensor] + groundtruth_mask: List[Tensor] groundtruth_labels: List[Tensor] groundtruth_crowds: List[Tensor] groundtruth_area: List[Tensor] @@ -262,7 +322,7 @@ class MeanAveragePrecision(Metric): def __init__( self, box_format: Literal["xyxy", "xywh", "cxcywh"] = "xyxy", - iou_type: Literal["bbox", "segm"] = "bbox", + iou_type: Union[Literal["bbox", "segm"], Tuple[str]] = "bbox", iou_thresholds: Optional[List[float]] = None, rec_thresholds: Optional[List[float]] = None, max_detection_thresholds: Optional[List[int]] = None, @@ -288,10 +348,7 @@ def __init__( raise ValueError(f"Expected argument `box_format` to be one of {allowed_box_formats} but got {box_format}") self.box_format = box_format - allowed_iou_types = ("segm", "bbox") - if iou_type not in allowed_iou_types: - raise ValueError(f"Expected argument `iou_type` to be one of {allowed_iou_types} but got {iou_type}") - self.iou_type = iou_type + self.iou_type = _validate_iou_type_arg(iou_type) if iou_thresholds is not None and not isinstance(iou_thresholds, list): raise ValueError( @@ -321,10 +378,12 @@ def __init__( raise ValueError("Expected argument `extended_summary` to be a boolean") self.extended_summary = extended_summary - self.add_state("detections", default=[], dist_reduce_fx=None) + self.add_state("detection_box", default=[], dist_reduce_fx=None) + self.add_state("detection_mask", default=[], dist_reduce_fx=None) self.add_state("detection_scores", default=[], dist_reduce_fx=None) self.add_state("detection_labels", default=[], dist_reduce_fx=None) - self.add_state("groundtruths", default=[], dist_reduce_fx=None) + self.add_state("groundtruth_box", default=[], dist_reduce_fx=None) + self.add_state("groundtruth_mask", default=[], dist_reduce_fx=None) self.add_state("groundtruth_labels", default=[], dist_reduce_fx=None) self.add_state("groundtruth_crowds", default=[], dist_reduce_fx=None) self.add_state("groundtruth_area", default=[], dist_reduce_fx=None) @@ -354,15 +413,20 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]] _input_validator(preds, target, iou_type=self.iou_type) for item in preds: - detections = self._get_safe_item_values(item, warn=self.warn_on_many_detections) - - self.detections.append(detections) + bbox_detection, mask_detection = self._get_safe_item_values(item, warn=self.warn_on_many_detections) + if bbox_detection is not None: + self.detection_box.append(bbox_detection) + if mask_detection is not None: + self.detection_mask.append(mask_detection) self.detection_labels.append(item["labels"]) self.detection_scores.append(item["scores"]) for item in target: - groundtruths = self._get_safe_item_values(item) - self.groundtruths.append(groundtruths) + bbox_groundtruth, mask_groundtruth = self._get_safe_item_values(item) + if bbox_groundtruth is not None: + self.groundtruth_box.append(bbox_groundtruth) + if mask_groundtruth is not None: + self.groundtruth_mask.append(mask_groundtruth) self.groundtruth_labels.append(item["labels"]) self.groundtruth_crowds.append(item.get("iscrowd", torch.zeros_like(item["labels"]))) self.groundtruth_area.append(item.get("area", torch.zeros_like(item["labels"]))) @@ -372,77 +436,107 @@ def compute(self) -> dict: coco_target, coco_preds = COCO(), COCO() coco_target.dataset = self._get_coco_format( - self.groundtruths, self.groundtruth_labels, crowds=self.groundtruth_crowds, area=self.groundtruth_area + labels=self.groundtruth_labels, + boxes=self.groundtruth_box if len(self.groundtruth_box) > 0 else None, + masks=self.groundtruth_mask if len(self.groundtruth_mask) > 0 else None, + crowds=self.groundtruth_crowds, + area=self.groundtruth_area, + ) + coco_preds.dataset = self._get_coco_format( + labels=self.detection_labels, + boxes=self.detection_box if len(self.detection_box) > 0 else None, + masks=self.detection_mask if len(self.detection_mask) > 0 else None, + scores=self.detection_scores, ) - coco_preds.dataset = self._get_coco_format(self.detections, self.detection_labels, scores=self.detection_scores) + result_dict = {} with contextlib.redirect_stdout(io.StringIO()): coco_target.createIndex() coco_preds.createIndex() - coco_eval = COCOeval(coco_target, coco_preds, iouType=self.iou_type) - coco_eval.params.iouThrs = np.array(self.iou_thresholds, dtype=np.float64) - coco_eval.params.recThrs = np.array(self.rec_thresholds, dtype=np.float64) - coco_eval.params.maxDets = self.max_detection_thresholds - - coco_eval.evaluate() - coco_eval.accumulate() - coco_eval.summarize() - stats = coco_eval.stats - - summary = {} - if self.extended_summary: - summary = { - "ious": apply_to_collection(coco_eval.ious, np.ndarray, lambda x: torch.tensor(x, dtype=torch.float32)), - "precision": torch.tensor(coco_eval.eval["precision"]), # precision has shape (TxRxKxAxM) - "recall": torch.tensor(coco_eval.eval["recall"]), # recall has shape (TxKxAxM) - } - - # if class mode is enabled, evaluate metrics per class - if self.class_metrics: - map_per_class_list = [] - mar_100_per_class_list = [] - for class_id in self._get_classes(): - coco_eval.params.catIds = [class_id] - with contextlib.redirect_stdout(io.StringIO()): - coco_eval.evaluate() - coco_eval.accumulate() - coco_eval.summarize() - class_stats = coco_eval.stats - - map_per_class_list.append(torch.tensor([class_stats[0]])) - mar_100_per_class_list.append(torch.tensor([class_stats[8]])) - - map_per_class_values = torch.tensor(map_per_class_list, dtype=torch.float32) - mar_100_per_class_values = torch.tensor(mar_100_per_class_list, dtype=torch.float32) - else: - map_per_class_values = torch.tensor([-1], dtype=torch.float32) - mar_100_per_class_values = torch.tensor([-1], dtype=torch.float32) + for i_type in self.iou_type: + prefix = "" if len(self.iou_type) == 1 else f"{i_type}_" + if len(self.iou_type) > 1: + # the area calculation is different for bbox and segm and therefore to get the small, medium and + # large values correct we need to dynamically change the area attribute of the annotations + for anno in coco_preds.dataset["annotations"]: + anno["area"] = anno[f"area_{i_type}"] + + coco_eval = COCOeval(coco_target, coco_preds, iouType=i_type) + coco_eval.params.iouThrs = np.array(self.iou_thresholds, dtype=np.float64) + coco_eval.params.recThrs = np.array(self.rec_thresholds, dtype=np.float64) + coco_eval.params.maxDets = self.max_detection_thresholds + + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + stats = coco_eval.stats + result_dict.update(self._coco_stats_to_tensor_dict(stats, prefix=prefix)) + + summary = {} + if self.extended_summary: + summary = { + f"{prefix}ious": apply_to_collection( + coco_eval.ious, np.ndarray, lambda x: torch.tensor(x, dtype=torch.float32) + ), + f"{prefix}precision": torch.tensor(coco_eval.eval["precision"]), + f"{prefix}recall": torch.tensor(coco_eval.eval["recall"]), + } + result_dict.update(summary) + + # if class mode is enabled, evaluate metrics per class + if self.class_metrics: + map_per_class_list = [] + mar_100_per_class_list = [] + for class_id in self._get_classes(): + coco_eval.params.catIds = [class_id] + with contextlib.redirect_stdout(io.StringIO()): + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + class_stats = coco_eval.stats + + map_per_class_list.append(torch.tensor([class_stats[0]])) + mar_100_per_class_list.append(torch.tensor([class_stats[8]])) + + map_per_class_values = torch.tensor(map_per_class_list, dtype=torch.float32) + mar_100_per_class_values = torch.tensor(mar_100_per_class_list, dtype=torch.float32) + else: + map_per_class_values = torch.tensor([-1], dtype=torch.float32) + mar_100_per_class_values = torch.tensor([-1], dtype=torch.float32) + prefix = "" if len(self.iou_type) == 1 else f"{i_type}_" + result_dict.update( + { + f"{prefix}map_per_class": map_per_class_values, + f"{prefix}mar_100_per_class": mar_100_per_class_values, + }, + ) + result_dict.update({"classes": torch.tensor(self._get_classes(), dtype=torch.int32)}) + + return result_dict + @staticmethod + def _coco_stats_to_tensor_dict(stats: List[float], prefix: str) -> Dict[str, Tensor]: return { - "map": torch.tensor([stats[0]], dtype=torch.float32), - "map_50": torch.tensor([stats[1]], dtype=torch.float32), - "map_75": torch.tensor([stats[2]], dtype=torch.float32), - "map_small": torch.tensor([stats[3]], dtype=torch.float32), - "map_medium": torch.tensor([stats[4]], dtype=torch.float32), - "map_large": torch.tensor([stats[5]], dtype=torch.float32), - "mar_1": torch.tensor([stats[6]], dtype=torch.float32), - "mar_10": torch.tensor([stats[7]], dtype=torch.float32), - "mar_100": torch.tensor([stats[8]], dtype=torch.float32), - "mar_small": torch.tensor([stats[9]], dtype=torch.float32), - "mar_medium": torch.tensor([stats[10]], dtype=torch.float32), - "mar_large": torch.tensor([stats[11]], dtype=torch.float32), - "map_per_class": map_per_class_values, - "mar_100_per_class": mar_100_per_class_values, - "classes": torch.tensor(self._get_classes(), dtype=torch.int32), - **summary, + f"{prefix}map": torch.tensor([stats[0]], dtype=torch.float32), + f"{prefix}map_50": torch.tensor([stats[1]], dtype=torch.float32), + f"{prefix}map_75": torch.tensor([stats[2]], dtype=torch.float32), + f"{prefix}map_small": torch.tensor([stats[3]], dtype=torch.float32), + f"{prefix}map_medium": torch.tensor([stats[4]], dtype=torch.float32), + f"{prefix}map_large": torch.tensor([stats[5]], dtype=torch.float32), + f"{prefix}mar_1": torch.tensor([stats[6]], dtype=torch.float32), + f"{prefix}mar_10": torch.tensor([stats[7]], dtype=torch.float32), + f"{prefix}mar_100": torch.tensor([stats[8]], dtype=torch.float32), + f"{prefix}mar_small": torch.tensor([stats[9]], dtype=torch.float32), + f"{prefix}mar_medium": torch.tensor([stats[10]], dtype=torch.float32), + f"{prefix}mar_large": torch.tensor([stats[11]], dtype=torch.float32), } @staticmethod def coco_to_tm( coco_preds: str, coco_target: str, - iou_type: Literal["bbox", "segm"] = "bbox", + iou_type: Union[Literal["bbox", "segm"], List[str]] = "bbox", ) -> Tuple[List[Dict[str, Tensor]], List[Dict[str, Tensor]]]: """Utility function for converting .json coco format files to the input format of this metric. @@ -470,6 +564,8 @@ def coco_to_tm( ... ) # doctest: +SKIP """ + iou_type = _validate_iou_type_arg(iou_type) + with contextlib.redirect_stdout(io.StringIO()): gt = COCO(coco_target) dt = gt.loadRes(coco_preds) @@ -481,14 +577,18 @@ def coco_to_tm( for t in gt_dataset: if t["image_id"] not in target: target[t["image_id"]] = { - "boxes" if iou_type == "bbox" else "masks": [], "labels": [], "iscrowd": [], "area": [], } - if iou_type == "bbox": + if "bbox" in iou_type: + target[t["image_id"]]["boxes"] = [] + if "segm" in iou_type: + target[t["image_id"]]["masks"] = [] + + if "bbox" in iou_type: target[t["image_id"]]["boxes"].append(t["bbox"]) - else: + if "segm" in iou_type: target[t["image_id"]]["masks"].append(gt.annToMask(t)) target[t["image_id"]]["labels"].append(t["category_id"]) target[t["image_id"]]["iscrowd"].append(t["iscrowd"]) @@ -497,39 +597,47 @@ def coco_to_tm( preds = {} for p in dt_dataset: if p["image_id"] not in preds: - preds[p["image_id"]] = {"boxes" if iou_type == "bbox" else "masks": [], "scores": [], "labels": []} - if iou_type == "bbox": + preds[p["image_id"]] = {"scores": [], "labels": []} + if "bbox" in iou_type: + preds[p["image_id"]]["boxes"] = [] + if "segm" in iou_type: + preds[p["image_id"]]["masks"] = [] + if "bbox" in iou_type: preds[p["image_id"]]["boxes"].append(p["bbox"]) - else: + if "segm" in iou_type: preds[p["image_id"]]["masks"].append(gt.annToMask(p)) preds[p["image_id"]]["scores"].append(p["score"]) preds[p["image_id"]]["labels"].append(p["category_id"]) for k in target: # add empty predictions for images without predictions if k not in preds: - preds[k] = {"boxes" if iou_type == "bbox" else "masks": [], "scores": [], "labels": []} + preds[k] = {"scores": [], "labels": []} + if "bbox" in iou_type: + preds[k]["boxes"] = [] + if "segm" in iou_type: + preds[k]["masks"] = [] batched_preds, batched_target = [], [] for key in target: - name = "boxes" if iou_type == "bbox" else "masks" - batched_preds.append( - { - name: torch.tensor(np.array(preds[key]["boxes"]), dtype=torch.float32) - if iou_type == "bbox" - else torch.tensor(np.array(preds[key]["masks"]), dtype=torch.uint8), - "scores": torch.tensor(preds[key]["scores"], dtype=torch.float32), - "labels": torch.tensor(preds[key]["labels"], dtype=torch.int32), - } - ) - batched_target.append( - { - name: torch.tensor(target[key]["boxes"], dtype=torch.float32) - if iou_type == "bbox" - else torch.tensor(np.array(target[key]["masks"]), dtype=torch.uint8), - "labels": torch.tensor(target[key]["labels"], dtype=torch.int32), - "iscrowd": torch.tensor(target[key]["iscrowd"], dtype=torch.int32), - "area": torch.tensor(target[key]["area"], dtype=torch.float32), - } - ) + bp = { + "scores": torch.tensor(preds[key]["scores"], dtype=torch.float32), + "labels": torch.tensor(preds[key]["labels"], dtype=torch.int32), + } + if "bbox" in iou_type: + bp["boxes"] = torch.tensor(np.array(preds[key]["boxes"]), dtype=torch.float32) + if "segm" in iou_type: + bp["masks"] = torch.tensor(np.array(preds[key]["masks"]), dtype=torch.uint8) + batched_preds.append(bp) + + bt = { + "labels": torch.tensor(target[key]["labels"], dtype=torch.int32), + "iscrowd": torch.tensor(target[key]["iscrowd"], dtype=torch.int32), + "area": torch.tensor(target[key]["area"], dtype=torch.float32), + } + if "bbox" in iou_type: + bt["boxes"] = torch.tensor(target[key]["boxes"], dtype=torch.float32) + if "segm" in iou_type: + bt["masks"] = torch.tensor(np.array(target[key]["masks"]), dtype=torch.uint8) + batched_target.append(bt) return batched_preds, batched_target @@ -564,8 +672,16 @@ def tm_to_coco(self, name: str = "tm_map_input") -> None: >>> metric.tm_to_coco("tm_map_input") # doctest: +SKIP """ - target_dataset = self._get_coco_format(self.groundtruths, self.groundtruth_labels) - preds_dataset = self._get_coco_format(self.detections, self.detection_labels, self.detection_scores) + target_dataset = self._get_coco_format( + labels=self.groundtruth_labels, + boxes=self.groundtruth_box, + masks=self.groundtruth_mask, + crowds=self.groundtruth_crowds, + area=self.groundtruth_area, + ) + preds_dataset = self._get_coco_format( + labels=self.detection_labels, boxes=self.detection_box, masks=self.detection_mask + ) preds_json = json.dumps(preds_dataset["annotations"], indent=4) target_json = json.dumps(target_dataset, indent=4) @@ -576,7 +692,9 @@ def tm_to_coco(self, name: str = "tm_map_input") -> None: with open(f"{name}_target.json", "w") as f: f.write(target_json) - def _get_safe_item_values(self, item: Dict[str, Any], warn: bool = False) -> Union[Tensor, Tuple]: + def _get_safe_item_values( + self, item: Dict[str, Any], warn: bool = False + ) -> Tuple[Optional[Tensor], Optional[Tuple]]: """Convert and return the boxes or masks from the item depending on the iou_type. Args: @@ -587,22 +705,23 @@ def _get_safe_item_values(self, item: Dict[str, Any], warn: bool = False) -> Uni boxes or masks depending on the iou_type """ - if self.iou_type == "bbox": + output = [None, None] + if "bbox" in self.iou_type: boxes = _fix_empty_tensors(item["boxes"]) if boxes.numel() > 0: boxes = box_convert(boxes, in_fmt=self.box_format, out_fmt="xywh") - if warn and len(boxes) > self.max_detection_thresholds[-1]: - _warning_on_too_many_detections(self.max_detection_thresholds[-1]) - return boxes - if self.iou_type == "segm": + output[0] = boxes + if "segm" in self.iou_type: masks = [] for i in item["masks"].cpu().numpy(): rle = mask_utils.encode(np.asfortranarray(i)) masks.append((tuple(rle["size"]), rle["counts"])) - if warn and len(masks) > self.max_detection_thresholds[-1]: - _warning_on_too_many_detections(self.max_detection_thresholds[-1]) - return tuple(masks) - raise Exception(f"IOU type {self.iou_type} is not supported") + output[1] = tuple(masks) + if (output[0] is not None and len(output[0]) > self.max_detection_thresholds[-1]) or ( + output[1] is not None and len(output[1]) > self.max_detection_thresholds[-1] + ): + _warning_on_too_many_detections(self.max_detection_thresholds[-1]) + return output def _get_classes(self) -> List: """Return a list of unique classes found in ground truth and detection data.""" @@ -612,8 +731,9 @@ def _get_classes(self) -> List: def _get_coco_format( self, - boxes: List[torch.Tensor], labels: List[torch.Tensor], + boxes: Optional[List[torch.Tensor]] = None, + masks: Optional[List[torch.Tensor]] = None, scores: Optional[List[torch.Tensor]] = None, crowds: Optional[List[torch.Tensor]] = None, area: Optional[List[torch.Tensor]] = None, @@ -628,20 +748,28 @@ def _get_coco_format( annotations = [] annotation_id = 1 # has to start with 1, otherwise COCOEval results are wrong - for image_id, (image_boxes, image_labels) in enumerate(zip(boxes, labels)): - if self.iou_type == "segm" and len(image_boxes) == 0: - continue - - if self.iou_type == "bbox": + for image_id, image_labels in enumerate(labels): + if boxes is not None: + image_boxes = boxes[image_id] image_boxes = image_boxes.cpu().tolist() + if masks is not None: + image_masks = masks[image_id] + if len(image_masks) == 0 and boxes is None: + continue image_labels = image_labels.cpu().tolist() images.append({"id": image_id}) - if self.iou_type == "segm": - images[-1]["height"], images[-1]["width"] = image_boxes[0][0][0], image_boxes[0][0][1] + if "segm" in self.iou_type and len(image_masks) > 0: + images[-1]["height"], images[-1]["width"] = image_masks[0][0][0], image_masks[0][0][1] - for k, (image_box, image_label) in enumerate(zip(image_boxes, image_labels)): - if self.iou_type == "bbox" and len(image_box) != 4: + for k, image_label in enumerate(image_labels): + if boxes is not None: + image_box = image_boxes[k] + if masks is not None and len(image_masks) > 0: + image_mask = image_masks[k] + image_mask = {"size": image_mask[0], "counts": image_mask[1]} + + if "bbox" in self.iou_type and len(image_box) != 4: raise ValueError( f"Invalid input box of sample {image_id}, element {k} (expected 4 values, got {len(image_box)})" ) @@ -652,21 +780,31 @@ def _get_coco_format( f" (expected value of type integer, got type {type(image_label)})" ) - stat = image_box if self.iou_type == "bbox" else {"size": image_box[0], "counts": image_box[1]} - + area_stat_box = None + area_stat_mask = None if area is not None and area[image_id][k].cpu().tolist() > 0: area_stat = area[image_id][k].cpu().tolist() else: - area_stat = image_box[2] * image_box[3] if self.iou_type == "bbox" else mask_utils.area(stat) + area_stat = mask_utils.area(image_mask) if "segm" in self.iou_type else image_box[2] * image_box[3] + if len(self.iou_type) > 1: + area_stat_box = image_box[2] * image_box[3] + area_stat_mask = mask_utils.area(image_mask) annotation = { "id": annotation_id, "image_id": image_id, - "bbox" if self.iou_type == "bbox" else "segmentation": stat, "area": area_stat, "category_id": image_label, "iscrowd": crowds[image_id][k].cpu().tolist() if crowds is not None else 0, } + if area_stat_box is not None: + annotation["area_bbox"] = area_stat_box + annotation["area_segm"] = area_stat_mask + + if boxes is not None: + annotation["bbox"] = image_box + if masks is not None: + annotation["segmentation"] = image_mask if scores is not None: score = scores[image_id][k].cpu().tolist() @@ -752,7 +890,7 @@ def _apply(self, fn: Callable) -> torch.nn.Module: # type: ignore[override] no longer a tensor but a tuple. """ - return super()._apply(fn, exclude_state=("detections", "groundtruths") if self.iou_type == "segm" else "") + return super()._apply(fn, exclude_state=("detection_mask", "groundtruth_mask")) def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Optional[Any] = None) -> None: """Custom sync function. @@ -763,9 +901,9 @@ def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Opt """ super()._sync_dist(dist_sync_fn=dist_sync_fn, process_group=process_group) - if self.iou_type == "segm": - self.detections = self._gather_tuple_list(self.detections, process_group) - self.groundtruths = self._gather_tuple_list(self.groundtruths, process_group) + if "segm" in self.iou_type: + self.detection_mask = self._gather_tuple_list(self.detection_mask, process_group) + self.groundtruth_mask = self._gather_tuple_list(self.groundtruth_mask, process_group) @staticmethod def _gather_tuple_list(list_to_gather: List[Tuple], process_group: Optional[Any] = None) -> List[Any]: diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index 61da40ecf82..de3b5805254 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -13,6 +13,7 @@ # limitations under the License. import contextlib import io +import json from collections import namedtuple from copy import deepcopy from functools import partial @@ -55,7 +56,7 @@ def _generate_coco_inputs(iou_type): _coco_segm_input = _generate_coco_inputs("segm") -def _compare_again_coco_fn(preds, target, iou_type, iou_thresholds=None, rec_thresholds=None, class_metrics=True): +def _compare_against_coco_fn(preds, target, iou_type, iou_thresholds=None, rec_thresholds=None, class_metrics=True): """Taken from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb.""" with contextlib.redirect_stdout(io.StringIO()): gt = COCO(_DETECTION_VAL) @@ -130,7 +131,7 @@ def test_map(self, iou_type, iou_thresholds, rec_thresholds, ddp): target=target, metric_class=MeanAveragePrecision, reference_metric=partial( - _compare_again_coco_fn, + _compare_against_coco_fn, iou_type=iou_type, iou_thresholds=iou_thresholds, rec_thresholds=rec_thresholds, @@ -159,13 +160,49 @@ def test_map_classwise(self, iou_type, ddp): preds=preds, target=target, metric_class=MeanAveragePrecision, - reference_metric=partial(_compare_again_coco_fn, iou_type=iou_type, class_metrics=True), + reference_metric=partial(_compare_against_coco_fn, iou_type=iou_type, class_metrics=True), metric_args={"box_format": "xywh", "iou_type": iou_type, "class_metrics": True}, check_batch=False, atol=1e-1, ) +def test_compare_both_same_time(tmpdir): + """Test that the class support evaluating both bbox and segm at the same time.""" + with open(_DETECTION_BBOX) as f: + boxes = json.load(f) + with open(_DETECTION_SEGM) as f: + segmentations = json.load(f) + combined = [{**box, **seg} for box, seg in zip(boxes, segmentations)] + with open(f"{tmpdir}/combined.json", "w") as f: + json.dump(combined, f) + batched_preds, batched_target = MeanAveragePrecision.coco_to_tm( + f"{tmpdir}/combined.json", _DETECTION_VAL, iou_type=["bbox", "segm"] + ) + batched_preds = [batched_preds[10 * i : 10 * (i + 1)] for i in range(10)] + batched_target = [batched_target[10 * i : 10 * (i + 1)] for i in range(10)] + + metric = MeanAveragePrecision(iou_type=["bbox", "segm"], box_format="xywh") + for bp, bt in zip(batched_preds, batched_target): + metric.update(bp, bt) + res = metric.compute() + + res1 = _compare_against_coco_fn([], [], iou_type="bbox", class_metrics=False) + res2 = _compare_against_coco_fn([], [], iou_type="segm", class_metrics=False) + + for k, v in res1.items(): + if k == "classes": + continue + assert f"bbox_{k}" in res + assert torch.allclose(res[f"bbox_{k}"], v, atol=1e-2) + + for k, v in res2.items(): + if k == "classes": + continue + assert f"segm_{k}" in res + assert torch.allclose(res[f"segm_{k}"], v, atol=1e-2) + + Input = namedtuple("Input", ["preds", "target"])