Skip to content

meanIoU calculation error for 1.8.0.dev0 (build from source) #3047

Open
@insomniaaac

Description

@insomniaaac

🐛 Bug

meanIoU calculation error. details as follow

To Reproduce

Steps to reproduce the behavior...

import torch

gt = torch.tensor(
    [
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 1, 1, 0, 0, 3, 3, 3, 0],
        [0, 1, 1, 1, 0, 0, 3, 3, 3, 0],
        [0, 1, 1, 1, 0, 0, 3, 3, 3, 0],
        [5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
        [0, 2, 2, 2, 0, 0, 4, 4, 4, 0],
        [0, 2, 2, 2, 0, 0, 4, 4, 4, 0],
        [0, 2, 2, 2, 0, 0, 4, 4, 4, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    ],
    dtype=torch.int64,
)  # 10x10 index tensor

pred = torch.tensor(
    [
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 1, 1, 0, 0, 3, 3, 3, 0],
        [0, 1, 1, 1, 0, 0, 3, 3, 3, 0],
        [0, 5, 5, 5, 0, 0, 3, 3, 3, 0],
        [5, 5, 5, 5, 5, 0, 0, 0, 0, 0],
        [0, 2, 2, 2, 0, 0, 4, 4, 4, 0],
        [0, 2, 2, 2, 0, 0, 4, 4, 4, 0],
        [0, 2, 2, 2, 0, 0, 4, 4, 4, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    ],
    dtype=torch.int64,
)  # 10x10 index tensor

note here, for class 1, IoU is 0.6667, for class 2 is 1.0, for class 3 is 1.0, for class 4 is 1.0, for class 5 is 0.3846 (union is 13, intersection is 5)

i use torchmetrics-1.8.0.dev0 (build from latest source)

from torchmetrics.segmentation import MeanIoU as torchmetricsMeanIoU
torchmetrics_miou = torchmetricsMeanIoU(
    num_classes=6, include_background=False, per_class=True
)
torchmetrics_miou.update(pred, gt)
print("torchmetrics_miou", torchmetrics_miou.compute())

result is

torchmetrics_miou tensor([0.7483, 0.7483, 0.7483, 0.7483, 0.7483])

which is not correct.

Expected behavior

i noticed that #2558 (comment) implementation is correct for this case.

from scripts.miou import MeanIoU as customMeanIoU
custom_miou = customMeanIoU(num_classes=6, include_background=False, per_class=True)
custom_miou.update(pred, gt)
print("custom_miou", custom_miou.compute())

output is

custom_miou tensor([0.6667, 1.0000, 1.0000, 1.0000, 0.3846])

Environment

  • TorchMetrics version (if build from source, add commit SHA): torchmetrics-1.8.0.dev0 (build from latest source)
  • Python & PyTorch Version (e.g., 1.0): python 3.10, pytorch 2.4.0+cu118
  • Any other relevant information such as OS (e.g., Linux): ubuntu

Metadata

Metadata

Assignees

No one assigned

    Labels

    bug / fixSomething isn't workinghelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions