-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Initial commit -- Adding calibration loss specific to segmentation #7819
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
54 commits
Select commit
Hold shift + click to select a range
8fbec82
Initial commit -- Adding calibration loss specific to segmentation
Bala93 23b897b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] b2ec62b
Update __init__.py
Bala93 93ee114
Update segcalib.py
Bala93 42e732b
Update segcalib.py
Bala93 187053d
Update segcalib.py
Bala93 1d27ec5
Update segcalib.py
Bala93 d499134
Update segcalib.py
Bala93 1e3f755
Update segcalib.py
Bala93 9dedfba
Update segcalib.py
Bala93 59959ce
Update monai/losses/segcalib.py
Bala93 cf1d044
Update monai/losses/segcalib.py
Bala93 0926851
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 5317706
Update segcalib.py
Bala93 3155433
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 7c121a0
Add specific to gaussian for both 2d and 3d
Bala93 24efd85
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 0067953
Merge branch 'Project-MONAI:dev' into model-calibration
Bala93 dccde47
Add mean loss and resolve formatting
Bala93 44e8065
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 57686d7
Merge branch 'dev' into model-calibration
Bala93 5cd9a33
Update segcalib.py
Bala93 b547c4e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 42a0215
Update segcalib.py
Bala93 7e36ca1
Update segcalib.py
Bala93 6dbd53d
Update segcalib.py
Bala93 354056c
Update segcalib.py
Bala93 7eb911f
Update segcalib.py
Bala93 0b1209b
Update segcalib.py
Bala93 035c92e
Update segcalib.py
Bala93 c1de5f1
Rename segcalib.py to nacl_loss.py
Bala93 91dd1b9
Update __init__.py
Bala93 9702c02
Update test_nacl_loss.py
Bala93 4462379
Update nacl_loss.py
Bala93 c4f8283
Update test_nacl_loss.py
Bala93 bc6b995
Update test_nacl_loss.py
Bala93 51e15fe
Added missing parameters in doc
Bala93 3a00aec
Formatting check with monai
Bala93 818b42b
Update test_nacl_loss.py
Bala93 6647708
Added mypy fixes
Bala93 7e579dd
DCO Remediation Commit for bala93 <balamuralim.1993@gmail.com>
Bala93 4f8abf1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] b72e478
Update docs/source/losses.rst
Bala93 747681d
* Include test cases covering more cases
Bala93 3b15554
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 877139c
Update monai/losses/nacl_loss.py
Bala93 4679456
Update monai/losses/nacl_loss.py
Bala93 7c5217e
* Add docstring with better explanations
Bala93 d33f435
* Maintain the dimension consistency.
Bala93 7deb2cc
Update nacl_loss.py
Bala93 91ce50b
Update nacl_loss.py
Bala93 7f87e0c
Merge branch 'model-calibration' of https://github.com/Bala93/MONAI i…
Bala93 0e880a8
Modify docstring
Bala93 db9daeb
Merge branch 'dev' into model-calibration
KumoLiu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.nn.modules.loss import _Loss | ||
|
||
from monai.networks.layers import GaussianFilter, MeanFilter | ||
|
||
|
||
class NACLLoss(_Loss): | ||
""" | ||
Neighbor-Aware Calibration Loss (NACL) is primarily developed for developing calibrated models in image segmentation. | ||
NACL computes standard cross-entropy loss with a linear penalty that enforces the logit distributions | ||
to match a soft class proportion of surrounding pixel. | ||
|
||
Murugesan, Balamurali, et al. | ||
"Trust your neighbours: Penalty-based constraints for model calibration." | ||
International Conference on Medical Image Computing and Computer-Assisted Intervention, MICCAI 2023. | ||
https://arxiv.org/abs/2303.06268 | ||
|
||
Murugesan, Balamurali, et al. | ||
"Neighbor-Aware Calibration of Segmentation Networks with Penalty-Based Constraints." | ||
https://arxiv.org/abs/2401.14487 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
classes: int, | ||
dim: int, | ||
kernel_size: int = 3, | ||
kernel_ops: str = "mean", | ||
distance_type: str = "l1", | ||
alpha: float = 0.1, | ||
sigma: float = 1.0, | ||
) -> None: | ||
""" | ||
Args: | ||
classes: number of classes | ||
dim: dimension of data (supports 2d and 3d) | ||
kernel_size: size of the spatial kernel | ||
distance_type: l1/l2 distance between spatial kernel and predicted logits | ||
alpha: weightage between cross entropy and logit constraint | ||
sigma: sigma of gaussian | ||
""" | ||
|
||
super().__init__() | ||
|
||
if kernel_ops not in ["mean", "gaussian"]: | ||
raise ValueError("Kernel ops must be either mean or gaussian") | ||
|
||
if dim not in [2, 3]: | ||
raise ValueError(f"Support 2d and 3d, got dim={dim}.") | ||
|
||
if distance_type not in ["l1", "l2"]: | ||
raise ValueError(f"Distance type must be either L1 or L2, got {distance_type}") | ||
|
||
self.nc = classes | ||
self.dim = dim | ||
self.cross_entropy = nn.CrossEntropyLoss() | ||
self.distance_type = distance_type | ||
self.alpha = alpha | ||
self.ks = kernel_size | ||
self.svls_layer: Any | ||
|
||
if kernel_ops == "mean": | ||
self.svls_layer = MeanFilter(spatial_dims=dim, size=kernel_size) | ||
self.svls_layer.filter = self.svls_layer.filter / (kernel_size**dim) | ||
if kernel_ops == "gaussian": | ||
self.svls_layer = GaussianFilter(spatial_dims=dim, sigma=sigma) | ||
|
||
def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Converts the mask to one hot represenation and is smoothened with the selected spatial filter. | ||
|
||
Args: | ||
mask: the shape should be BH[WD]. | ||
|
||
Returns: | ||
torch.Tensor: the shape would be BNH[WD], N being number of classes. | ||
""" | ||
rmask: torch.Tensor | ||
|
||
if self.dim == 2: | ||
oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 3, 1, 2).float() | ||
KumoLiu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
rmask = self.svls_layer(oh_labels) | ||
|
||
if self.dim == 3: | ||
oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 4, 1, 2, 3).float() | ||
rmask = self.svls_layer(oh_labels) | ||
|
||
return rmask | ||
|
||
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Computes standard cross-entropy loss and constraints it neighbor aware logit penalty. | ||
|
||
Args: | ||
inputs: the shape should be BNH[WD], where N is the number of classes. | ||
targets: the shape should be BH[WD]. | ||
|
||
Returns: | ||
torch.Tensor: value of the loss. | ||
|
||
Example: | ||
>>> import torch | ||
>>> from monai.losses import NACLLoss | ||
>>> B, N, H, W = 8, 3, 64, 64 | ||
>>> input = torch.rand(B, N, H, W) | ||
>>> target = torch.randint(0, N, (B, H, W)) | ||
>>> criterion = NACLLoss(classes = N, dim = 2) | ||
>>> loss = criterion(input, target) | ||
""" | ||
|
||
loss_ce = self.cross_entropy(inputs, targets) | ||
|
||
utargets = self.get_constr_target(targets) | ||
|
||
if self.distance_type == "l1": | ||
loss_conf = utargets.sub(inputs).abs_().mean() | ||
elif self.distance_type == "l2": | ||
loss_conf = utargets.sub(inputs).pow_(2).abs_().mean() | ||
|
||
loss: torch.Tensor = loss_ce + self.alpha * loss_conf | ||
|
||
return loss |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
import unittest | ||
|
||
import numpy as np | ||
import torch | ||
from parameterized import parameterized | ||
|
||
from monai.losses import NACLLoss | ||
|
||
inputs = torch.tensor( | ||
[ | ||
[ | ||
[ | ||
[0.1498, 0.1158, 0.3996, 0.3730], | ||
[0.2155, 0.1585, 0.8541, 0.8579], | ||
[0.6640, 0.2424, 0.0774, 0.0324], | ||
[0.0580, 0.2180, 0.3447, 0.8722], | ||
], | ||
[ | ||
[0.3908, 0.9366, 0.1779, 0.1003], | ||
[0.9630, 0.6118, 0.4405, 0.7916], | ||
[0.5782, 0.9515, 0.4088, 0.3946], | ||
[0.7860, 0.3910, 0.0324, 0.9568], | ||
], | ||
[ | ||
[0.0759, 0.0238, 0.5570, 0.1691], | ||
[0.2703, 0.7722, 0.1611, 0.6431], | ||
[0.8051, 0.6596, 0.4121, 0.1125], | ||
[0.5283, 0.6746, 0.5528, 0.7913], | ||
], | ||
] | ||
] | ||
) | ||
targets = torch.tensor([[[1, 1, 1, 1], [1, 1, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]]) | ||
|
||
TEST_CASES = [ | ||
[{"classes": 3, "dim": 2}, {"inputs": inputs, "targets": targets}, 1.1442], | ||
[{"classes": 3, "dim": 2, "kernel_ops": "gaussian"}, {"inputs": inputs, "targets": targets}, 1.1433], | ||
[{"classes": 3, "dim": 2, "kernel_ops": "gaussian", "sigma": 0.5}, {"inputs": inputs, "targets": targets}, 1.1469], | ||
[{"classes": 3, "dim": 2, "distance_type": "l2"}, {"inputs": inputs, "targets": targets}, 1.1269], | ||
[{"classes": 3, "dim": 2, "alpha": 0.2}, {"inputs": inputs, "targets": targets}, 1.1790], | ||
[ | ||
{"classes": 3, "dim": 3, "kernel_ops": "gaussian"}, | ||
{ | ||
"inputs": torch.tensor( | ||
[ | ||
[ | ||
[ | ||
[ | ||
[0.5977, 0.2767, 0.0591, 0.1675], | ||
[0.4835, 0.3778, 0.8406, 0.3065], | ||
[0.6047, 0.2860, 0.9742, 0.2013], | ||
[0.9128, 0.8368, 0.6711, 0.4384], | ||
], | ||
[ | ||
[0.9797, 0.1863, 0.5584, 0.6652], | ||
[0.2272, 0.2004, 0.7914, 0.4224], | ||
[0.5097, 0.8818, 0.2581, 0.3495], | ||
[0.1054, 0.5483, 0.3732, 0.3587], | ||
], | ||
[ | ||
[0.3060, 0.7066, 0.7922, 0.4689], | ||
[0.1733, 0.8902, 0.6704, 0.2037], | ||
[0.8656, 0.5561, 0.2701, 0.0092], | ||
[0.1866, 0.7714, 0.6424, 0.9791], | ||
], | ||
[ | ||
[0.5067, 0.3829, 0.6156, 0.8985], | ||
[0.5192, 0.8347, 0.2098, 0.2260], | ||
[0.8887, 0.3944, 0.6400, 0.5345], | ||
[0.1207, 0.3763, 0.5282, 0.7741], | ||
], | ||
], | ||
[ | ||
[ | ||
[0.8499, 0.4759, 0.1964, 0.5701], | ||
[0.3190, 0.1238, 0.2368, 0.9517], | ||
[0.0797, 0.6185, 0.0135, 0.8672], | ||
[0.4116, 0.1683, 0.1355, 0.0545], | ||
], | ||
[ | ||
[0.7533, 0.2658, 0.5955, 0.4498], | ||
[0.9500, 0.2317, 0.2825, 0.9763], | ||
[0.1493, 0.1558, 0.3743, 0.8723], | ||
[0.1723, 0.7980, 0.8816, 0.0133], | ||
], | ||
[ | ||
[0.8426, 0.2666, 0.2077, 0.3161], | ||
[0.1725, 0.8414, 0.1515, 0.2825], | ||
[0.4882, 0.5159, 0.4120, 0.1585], | ||
[0.2551, 0.9073, 0.7691, 0.9898], | ||
], | ||
[ | ||
[0.4633, 0.8717, 0.8537, 0.2899], | ||
[0.3693, 0.7953, 0.1183, 0.4596], | ||
[0.0087, 0.7925, 0.0989, 0.8385], | ||
[0.8261, 0.6920, 0.7069, 0.4464], | ||
], | ||
], | ||
[ | ||
[ | ||
[0.0110, 0.1608, 0.4814, 0.6317], | ||
[0.0194, 0.9669, 0.3259, 0.0028], | ||
[0.5674, 0.8286, 0.0306, 0.5309], | ||
[0.3973, 0.8183, 0.0238, 0.1934], | ||
], | ||
[ | ||
[0.8947, 0.6629, 0.9439, 0.8905], | ||
[0.0072, 0.1697, 0.4634, 0.0201], | ||
[0.7184, 0.2424, 0.0820, 0.7504], | ||
[0.3937, 0.1424, 0.4463, 0.5779], | ||
], | ||
[ | ||
[0.4123, 0.6227, 0.0523, 0.8826], | ||
[0.0051, 0.0353, 0.3662, 0.7697], | ||
[0.4867, 0.8986, 0.2510, 0.5316], | ||
[0.1856, 0.2634, 0.9140, 0.9725], | ||
], | ||
[ | ||
[0.2041, 0.4248, 0.2371, 0.7256], | ||
[0.2168, 0.5380, 0.4538, 0.7007], | ||
[0.9013, 0.2623, 0.0739, 0.2998], | ||
[0.1366, 0.5590, 0.2952, 0.4592], | ||
], | ||
], | ||
] | ||
] | ||
), | ||
"targets": torch.tensor( | ||
[ | ||
[ | ||
[[0, 1, 0, 1], [1, 2, 1, 0], [2, 1, 1, 1], [1, 1, 0, 1]], | ||
[[2, 1, 0, 2], [1, 2, 0, 2], [1, 0, 1, 1], [1, 1, 0, 0]], | ||
[[1, 0, 2, 1], [0, 2, 2, 1], [1, 0, 1, 1], [0, 0, 2, 1]], | ||
[[2, 1, 1, 0], [1, 0, 0, 2], [1, 0, 2, 1], [2, 1, 0, 1]], | ||
] | ||
] | ||
), | ||
}, | ||
1.15035, | ||
], | ||
] | ||
|
||
|
||
class TestNACLLoss(unittest.TestCase): | ||
@parameterized.expand(TEST_CASES) | ||
def test_result(self, input_param, input_data, expected_val): | ||
loss = NACLLoss(**input_param) | ||
result = loss(**input_data) | ||
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.