Skip to content

Commit a944081

Browse files
junbao-zhouvfdev-5
andauthored
fix an error message of confusion_matrix.IoU (#2613)
* fix error messages in confusion_matrix.py to make them match condition check * update test_confusion_matrix.py to match the new error message Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent 03577ac commit a944081

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

ignite/metrics/confusion_matrix.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,9 @@ def IoU(cm: ConfusionMatrix, ignore_index: Optional[int] = None) -> MetricsLambd
239239

240240
if ignore_index is not None:
241241
if not (isinstance(ignore_index, numbers.Integral) and 0 <= ignore_index < cm.num_classes):
242-
raise ValueError(f"ignore_index should be non-negative integer, but given {ignore_index}")
242+
raise ValueError(
243+
f"ignore_index should be integer and in the range of [0, {cm.num_classes}), but given {ignore_index}"
244+
)
243245

244246
# Increase floating point precision and pass to CPU
245247
cm = cm.to(torch.double)
@@ -393,7 +395,9 @@ def DiceCoefficient(cm: ConfusionMatrix, ignore_index: Optional[int] = None) ->
393395

394396
if ignore_index is not None:
395397
if not (isinstance(ignore_index, numbers.Integral) and 0 <= ignore_index < cm.num_classes):
396-
raise ValueError(f"ignore_index should be non-negative integer, but given {ignore_index}")
398+
raise ValueError(
399+
f"ignore_index should be integer and in the range of [0, {cm.num_classes}), but given {ignore_index}"
400+
)
397401

398402
# Increase floating point precision and pass to CPU
399403
cm = cm.to(torch.double)

tests/ignite/metrics/test_confusion_matrix.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -189,16 +189,16 @@ def test_iou_wrong_input():
189189
IoU(None)
190190

191191
cm = ConfusionMatrix(num_classes=10)
192-
with pytest.raises(ValueError, match="ignore_index should be non-negative integer"):
192+
with pytest.raises(ValueError, match=r"ignore_index should be integer and in the range of \[0, 10\), but given -1"):
193193
IoU(cm, ignore_index=-1)
194194

195-
with pytest.raises(ValueError, match="ignore_index should be non-negative integer"):
195+
with pytest.raises(ValueError, match=r"ignore_index should be integer and in the range of \[0, 10\), but given a"):
196196
IoU(cm, ignore_index="a")
197197

198-
with pytest.raises(ValueError, match="ignore_index should be non-negative integer"):
198+
with pytest.raises(ValueError, match=r"ignore_index should be integer and in the range of \[0, 10\), but given 10"):
199199
IoU(cm, ignore_index=10)
200200

201-
with pytest.raises(ValueError, match="ignore_index should be non-negative integer"):
201+
with pytest.raises(ValueError, match=r"ignore_index should be integer and in the range of \[0, 10\), but given 11"):
202202
IoU(cm, ignore_index=11)
203203

204204

@@ -403,16 +403,16 @@ def test_dice_coefficient_wrong_input():
403403
DiceCoefficient(None)
404404

405405
cm = ConfusionMatrix(num_classes=10)
406-
with pytest.raises(ValueError, match="ignore_index should be non-negative integer"):
406+
with pytest.raises(ValueError, match=r"ignore_index should be integer and in the range of \[0, 10\), but given -1"):
407407
DiceCoefficient(cm, ignore_index=-1)
408408

409-
with pytest.raises(ValueError, match="ignore_index should be non-negative integer"):
409+
with pytest.raises(ValueError, match=r"ignore_index should be integer and in the range of \[0, 10\), but given a"):
410410
DiceCoefficient(cm, ignore_index="a")
411411

412-
with pytest.raises(ValueError, match="ignore_index should be non-negative integer"):
412+
with pytest.raises(ValueError, match=r"ignore_index should be integer and in the range of \[0, 10\), but given 10"):
413413
DiceCoefficient(cm, ignore_index=10)
414414

415-
with pytest.raises(ValueError, match="ignore_index should be non-negative integer"):
415+
with pytest.raises(ValueError, match=r"ignore_index should be integer and in the range of \[0, 10\), but given 11"):
416416
DiceCoefficient(cm, ignore_index=11)
417417

418418

0 commit comments

Comments
 (0)