Skip to content

Commit 2800a76

Browse files
feat: add clDice loss (#6763)
Fixes #5938 ### Description This PR aims to add the `SoftclDiceLoss` and the `SoftDiceclDiceLoss` from [clDice - a Novel Topology-Preserving Loss Function for Tubular Structure Segmentation](https://openaccess.thecvf.com/content/CVPR2021/papers/Shit_clDice_-_A_Novel_Topology-Preserving_Loss_Function_for_Tubular_Structure_CVPR_2021_paper.pdf) ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Saurav Maheshkar <sauravvmaheshkar@gmail.com>
1 parent 28c9083 commit 2800a76

File tree

3 files changed

+241
-0
lines changed

3 files changed

+241
-0
lines changed

monai/losses/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss
1415
from .contrastive import ContrastiveLoss
1516
from .deform import BendingEnergyLoss
1617
from .dice import (

monai/losses/cldice.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import torch
15+
import torch.nn.functional as F
16+
from torch.nn.modules.loss import _Loss
17+
18+
19+
def soft_erode(img: torch.Tensor) -> torch.Tensor: # type: ignore
20+
"""
21+
Perform soft erosion on the input image
22+
23+
Args:
24+
img: the shape should be BCH(WD)
25+
26+
Adapted from:
27+
https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L6
28+
"""
29+
if len(img.shape) == 4:
30+
p1 = -(F.max_pool2d(-img, (3, 1), (1, 1), (1, 0)))
31+
p2 = -(F.max_pool2d(-img, (1, 3), (1, 1), (0, 1)))
32+
return torch.min(p1, p2) # type: ignore
33+
elif len(img.shape) == 5:
34+
p1 = -(F.max_pool3d(-img, (3, 1, 1), (1, 1, 1), (1, 0, 0)))
35+
p2 = -(F.max_pool3d(-img, (1, 3, 1), (1, 1, 1), (0, 1, 0)))
36+
p3 = -(F.max_pool3d(-img, (1, 1, 3), (1, 1, 1), (0, 0, 1)))
37+
return torch.min(torch.min(p1, p2), p3) # type: ignore
38+
39+
40+
def soft_dilate(img: torch.Tensor) -> torch.Tensor: # type: ignore
41+
"""
42+
Perform soft dilation on the input image
43+
44+
Args:
45+
img: the shape should be BCH(WD)
46+
47+
Adapted from:
48+
https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L18
49+
"""
50+
if len(img.shape) == 4:
51+
return F.max_pool2d(img, (3, 3), (1, 1), (1, 1)) # type: ignore
52+
elif len(img.shape) == 5:
53+
return F.max_pool3d(img, (3, 3, 3), (1, 1, 1), (1, 1, 1)) # type: ignore
54+
55+
56+
def soft_open(img: torch.Tensor) -> torch.Tensor:
57+
"""
58+
Wrapper function to perform soft opening on the input image
59+
60+
Args:
61+
img: the shape should be BCH(WD)
62+
63+
Adapted from:
64+
https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L25
65+
"""
66+
eroded_image = soft_erode(img)
67+
dilated_image = soft_dilate(eroded_image)
68+
return dilated_image
69+
70+
71+
def soft_skel(img: torch.Tensor, iter_: int) -> torch.Tensor:
72+
"""
73+
Perform soft skeletonization on the input image
74+
75+
Adapted from:
76+
https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L29
77+
78+
Args:
79+
img: the shape should be BCH(WD)
80+
iter_: number of iterations for skeletonization
81+
82+
Returns:
83+
skeletonized image
84+
"""
85+
img1 = soft_open(img)
86+
skel = F.relu(img - img1)
87+
for _ in range(iter_):
88+
img = soft_erode(img)
89+
img1 = soft_open(img)
90+
delta = F.relu(img - img1)
91+
skel = skel + F.relu(delta - skel * delta)
92+
return skel
93+
94+
95+
def soft_dice(y_true: torch.Tensor, y_pred: torch.Tensor, smooth: float = 1.0) -> torch.Tensor:
96+
"""
97+
Function to compute soft dice loss
98+
99+
Adapted from:
100+
https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L22
101+
102+
Args:
103+
y_true: the shape should be BCH(WD)
104+
y_pred: the shape should be BCH(WD)
105+
106+
Returns:
107+
dice loss
108+
"""
109+
intersection = torch.sum((y_true * y_pred)[:, 1:, ...])
110+
coeff = (2.0 * intersection + smooth) / (torch.sum(y_true[:, 1:, ...]) + torch.sum(y_pred[:, 1:, ...]) + smooth)
111+
soft_dice: torch.Tensor = 1.0 - coeff
112+
return soft_dice
113+
114+
115+
class SoftclDiceLoss(_Loss):
116+
"""
117+
Compute the Soft clDice loss defined in:
118+
119+
Shit et al. (2021) clDice -- A Novel Topology-Preserving Loss Function
120+
for Tubular Structure Segmentation. (https://arxiv.org/abs/2003.07311)
121+
122+
Adapted from:
123+
https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L7
124+
"""
125+
126+
def __init__(self, iter_: int = 3, smooth: float = 1.0) -> None:
127+
"""
128+
Args:
129+
iter_: Number of iterations for skeletonization
130+
smooth: Smoothing parameter
131+
"""
132+
super().__init__()
133+
self.iter = iter_
134+
self.smooth = smooth
135+
136+
def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
137+
skel_pred = soft_skel(y_pred, self.iter)
138+
skel_true = soft_skel(y_true, self.iter)
139+
tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / (
140+
torch.sum(skel_pred[:, 1:, ...]) + self.smooth
141+
)
142+
tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / (
143+
torch.sum(skel_true[:, 1:, ...]) + self.smooth
144+
)
145+
cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens)
146+
return cl_dice
147+
148+
149+
class SoftDiceclDiceLoss(_Loss):
150+
"""
151+
Compute the Soft clDice loss defined in:
152+
153+
Shit et al. (2021) clDice -- A Novel Topology-Preserving Loss Function
154+
for Tubular Structure Segmentation. (https://arxiv.org/abs/2003.07311)
155+
156+
Adapted from:
157+
https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L38
158+
"""
159+
160+
def __init__(self, iter_: int = 3, alpha: float = 0.5, smooth: float = 1.0) -> None:
161+
"""
162+
Args:
163+
iter_: Number of iterations for skeletonization
164+
smooth: Smoothing parameter
165+
alpha: Weighing factor for cldice
166+
"""
167+
super().__init__()
168+
self.iter = iter_
169+
self.smooth = smooth
170+
self.alpha = alpha
171+
172+
def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
173+
dice = soft_dice(y_true, y_pred, self.smooth)
174+
skel_pred = soft_skel(y_pred, self.iter)
175+
skel_true = soft_skel(y_true, self.iter)
176+
tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / (
177+
torch.sum(skel_pred[:, 1:, ...]) + self.smooth
178+
)
179+
tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / (
180+
torch.sum(skel_true[:, 1:, ...]) + self.smooth
181+
)
182+
cl_dice = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens)
183+
total_loss: torch.Tensor = (1.0 - self.alpha) * dice + self.alpha * cl_dice
184+
return total_loss

