@@ -99,7 +99,9 @@ def get_stats(
99
99
threshold (Optional[float, List[float]]): Binarization threshold for
100
100
``output`` in case of ``'binary'`` or ``'multilabel'`` modes. Defaults to None.
101
101
num_classes (Optional[int]): Number of classes, necessary attribute
102
- only for ``'multiclass'`` mode.
102
+ only for ``'multiclass'`` mode. Class values should be in range 0..(num_classes - 1).
103
+ If ``ignore_index`` is specified it should be outside the classes range, e.g. ``-1`` or
104
+ ``255``.
103
105
104
106
Raises:
105
107
ValueError: in case of misconfiguration.
@@ -139,12 +141,16 @@ def get_stats(
139
141
if mode == "multiclass" and num_classes is None :
140
142
raise ValueError ("``num_classes`` attribute should be not ``None`` for 'multiclass' mode." )
141
143
144
+ if ignore_index is not None and 0 <= ignore_index <= num_classes - 1 :
145
+ raise ValueError (
146
+ f"``ignore_index`` should be outside the class values range, but got class values in range "
147
+ f"0..{ num_classes - 1 } and ``ignore_index={ ignore_index } ``. Hint: if you have ``ignore_index = 0``"
148
+ f"consirder subtracting ``1`` from your target and model output to make ``ignore_index = -1``"
149
+ f"and relevant class values started from ``0``."
150
+ )
151
+
142
152
if mode == "multiclass" :
143
- if ignore_index is not None :
144
- ignore = target == ignore_index
145
- output = torch .where (ignore , - 1 , output )
146
- target = torch .where (ignore , - 1 , target )
147
- tp , fp , fn , tn = _get_stats_multiclass (output , target , num_classes )
153
+ tp , fp , fn , tn = _get_stats_multiclass (output , target , num_classes , ignore_index )
148
154
else :
149
155
if threshold is not None :
150
156
output = torch .where (output >= threshold , 1 , 0 )
@@ -159,11 +165,18 @@ def _get_stats_multiclass(
159
165
output : torch .LongTensor ,
160
166
target : torch .LongTensor ,
161
167
num_classes : int ,
168
+ ignore_index : Optional [int ],
162
169
) -> Tuple [torch .LongTensor , torch .LongTensor , torch .LongTensor , torch .LongTensor ]:
163
170
164
171
batch_size , * dims = output .shape
165
172
num_elements = torch .prod (torch .tensor (dims )).long ()
166
173
174
+ if ignore_index is not None :
175
+ ignore = target == ignore_index
176
+ output = torch .where (ignore , - 1 , output )
177
+ target = torch .where (ignore , - 1 , target )
178
+ ignore_per_sample = ignore .view (batch_size , - 1 ).sum (1 )
179
+
167
180
tp_count = torch .zeros (batch_size , num_classes , dtype = torch .long )
168
181
fp_count = torch .zeros (batch_size , num_classes , dtype = torch .long )
169
182
fn_count = torch .zeros (batch_size , num_classes , dtype = torch .long )
@@ -178,6 +191,8 @@ def _get_stats_multiclass(
178
191
fp = torch .histc (output_i .float (), bins = num_classes , min = 0 , max = num_classes - 1 ) - tp
179
192
fn = torch .histc (target_i .float (), bins = num_classes , min = 0 , max = num_classes - 1 ) - tp
180
193
tn = num_elements - tp - fp - fn
194
+ if ignore_index is not None :
195
+ tn = tn - ignore_per_sample [i ]
181
196
tp_count [i ] = tp .long ()
182
197
fp_count [i ] = fp .long ()
183
198
fn_count [i ] = fn .long ()
0 commit comments