Skip to content

Commit a778e58

Browse files
committed
Add test cases
1 parent 60c7b36 commit a778e58

File tree

3 files changed

+48
-0
lines changed

3 files changed

+48
-0
lines changed

tests/test_dice_loss.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,22 @@
3434
},
3535
0.416657,
3636
],
37+
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
38+
{"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": True},
39+
{
40+
"input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
41+
"target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
42+
},
43+
0.0,
44+
],
45+
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
46+
{"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": False},
47+
{
48+
"input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
49+
"target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
50+
},
51+
0.307773,
52+
],
3753
[ # shape: (2, 2, 3), (2, 1, 3)
3854
{"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0},
3955
{

tests/test_generalized_dice_loss.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,22 @@
3434
},
3535
0.416597,
3636
],
37+
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
38+
{"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": True},
39+
{
40+
"input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
41+
"target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
42+
},
43+
0.0,
44+
],
45+
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
46+
{"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": False},
47+
{
48+
"input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
49+
"target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
50+
},
51+
0.307748,
52+
],
3753
[ # shape: (2, 2, 3), (2, 1, 3)
3854
{"include_background": False, "to_onehot_y": True, "smooth_nr": 0.0, "smooth_dr": 0.0},
3955
{

tests/test_tversky_loss.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,22 @@
3434
},
3535
0.416657,
3636
],
37+
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
38+
{"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": True},
39+
{
40+
"input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
41+
"target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
42+
},
43+
0.0,
44+
],
45+
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
46+
{"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": False},
47+
{
48+
"input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
49+
"target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
50+
},
51+
0.307773,
52+
],
3753
[ # shape: (2, 2, 3), (2, 1, 3)
3854
{"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0},
3955
{

0 commit comments

Comments
 (0)