Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial commit -- Adding calibration loss specific to segmentation #7819

Merged
merged 54 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
8fbec82
Initial commit -- Adding calibration loss specific to segmentation
Bala93 Jun 2, 2024
23b897b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 2, 2024
b2ec62b
Update __init__.py
Bala93 Jun 2, 2024
93ee114
Update segcalib.py
Bala93 Jun 2, 2024
42e732b
Update segcalib.py
Bala93 Jun 2, 2024
187053d
Update segcalib.py
Bala93 Jun 2, 2024
1d27ec5
Update segcalib.py
Bala93 Jun 2, 2024
d499134
Update segcalib.py
Bala93 Jun 3, 2024
1e3f755
Update segcalib.py
Bala93 Jun 3, 2024
9dedfba
Update segcalib.py
Bala93 Jun 4, 2024
59959ce
Update monai/losses/segcalib.py
Bala93 Jun 14, 2024
cf1d044
Update monai/losses/segcalib.py
Bala93 Jun 14, 2024
0926851
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 14, 2024
5317706
Update segcalib.py
Bala93 Jun 15, 2024
3155433
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 15, 2024
7c121a0
Add specific to gaussian for both 2d and 3d
Bala93 Aug 3, 2024
24efd85
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2024
0067953
Merge branch 'Project-MONAI:dev' into model-calibration
Bala93 Aug 3, 2024
dccde47
Add mean loss and resolve formatting
Bala93 Aug 3, 2024
44e8065
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2024
57686d7
Merge branch 'dev' into model-calibration
Bala93 Aug 3, 2024
5cd9a33
Update segcalib.py
Bala93 Aug 3, 2024
b547c4e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 3, 2024
42a0215
Update segcalib.py
Bala93 Aug 3, 2024
7e36ca1
Update segcalib.py
Bala93 Aug 3, 2024
6dbd53d
Update segcalib.py
Bala93 Aug 3, 2024
354056c
Update segcalib.py
Bala93 Aug 4, 2024
7eb911f
Update segcalib.py
Bala93 Aug 4, 2024
0b1209b
Update segcalib.py
Bala93 Aug 4, 2024
035c92e
Update segcalib.py
Bala93 Aug 4, 2024
c1de5f1
Rename segcalib.py to nacl_loss.py
Bala93 Aug 5, 2024
91dd1b9
Update __init__.py
Bala93 Aug 5, 2024
9702c02
Update test_nacl_loss.py
Bala93 Aug 5, 2024
4462379
Update nacl_loss.py
Bala93 Aug 5, 2024
c4f8283
Update test_nacl_loss.py
Bala93 Aug 5, 2024
bc6b995
Update test_nacl_loss.py
Bala93 Aug 5, 2024
51e15fe
Added missing parameters in doc
Bala93 Aug 5, 2024
3a00aec
Formatting check with monai
Bala93 Aug 5, 2024
818b42b
Update test_nacl_loss.py
Bala93 Aug 5, 2024
6647708
Added mypy fixes
Bala93 Aug 5, 2024
7e579dd
DCO Remediation Commit for bala93 <balamuralim.1993@gmail.com>
Bala93 Aug 5, 2024
4f8abf1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2024
b72e478
Update docs/source/losses.rst
Bala93 Aug 6, 2024
747681d
* Include test cases covering more cases
Bala93 Aug 7, 2024
3b15554
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 7, 2024
877139c
Update monai/losses/nacl_loss.py
Bala93 Aug 7, 2024
4679456
Update monai/losses/nacl_loss.py
Bala93 Aug 7, 2024
7c5217e
* Add docstring with better explanations
Bala93 Aug 7, 2024
d33f435
* Maintain the dimension consistency.
Bala93 Aug 7, 2024
7deb2cc
Update nacl_loss.py
Bala93 Aug 7, 2024
91ce50b
Update nacl_loss.py
Bala93 Aug 7, 2024
7f87e0c
Merge branch 'model-calibration' of https://github.com/Bala93/MONAI i…
Bala93 Aug 7, 2024
0e880a8
Modify docstring
Bala93 Aug 7, 2024
db9daeb
Merge branch 'dev' into model-calibration
KumoLiu Aug 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ Segmentation Losses
.. autoclass:: SoftDiceclDiceLoss
:members:

`NACLLoss`
~~~~~~~~~~~~~~~~~~~~
Bala93 marked this conversation as resolved.
Show resolved Hide resolved
.. autoclass:: NACLLoss
:members:

Registration Losses
-------------------

Expand Down
1 change: 1 addition & 0 deletions monai/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .hausdorff_loss import HausdorffDTLoss, LogHausdorffDTLoss
from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss
from .multi_scale import MultiScaleLoss
from .nacl_loss import NACLLoss
from .perceptual import PerceptualLoss
from .spatial_mask import MaskedLoss
from .spectral_loss import JukeboxLoss
Expand Down
279 changes: 279 additions & 0 deletions monai/losses/nacl_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import math
from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss

