Skip to content

Commit e07b1df

Browse files
KumoLiuwyli
authored andcommitted
Precision issue in get_confusion_matrix (Project-MONAI#7187)
Fixes Project-MONAI#7186 ### Description remove unnecessary float() ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu <yunl@nvidia.com> Co-authored-by: Wenqi Li <831580+wyli@users.noreply.github.com> Signed-off-by: Yu0610 <612410030@alum.ccu.edu.tw>
1 parent 3c8fb67 commit e07b1df

File tree

7 files changed

+41
-50
lines changed

7 files changed

+41
-50
lines changed

monai/metrics/confusion_matrix.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,6 @@ def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_backgrou
153153
if not include_background:
154154
y_pred, y = ignore_background(y_pred=y_pred, y=y)
155155

156-
y = y.float()
157-
y_pred = y_pred.float()
158-
159156
if y.shape != y_pred.shape:
160157
raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.")
161158

@@ -165,12 +162,12 @@ def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_backgrou
165162
# As for classification tasks, S equals to 1.
166163
y_pred = y_pred.reshape(batch_size, n_class, -1)
167164
y = y.reshape(batch_size, n_class, -1)
168-
tp = ((y_pred + y) == 2).float()
169-
tn = ((y_pred + y) == 0).float()
165+
tp = (y_pred + y) == 2
166+
tn = (y_pred + y) == 0
170167

171-
tp = tp.sum(dim=[2])
172-
tn = tn.sum(dim=[2])
173-
p = y.sum(dim=[2])
168+
tp = tp.sum(dim=[2]).float()
169+
tn = tn.sum(dim=[2]).float()
170+
p = y.sum(dim=[2]).float()
174171
n = y.shape[-1] - p
175172

176173
fn = p - tp

monai/metrics/f_beta_score.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,6 @@ def get_f_beta_score(y_pred: torch.Tensor, y: torch.Tensor, include_background:
6363
if not include_background:
6464
y_pred, y = ignore_background(y_pred=y_pred, y=y)
6565

66-
y = y.float()
67-
y_pred = y_pred.float()
68-
6966
if y.shape != y_pred.shape:
7067
raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.")
7168

@@ -75,12 +72,12 @@ def get_f_beta_score(y_pred: torch.Tensor, y: torch.Tensor, include_background:
7572
# As for classification tasks, S equals to 1.
7673
y_pred = y_pred.view(batch_size, n_class, -1)
7774
y = y.view(batch_size, n_class, -1)
78-
tp = ((y_pred + y) == 2).float()
79-
tn = ((y_pred + y) == 0).float()
75+
tp = (y_pred + y) == 2
76+
tn = (y_pred + y) == 0
8077

81-
tp = tp.sum(dim=[2])
82-
tn = tn.sum(dim=[2])
83-
p = y.sum(dim=[2])
78+
tp = tp.sum(dim=[2]).float()
79+
tn = tn.sum(dim=[2]).float()
80+
p = y.sum(dim=[2]).float()
8481
n = y.shape[-1] - p
8582

8683
fn = p - tp

monai/metrics/meaniou.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,6 @@ def compute_iou(
130130
if not include_background:
131131
y_pred, y = ignore_background(y_pred=y_pred, y=y)
132132

133-
y = y.float()
134-
y_pred = y_pred.float()
135-
136133
if y.shape != y_pred.shape:
137134
raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.")
138135

monai/metrics/regression.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,6 @@ def __init__(self, reduction: MetricReduction | str = MetricReduction.MEAN, get_
111111
self.sq_func = partial(torch.pow, exponent=2.0)
112112

113113
def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
114-
y_pred = y_pred.float()
115-
y = y.float()
116-
117114
return compute_mean_error_metrics(y_pred, y, func=self.sq_func)
118115

119116

@@ -143,9 +140,6 @@ def __init__(self, reduction: MetricReduction | str = MetricReduction.MEAN, get_
143140
self.abs_func = torch.abs
144141

145142
def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
146-
y_pred = y_pred.float()
147-
y = y.float()
148-
149143
return compute_mean_error_metrics(y_pred, y, func=self.abs_func)
150144

151145

@@ -176,9 +170,6 @@ def __init__(self, reduction: MetricReduction | str = MetricReduction.MEAN, get_
176170
self.sq_func = partial(torch.pow, exponent=2.0)
177171

178172
def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
179-
y_pred = y_pred.float()
180-
y = y.float()
181-
182173
mse_out = compute_mean_error_metrics(y_pred, y, func=self.sq_func)
183174
return torch.sqrt(mse_out)
184175

@@ -218,9 +209,6 @@ def __init__(
218209
self.sq_func = partial(torch.pow, exponent=2.0)
219210

220211
def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> Any:
221-
y_pred = y_pred.float()
222-
y = y.float()
223-
224212
mse_out = compute_mean_error_metrics(y_pred, y, func=self.sq_func)
225213
return 20 * math.log10(self.max_val) - 10 * torch.log10(mse_out)
226214

monai/metrics/surface_dice.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,6 @@ def compute_surface_dice(
228228
f"y_pred and y should have same shape, but instead, shapes are {y_pred.shape} (y_pred) and {y.shape} (y)."
229229
)
230230

231-
y = y.float()
232-
y_pred = y_pred.float()
233-
234231
batch_size, n_class = y_pred.shape[:2]
235232

236233
if n_class != len(class_thresholds):

monai/metrics/utils.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -95,37 +95,37 @@ def do_metric_reduction(
9595
# some elements might be Nan (if ground truth y was missing (zeros))
9696
# we need to account for it
9797
nans = torch.isnan(f)
98-
not_nans = (~nans).float()
98+
not_nans = ~nans
9999

100-
t_zero = torch.zeros(1, device=f.device, dtype=f.dtype)
100+
t_zero = torch.zeros(1, device=f.device, dtype=torch.float)
101101
reduction = look_up_option(reduction, MetricReduction)
102102
if reduction == MetricReduction.NONE:
103-
return f, not_nans
103+
return f, not_nans.float()
104104

105105
f[nans] = 0
106106
if reduction == MetricReduction.MEAN:
107107
# 2 steps, first, mean by channel (accounting for nans), then by batch
108-
not_nans = not_nans.sum(dim=1)
109-
f = torch.where(not_nans > 0, f.sum(dim=1) / not_nans, t_zero) # channel average
108+
not_nans = not_nans.sum(dim=1).float()
109+
f = torch.where(not_nans > 0, f.sum(dim=1).float() / not_nans, t_zero) # channel average
110110

111-
not_nans = (not_nans > 0).float().sum(dim=0)
112-
f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average
111+
not_nans = (not_nans > 0).sum(dim=0).float()
112+
f = torch.where(not_nans > 0, f.sum(dim=0).float() / not_nans, t_zero) # batch average
113113

114114
elif reduction == MetricReduction.SUM:
115-
not_nans = not_nans.sum(dim=[0, 1])
115+
not_nans = not_nans.sum(dim=[0, 1]).float()
116116
f = torch.sum(f, dim=[0, 1]) # sum over the batch and channel dims
117117
elif reduction == MetricReduction.MEAN_BATCH:
118-
not_nans = not_nans.sum(dim=0)
119-
f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average
118+
not_nans = not_nans.sum(dim=0).float()
119+
f = torch.where(not_nans > 0, f.sum(dim=0).float() / not_nans, t_zero) # batch average
120120
elif reduction == MetricReduction.SUM_BATCH:
121-
not_nans = not_nans.sum(dim=0)
122-
f = f.sum(dim=0) # the batch sum
121+
not_nans = not_nans.sum(dim=0).float()
122+
f = f.sum(dim=0).float() # the batch sum
123123
elif reduction == MetricReduction.MEAN_CHANNEL:
124-
not_nans = not_nans.sum(dim=1)
125-
f = torch.where(not_nans > 0, f.sum(dim=1) / not_nans, t_zero) # channel average
124+
not_nans = not_nans.sum(dim=1).float()
125+
f = torch.where(not_nans > 0, f.sum(dim=1).float() / not_nans, t_zero) # channel average
126126
elif reduction == MetricReduction.SUM_CHANNEL:
127-
not_nans = not_nans.sum(dim=1)
128-
f = f.sum(dim=1) # the channel sum
127+
not_nans = not_nans.sum(dim=1).float()
128+
f = f.sum(dim=1).float() # the channel sum
129129
elif reduction != MetricReduction.NONE:
130130
raise ValueError(
131131
f"Unsupported reduction: {reduction}, available options are "

tests/test_compute_confusion_matrix.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,14 @@
210210

211211
TEST_CASES_CLF = [data_clf.copy(), result_clf]
212212

213+
TEST_CASE_PRECISION = [
214+
{
215+
"y_pred": torch.zeros([1, 1, 1024, 1024, 44], device=_device),
216+
"y": torch.zeros([1, 1, 1024, 1024, 44], device=_device),
217+
},
218+
torch.tensor([[[0.0, 0.0, 46137344.0, 0.0]]]),
219+
]
220+
213221

214222
class TestConfusionMatrix(unittest.TestCase):
215223
@parameterized.expand([TEST_CASE_CONFUSION_MATRIX])
@@ -274,6 +282,13 @@ def test_clf_with_nan(self, input_data, expected_value):
274282
expected_value = compute_confusion_matrix_metric("tpr", expected_value)
275283
assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4)
276284

285+
@parameterized.expand([TEST_CASE_PRECISION])
286+
def test_precision(self, input_data, expected_value):
287+
# include or ignore background
288+
result = get_confusion_matrix(**input_data)
289+
assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4)
290+
np.testing.assert_equal(result.device, input_data["y_pred"].device)
291+
277292

278293
if __name__ == "__main__":
279294
unittest.main()

0 commit comments

Comments
 (0)