From 5a226431f6beee606b1db9a45bab1f19fdd449d9 Mon Sep 17 00:00:00 2001 From: Blazej Dolicki Date: Mon, 23 Jan 2023 14:00:46 +0100 Subject: [PATCH] Hacky fix for Dice score with `average` set to `weighted` or `none` --- src/torchmetrics/classification/dice.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/classification/dice.py b/src/torchmetrics/classification/dice.py index 6bcae6247d8..626d8e0be1a 100644 --- a/src/torchmetrics/classification/dice.py +++ b/src/torchmetrics/classification/dice.py @@ -163,14 +163,11 @@ def __init__( self.ignore_index = ignore_index self.top_k = top_k - if average not in ["micro", "macro", "samples"]: - raise ValueError(f"The `reduce` {average} is not valid.") - if mdmc_average not in [None, "samplewise", "global"]: raise ValueError(f"The `mdmc_reduce` {mdmc_average} is not valid.") - if average == "macro" and (not num_classes or num_classes < 1): - raise ValueError("When you set `average` as 'macro', you have to provide the number of classes.") + if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1): + raise ValueError(f"When you set `average` as '{average}', you have to provide the number of classes.") if num_classes and ignore_index is not None and (not ignore_index < num_classes or num_classes == 1): raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") @@ -180,7 +177,7 @@ def __init__( if mdmc_average != "samplewise" and average != "samples": if average == "micro": zeros_shape = [] - elif average == "macro": + elif average in ["macro", "weighted", "none", None]: zeros_shape = [num_classes] else: raise ValueError(f'Wrong reduce="{average}"')