from monai.utils import pytorch_after


def get_mean_kernel_2d(ksize: int = 3) -> torch.Tensor:
mean_kernel = torch.ones([ksize, ksize]) / (ksize**2)
return mean_kernel


def get_mean_kernel_3d(ksize: int = 3) -> torch.Tensor:
Bala93 marked this conversation as resolved.
Show resolved Hide resolved
mean_kernel = torch.ones([ksize, ksize, ksize]) / (ksize**3)
return mean_kernel


def get_gaussian_kernel_2d(ksize: int = 3, sigma: float = 1.0) -> torch.Tensor:
Bala93 marked this conversation as resolved.
Show resolved Hide resolved
x_grid = torch.arange(ksize).repeat(ksize).view(ksize, ksize)
y_grid = x_grid.t()
xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
mean = (ksize - 1) / 2.0
variance = sigma**2.0
gaussian_kernel: torch.Tensor = (1.0 / (2.0 * math.pi * variance + 1e-16)) * torch.exp(
-torch.sum((xy_grid - mean) ** 2.0, dim=-1) / (2 * variance + 1e-16)
)
gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
return gaussian_kernel


def get_gaussian_kernel_3d(ksize: int = 3, sigma: float = 1.0) -> torch.Tensor:
x_coord = torch.arange(ksize)
x_grid_2d = x_coord.repeat(ksize).view(ksize, ksize)
x_grid = x_coord.repeat(ksize * ksize).view(ksize, ksize, ksize)
y_grid_2d = x_grid_2d.t()
y_grid = y_grid_2d.repeat(ksize, 1).view(ksize, ksize, ksize)
z_grid = y_grid_2d.repeat(1, ksize).view(ksize, ksize, ksize)
xyz_grid = torch.stack([x_grid, y_grid, z_grid], dim=-1).float()
mean = (ksize - 1) / 2.0
variance = sigma**2.0
gaussian_kernel: torch.Tensor = (1.0 / (2.0 * math.pi * variance + 1e-16)) * torch.exp(
-torch.sum((xyz_grid - mean) ** 2.0, dim=-1) / (2 * variance + 1e-16)
)
gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
return gaussian_kernel


class GaussianFilter(torch.nn.Module):
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, dim: int = 3, ksize: int = 3, sigma: float = 1.0, channels: int = 0) -> None:
super(GaussianFilter, self).__init__()

self.svls_kernel: torch.Tensor
self.svls_layer: Any

if dim == 2:
gkernel = get_gaussian_kernel_2d(ksize=ksize, sigma=sigma)
neighbors_sum = (1 - gkernel[1, 1]) + 1e-16
gkernel[int(ksize / 2), int(ksize / 2)] = neighbors_sum
self.svls_kernel = gkernel / neighbors_sum

svls_kernel_2d = self.svls_kernel.view(1, 1, ksize, ksize)
svls_kernel_2d = svls_kernel_2d.repeat(channels, 1, 1, 1)
padding = int(ksize / 2)

self.svls_layer = torch.nn.Conv2d(
in_channels=channels,
out_channels=channels,
kernel_size=ksize,
groups=channels,
bias=False,
padding=padding,
padding_mode="replicate",
)
self.svls_layer.weight.data = svls_kernel_2d
self.svls_layer.weight.requires_grad = False

if dim == 3:
gkernel = get_gaussian_kernel_3d(ksize=ksize, sigma=sigma)
neighbors_sum = 1 - gkernel[1, 1, 1]
gkernel[1, 1, 1] = neighbors_sum
self.svls_kernel = gkernel / neighbors_sum

svls_kernel_3d = self.svls_kernel.view(1, 1, ksize, ksize, ksize)
svls_kernel_3d = svls_kernel_3d.repeat(channels, 1, 1, 1, 1)
padding = int(ksize / 2)

self.svls_layer = torch.nn.Conv3d(
in_channels=channels,
out_channels=channels,
kernel_size=ksize,
groups=channels,
bias=False,
padding=padding,
padding_mode="replicate",
)
self.svls_layer.weight.data = svls_kernel_3d
self.svls_layer.weight.requires_grad = False

def forward(self, x: torch.Tensor) -> torch.Tensor:
svls_normalized: torch.Tensor = self.svls_layer(x) / self.svls_kernel.sum()
return svls_normalized


class MeanFilter(torch.nn.Module):
Bala93 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, dim: int = 3, ksize: int = 3, channels: int = 0) -> None:
super(MeanFilter, self).__init__()

self.svls_kernel: torch.Tensor
self.svls_layer: Any

if dim == 2:
self.svls_kernel = get_mean_kernel_2d(ksize=ksize)
svls_kernel_2d = self.svls_kernel.view(1, 1, ksize, ksize)
svls_kernel_2d = svls_kernel_2d.repeat(channels, 1, 1, 1)
padding = int(ksize / 2)