tests/test_cldice_loss.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# You may obtain a copy of the License at
2+
# http://www.apache.org/licenses/LICENSE-2.0
3+
# Unless required by applicable law or agreed to in writing, software
4+
# distributed under the License is distributed on an "AS IS" BASIS,
5+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6+
# See the License for the specific language governing permissions and
7+
# limitations under the License.
8+
9+
from __future__ import annotations
10+
11+
import unittest
12+
13+
import numpy as np
14+
import torch
15+
from parameterized import parameterized
16+
17+
from monai.losses import SoftclDiceLoss, SoftDiceclDiceLoss
18+
19+
TEST_CASES = [
20+
[ # shape: (1, 4), (1, 4)
21+
{"y_pred": torch.ones((100, 3, 256, 256)), "y_true": torch.ones((100, 3, 256, 256))},
22+
0.0,
23+
],
24+
[ # shape: (1, 5), (1, 5)
25+
{"y_pred": torch.ones((100, 3, 256, 256, 5)), "y_true": torch.ones((100, 3, 256, 256, 5))},
26+
0.0,
27+
],
28+
]
29+
30+
31+
class TestclDiceLoss(unittest.TestCase):
32+
@parameterized.expand(TEST_CASES)
33+
def test_result(self, y_pred_data, expected_val):
34+
loss = SoftclDiceLoss()
35+
loss_dice = SoftDiceclDiceLoss()
36+
result = loss(**y_pred_data)
37+
result_dice = loss_dice(**y_pred_data)
38+
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)
39+
np.testing.assert_allclose(result_dice.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)
40+
41+
def test_with_cuda(self):
42+
loss = SoftclDiceLoss()
43+
loss_dice = SoftDiceclDiceLoss()
44+
i = torch.ones((100, 3, 256, 256))
45+
j = torch.ones((100, 3, 256, 256))
46+
if torch.cuda.is_available():
47+
i = i.cuda()
48+
j = j.cuda()
49+
output = loss(i, j)
50+
output_dice = loss_dice(i, j)
51+
np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4)
52+
np.testing.assert_allclose(output_dice.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4)
53+
54+
55+
if __name__ == "__main__":
56+
unittest.main()

0 commit comments

Comments
 (0)