Skip to content

Commit 7b22ad7

Browse files
authored
Fix ignore_index for multiclass metric computation (#547)
1 parent 2447352 commit 7b22ad7

File tree

2 files changed

+22
-6
lines changed

2 files changed

+22
-6
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ __pycache__/
33
*.py[cod]
44
*$py.class
55
.idea/
6+
.venv*
67

78
# C extensions
89
*.so

segmentation_models_pytorch/metrics/functional.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ def get_stats(
9999
threshold (Optional[float, List[float]]): Binarization threshold for
100100
``output`` in case of ``'binary'`` or ``'multilabel'`` modes. Defaults to None.
101101
num_classes (Optional[int]): Number of classes, necessary attribute
102-
only for ``'multiclass'`` mode.
102+
only for ``'multiclass'`` mode. Class values should be in range 0..(num_classes - 1).
103+
If ``ignore_index`` is specified it should be outside the classes range, e.g. ``-1`` or
104+
``255``.
103105
104106
Raises:
105107
ValueError: in case of misconfiguration.
@@ -139,12 +141,16 @@ def get_stats(
139141
if mode == "multiclass" and num_classes is None:
140142
raise ValueError("``num_classes`` attribute should be not ``None`` for 'multiclass' mode.")
141143

144+
if ignore_index is not None and 0 <= ignore_index <= num_classes - 1:
145+
raise ValueError(
146+
f"``ignore_index`` should be outside the class values range, but got class values in range "
147+
f"0..{num_classes - 1} and ``ignore_index={ignore_index}``. Hint: if you have ``ignore_index = 0``"
148+
f"consirder subtracting ``1`` from your target and model output to make ``ignore_index = -1``"
149+
f"and relevant class values started from ``0``."
150+
)
151+
142152
if mode == "multiclass":
143-
if ignore_index is not None:
144-
ignore = target == ignore_index
145-
output = torch.where(ignore, -1, output)
146-
target = torch.where(ignore, -1, target)
147-
tp, fp, fn, tn = _get_stats_multiclass(output, target, num_classes)
153+
tp, fp, fn, tn = _get_stats_multiclass(output, target, num_classes, ignore_index)
148154
else:
149155
if threshold is not None:
150156
output = torch.where(output >= threshold, 1, 0)
@@ -159,11 +165,18 @@ def _get_stats_multiclass(
159165
output: torch.LongTensor,
160166
target: torch.LongTensor,
161167
num_classes: int,
168+
ignore_index: Optional[int],
162169
) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor]:
163170

164171
batch_size, *dims = output.shape
165172
num_elements = torch.prod(torch.tensor(dims)).long()
166173

174+
if ignore_index is not None:
175+
ignore = target == ignore_index
176+
output = torch.where(ignore, -1, output)
177+
target = torch.where(ignore, -1, target)
178+
ignore_per_sample = ignore.view(batch_size, -1).sum(1)
179+
167180
tp_count = torch.zeros(batch_size, num_classes, dtype=torch.long)
168181
fp_count = torch.zeros(batch_size, num_classes, dtype=torch.long)
169182
fn_count = torch.zeros(batch_size, num_classes, dtype=torch.long)
@@ -178,6 +191,8 @@ def _get_stats_multiclass(
178191
fp = torch.histc(output_i.float(), bins=num_classes, min=0, max=num_classes - 1) - tp
179192
fn = torch.histc(target_i.float(), bins=num_classes, min=0, max=num_classes - 1) - tp
180193
tn = num_elements - tp - fp - fn
194+
if ignore_index is not None:
195+
tn = tn - ignore_per_sample[i]
181196
tp_count[i] = tp.long()
182197
fp_count[i] = fp.long()
183198
fn_count[i] = fn.long()

0 commit comments

Comments
 (0)