Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MeanAveragePrecision box and segm at same time #1928

Merged
merged 26 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
0c8e1ec
add code
SkafteNicki Jul 19, 2023
cf97263
add test
SkafteNicki Jul 19, 2023
d3c6fa1
docs changes
SkafteNicki Jul 19, 2023
d926d3b
changelog
SkafteNicki Jul 19, 2023
e98551d
add example text
SkafteNicki Jul 19, 2023
c0ead51
Update src/torchmetrics/detection/helpers.py
SkafteNicki Jul 25, 2023
1e5e8d1
merge master
SkafteNicki Jul 25, 2023
f984250
Merge branch 'master' into feature/box_segm_at_same_time
SkafteNicki Jul 25, 2023
2a47a66
fix syncronization error for segm
SkafteNicki Jul 25, 2023
1fb7279
Merge branch 'master' into feature/box_segm_at_same_time
SkafteNicki Jul 28, 2023
3cbab0c
Apply suggestions from code review
SkafteNicki Aug 1, 2023
3f05d04
Merge branch 'master' into feature/box_segm_at_same_time
SkafteNicki Aug 2, 2023
0e70ffb
Apply suggestions from code review
SkafteNicki Aug 2, 2023
da318cf
smaller corrections
SkafteNicki Aug 2, 2023
33526cc
Merge branch 'master' into feature/box_segm_at_same_time
SkafteNicki Aug 5, 2023
0bdb146
Merge branch 'master' into feature/box_segm_at_same_time
Borda Aug 7, 2023
ce9006b
merge master
SkafteNicki Aug 7, 2023
b0b05b6
fix mistake
SkafteNicki Aug 7, 2023
dfa7e2e
Merge branch 'master' into feature/box_segm_at_same_time
Borda Aug 7, 2023
21fbc6c
merge
Borda Aug 7, 2023
2b887ad
Merge branch 'master' into feature/box_segm_at_same_time
Borda Aug 7, 2023
a6daa55
Merge branch 'master' into feature/box_segm_at_same_time
Borda Aug 8, 2023
57778d9
Merge branch 'master' into feature/box_segm_at_same_time
Borda Aug 8, 2023
00c39b0
merge master
SkafteNicki Aug 8, 2023
455720c
Merge branch 'master' into feature/box_segm_at_same_time
mergify[bot] Aug 8, 2023
d5b9080
Merge branch 'master' into feature/box_segm_at_same_time
mergify[bot] Aug 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `top_k` argument to `RetrievalMRR` in retrieval package ([#1961](https://github.com/Lightning-AI/torchmetrics/pull/1961))


- Added support for evaluating `"segm"` and `"bbox"` detection in `MeanAveragePrecision` at the same time ([#1928](https://github.com/Lightning-AI/torchmetrics/pull/1928))


- Added warning to `MeanAveragePrecision` if too many detections are observed ([#1978](https://github.com/Lightning-AI/torchmetrics/pull/1978))



### Changed

-
Expand Down
66 changes: 42 additions & 24 deletions src/torchmetrics/detection/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Sequence
from typing import Dict, Literal, Sequence, Tuple, Union

from torch import Tensor


def _input_validator(
preds: Sequence[Dict[str, Tensor]], targets: Sequence[Dict[str, Tensor]], iou_type: str = "bbox"
preds: Sequence[Dict[str, Tensor]],
targets: Sequence[Dict[str, Tensor]],
iou_type: Union[Literal["bbox", "segm"], Tuple[Literal["bbox", "segm"]]] = "bbox",
) -> None:
"""Ensure the correct input format of `preds` and `targets`."""
if iou_type == "bbox":
item_val_name = "boxes"
elif iou_type == "segm":
item_val_name = "masks"
else:
if isinstance(iou_type, str):
iou_type = (iou_type,)

name_map = {"bbox": "boxes", "segm": "masks"}
if any(tp not in name_map for tp in iou_type):
raise Exception(f"IOU type {iou_type} is not supported")
item_val_name = [name_map[tp] for tp in iou_type]

if not isinstance(preds, Sequence):
raise ValueError(f"Expected argument `preds` to be of type Sequence, but got {preds}")
Expand All @@ -36,42 +39,57 @@ def _input_validator(
f"Expected argument `preds` and `target` to have the same length, but got {len(preds)} and {len(targets)}"
)

for k in [item_val_name, "scores", "labels"]:
for k in [*item_val_name, "scores", "labels"]:
if any(k not in p for p in preds):
raise ValueError(f"Expected all dicts in `preds` to contain the `{k}` key")

for k in [item_val_name, "labels"]:
for k in [*item_val_name, "labels"]:
if any(k not in p for p in targets):
raise ValueError(f"Expected all dicts in `target` to contain the `{k}` key")

if any(type(pred[item_val_name]) is not Tensor for pred in preds):
raise ValueError(f"Expected all {item_val_name} in `preds` to be of type Tensor")
for ivn in item_val_name:
if any(type(pred[ivn]) is not Tensor for pred in preds):
raise ValueError(f"Expected all {ivn} in `preds` to be of type Tensor")
if any(type(pred["scores"]) is not Tensor for pred in preds):
raise ValueError("Expected all scores in `preds` to be of type Tensor")
if any(type(pred["labels"]) is not Tensor for pred in preds):
raise ValueError("Expected all labels in `preds` to be of type Tensor")
if any(type(target[item_val_name]) is not Tensor for target in targets):
raise ValueError(f"Expected all {item_val_name} in `target` to be of type Tensor")
for ivn in item_val_name:
if any(type(target[ivn]) is not Tensor for target in targets):
raise ValueError(f"Expected all {ivn} in `target` to be of type Tensor")
if any(type(target["labels"]) is not Tensor for target in targets):
raise ValueError("Expected all labels in `target` to be of type Tensor")

for i, item in enumerate(targets):
if item[item_val_name].size(0) != item["labels"].size(0):
raise ValueError(
f"Input {item_val_name} and labels of sample {i} in targets have a"
f" different length (expected {item[item_val_name].size(0)} labels, got {item['labels'].size(0)})"
)
for ivn in item_val_name:
if item[ivn].size(0) != item["labels"].size(0):
raise ValueError(
f"Input '{ivn}' and labels of sample {i} in targets have a"
f" different length (expected {item[ivn].size(0)} labels, got {item['labels'].size(0)})"
)
for i, item in enumerate(preds):
if not (item[item_val_name].size(0) == item["labels"].size(0) == item["scores"].size(0)):
raise ValueError(
f"Input {item_val_name}, labels and scores of sample {i} in predictions have a"
f" different length (expected {item[item_val_name].size(0)} labels and scores,"
f" got {item['labels'].size(0)} labels and {item['scores'].size(0)})"
)
for ivn in item_val_name:
if not (item[ivn].size(0) == item["labels"].size(0) == item["scores"].size(0)):
raise ValueError(
f"Input '{ivn}', labels and scores of sample {i} in predictions have a"
f" different length (expected {item[ivn].size(0)} labels and scores,"
f" got {item['labels'].size(0)} labels and {item['scores'].size(0)})"
)


def _fix_empty_tensors(boxes: Tensor) -> Tensor:
"""Empty tensors can cause problems in DDP mode, this methods corrects them."""
if boxes.numel() == 0 and boxes.ndim == 1:
return boxes.unsqueeze(0)
return boxes


def _validate_iou_type_arg(iou_type: Union[Literal["bbox", "segm"], Tuple[str]] = "bbox") -> Tuple[str]:
allowed_iou_types = ("segm", "bbox")
if isinstance(iou_type, str):
iou_type = (iou_type,)
if any(tp not in allowed_iou_types for tp in iou_type):
raise ValueError(
f"Expected argument `iou_type` to be one of {allowed_iou_types} or a list of, but got {iou_type}"
)
return iou_type
Loading
Loading