diff --git a/CHANGELOG.md b/CHANGELOG.md index 381b6754c10..d20e3aa1dd0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,7 +38,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fixed IOU compute in cuda ([#1982](https://github.com/Lightning-AI/torchmetrics/pull/1982)) ## [1.0.2] - 2023-08-02 diff --git a/src/torchmetrics/functional/detection/iou.py b/src/torchmetrics/functional/detection/iou.py index 43abc597d22..1f37ef01a27 100644 --- a/src/torchmetrics/functional/detection/iou.py +++ b/src/torchmetrics/functional/detection/iou.py @@ -38,7 +38,7 @@ def _iou_update( def _iou_compute(iou: torch.Tensor, labels_eq: bool = True) -> torch.Tensor: if labels_eq: return iou.diag().mean() - return iou.mean() if iou.numel() > 0 else torch.tensor(0.0) + return iou.mean() if iou.numel() > 0 else torch.tensor(0.0).to(iou.device) def intersection_over_union(