self.svls_layer = torch.nn.Conv2d(
in_channels=channels,
out_channels=channels,
kernel_size=ksize,
groups=channels,
bias=False,
padding=padding,
padding_mode="replicate",
)
self.svls_layer.weight.data = svls_kernel_2d
self.svls_layer.weight.requires_grad = False

if dim == 3:
self.svls_kernel = get_mean_kernel_3d(ksize=ksize)
svls_kernel_3d = self.svls_kernel.view(1, 1, ksize, ksize, ksize)
svls_kernel_3d = svls_kernel_3d.repeat(channels, 1, 1, 1, 1)
padding = int(ksize / 2)

self.svls_layer = torch.nn.Conv3d(
in_channels=channels,
out_channels=channels,
kernel_size=ksize,
groups=channels,
bias=False,
padding=padding,
padding_mode="replicate",
)
self.svls_layer.weight.data = svls_kernel_3d
self.svls_layer.weight.requires_grad = False

def forward(self, x: torch.Tensor) -> torch.Tensor:
svls_normalized: torch.Tensor = self.svls_layer(x) / self.svls_kernel.sum()
return svls_normalized


class NACLLoss(_Loss):
"""
Neighbor-Aware Calibration Loss (NACL) is primarily developed for developing calibrated models in image segmentation.
NACL computes standard cross-entropy loss with a linear penalty that enforces the logit distributions
to match a soft class proportion of surrounding pixel.

Murugesan, Balamurali, et al.
"Trust your neighbours: Penalty-based constraints for model calibration."
International Conference on Medical Image Computing and Computer-Assisted Intervention, MICCAI 2023.
https://arxiv.org/abs/2303.06268

Murugesan, Balamurali, et al.
"Neighbor-Aware Calibration of Segmentation Networks with Penalty-Based Constraints."
https://arxiv.org/abs/2401.14487
"""

def __init__(
self,
classes: int,
dim: int,
kernel_size: int = 3,
kernel_ops: str = "mean",
distance_type: str = "l1",
alpha: float = 0.1,
sigma: float = 1.0,
) -> None:
"""
Args:
classes: number of classes
dim: dimension of data (supports 2d and 3d)
kernel_size: size of the spatial kernel
distance_type: l1/l2 distance between spatial kernel and predicted logits
alpha: weightage between cross entropy and logit constraint
sigma: sigma of gaussian
"""

super().__init__()

if kernel_ops not in ["mean", "gaussian"]:
raise ValueError("Kernel ops must be either mean or gaussian")

if dim not in [2, 3]:
raise ValueError("Supoorts 2d and 3d")
Bala93 marked this conversation as resolved.
Show resolved Hide resolved

if distance_type not in ["l1", "l2"]:
raise ValueError("Distance type must be either L1 or L2")
Bala93 marked this conversation as resolved.
Show resolved Hide resolved

self.nc = classes
self.dim = dim
self.cross_entropy = nn.CrossEntropyLoss()
self.distance_type = distance_type
self.alpha = alpha
self.ks = kernel_size
self.svls_layer: Any

if kernel_ops == "mean":
self.svls_layer = MeanFilter(dim=dim, ksize=kernel_size, channels=classes)
if kernel_ops == "gaussian":
self.svls_layer = GaussianFilter(dim=dim, ksize=kernel_size, sigma=sigma, channels=classes)

self.old_pt_ver = not pytorch_after(1, 10)

# def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Bala93 marked this conversation as resolved.
Show resolved Hide resolved
# """
# Compute CrossEntropy loss for the input logits and target.
# Will remove the channel dim according to PyTorch CrossEntropyLoss:
# https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?#torch.nn.CrossEntropyLoss.

# """
# n_pred_ch, n_target_ch = input.shape[1], target.shape[1]
# if n_pred_ch != n_target_ch and n_target_ch == 1:
# target = torch.squeeze(target, dim=1)
# target = target.long()
# elif self.old_pt_ver:
# warnings.warn(
# f"Multichannel targets are not supported in this older Pytorch version {torch.__version__}. "
# "Using argmax (as a workaround) to convert target to a single channel."
# )
# target = torch.argmax(target, dim=1)
# elif not torch.is_floating_point(target):
# target = target.to(dtype=input.dtype)

# return self.cross_entropy(input, target) # type: ignore[no-any-return]

def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor:
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved

rmask: torch.Tensor

if self.dim == 2:
oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 3, 1, 2).float()
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
rmask = self.svls_layer(oh_labels)

if self.dim == 3:
oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 4, 1, 2, 3).float()
rmask = self.svls_layer(oh_labels)

return rmask

def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
loss_ce = self.cross_entropy(inputs, targets)

utargets = self.get_constr_target(targets)

if self.distance_type == "l1":
loss_conf = utargets.sub(inputs).abs_().mean()
elif self.distance_type == "l2":
loss_conf = utargets.sub(inputs).pow_(2).abs_().mean()

loss: torch.Tensor = loss_ce + self.alpha * loss_conf

return loss
Loading
Loading