Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add softmax version to focal_loss #6544

Merged
merged 13 commits into from
May 27, 2023
160 changes: 106 additions & 54 deletions monai/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ class FocalLoss(_Loss):
FocalLoss is an extension of BCEWithLogitsLoss that down-weights loss from
high confidence correct predictions.

Reimplementation of the Focal Loss (with a build-in sigmoid activation) described in:
Reimplementation of the Focal Loss described in:

- "Focal Loss for Dense Object Detection", T. Lin et al., ICCV 2017
- "AnatomyNet: Deep learning for fast and fully automated wholevolume segmentation of head and neck anatomy",
- ["Focal Loss for Dense Object Detection"](https://arxiv.org/abs/1708.02002), T. Lin et al., ICCV 2017
- "AnatomyNet: Deep learning for fast and fully automated whole-volume segmentation of head and neck anatomy",
Zhu et al., Medical Physics 2018

Example:
Expand Down Expand Up @@ -70,19 +70,23 @@ def __init__(
include_background: bool = True,
to_onehot_y: bool = False,
gamma: float = 2.0,
alpha: float | None = None,
weight: Sequence[float] | float | int | torch.Tensor | None = None,
reduction: LossReduction | str = LossReduction.MEAN,
use_softmax: bool = False,
) -> None:
"""
Args:
include_background: if False, channel index 0 (background category) is excluded from the calculation.
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
gamma: value of the exponent gamma in the definition of the Focal loss.
include_background: if False, channel index 0 (background category) is excluded from the loss calculation.
If False, `alpha` is invalid when using softmax.
to_onehot_y: whether to convert the label `y` into the one-hot format. Defaults to False.
gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
alpha: value of the alpha in the definition of the alpha-balanced Focal loss.
The value should be in [0, 1]. Defaults to None.
weight: weights to apply to the voxels of each class. If None no weights are applied.
This corresponds to the weights `\alpha` in [1].
The input can be a single value (same weight for all classes), a sequence of values (the length
of the sequence should be the same as the number of classes, if not ``include_background``, the
number should not include class 0).
of the sequence should be the same as the number of classes. If not ``include_background``,
the number of classes should not include the background category class 0).
The value/values should be no less than 0. Defaults to None.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
Expand All @@ -91,6 +95,9 @@ def __init__(
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.

use_softmax: whether to use softmax to transform the original logits into probabilities.
If True, softmax is used. If False, sigmoid is used. Defaults to False.

Example:
>>> import torch
>>> from monai.losses import FocalLoss
Expand All @@ -103,14 +110,16 @@ def __init__(
self.include_background = include_background
self.to_onehot_y = to_onehot_y
self.gamma = gamma
self.weight: Sequence[float] | float | int | torch.Tensor | None = weight
self.alpha = alpha
self.weight = weight
self.use_softmax = use_softmax

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
input: the shape should be BNH[WD], where N is the number of classes.
The input should be the original logits since it will be transformed by
a sigmoid in the forward function.
a sigmoid/softmax in the forward function.
target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.

Raises:
Expand Down Expand Up @@ -141,63 +150,106 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if target.shape != input.shape:
raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")

i = input
t = target

# Change the shape of input and target to B x N x num_voxels.
b, n = t.shape[:2]
i = i.reshape(b, n, -1)
t = t.reshape(b, n, -1)

# computing binary cross entropy with logits
# see also https://github.com/pytorch/pytorch/blob/v1.9.0/aten/src/ATen/native/Loss.cpp#L231
max_val = (-i).clamp(min=0)
ce = i - i * t + max_val + ((-max_val).exp() + (-i - max_val).exp()).log()
loss: Optional[torch.Tensor] = None
input = input.float()
target = target.float()
if self.use_softmax:
if not self.include_background and self.alpha is not None:
self.alpha = None
warnings.warn("`include_background=False`, `alpha` ignored when using softmax.")
loss = softmax_focal_loss(input, target, self.gamma, self.alpha)
else:
loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha)

if self.weight is not None:
# make sure the lengths of weights are equal to the number of classes
class_weight: Optional[torch.Tensor] = None
num_of_classes = target.shape[1]
if isinstance(self.weight, (float, int)):
class_weight = torch.as_tensor([self.weight] * i.size(1))
class_weight = torch.as_tensor([self.weight] * num_of_classes)
else:
class_weight = torch.as_tensor(self.weight)
if class_weight.size(0) != i.size(1):
if class_weight.shape[0] != num_of_classes:
raise ValueError(
"the length of the weight sequence should be the same as the number of classes. "
+ "If `include_background=False`, the number should not include class 0."
"""the length of the `weight` sequence should be the same as the number of classes.
If `include_background=False`, the weight should not include
the background category class 0."""
)
if class_weight.min() < 0:
raise ValueError("the value/values of weights should be no less than 0.")
class_weight = class_weight.to(i)
# Convert the weight to a map in which each voxel
# has the weight associated with the ground-truth label
# associated with this voxel in target.
at = class_weight[None, :, None] # N => 1,N,1
at = at.expand((t.size(0), -1, t.size(2))) # 1,N,1 => B,N,H*W
# Multiply the log proba by their weights.
ce = ce * at

# Compute the loss mini-batch.
# (1-p_t)^gamma * log(p_t) with reduced chance of overflow
p = F.logsigmoid(-i * (t * 2.0 - 1.0))
flat_loss: torch.Tensor = (p * self.gamma).exp() * ce

# Previously there was a mean over the last dimension, which did not
# return a compatible BCE loss. To maintain backwards compatible
# behavior we have a flag that performs this extra step, disable or
# parameterize if necessary. (Or justify why the mean should be there)
average_spatial_dims = True
raise ValueError("the value/values of the `weight` should be no less than 0.")
# apply class_weight to loss
class_weight = class_weight.to(loss)
broadcast_dims = [-1] + [1] * len(target.shape[2:])
class_weight = class_weight.view(broadcast_dims)
loss = class_weight * loss

if self.reduction == LossReduction.SUM.value:
# Previously there was a mean over the last dimension, which did not
# return a compatible BCE loss. To maintain backwards compatible
# behavior we have a flag that performs this extra step, disable or
# parameterize if necessary. (Or justify why the mean should be there)
average_spatial_dims = True
if average_spatial_dims:
flat_loss = flat_loss.mean(dim=-1)
loss = flat_loss.sum()
loss = loss.mean(dim=list(range(2, len(target.shape))))
loss = loss.sum()
elif self.reduction == LossReduction.MEAN.value:
if average_spatial_dims:
flat_loss = flat_loss.mean(dim=-1)
loss = flat_loss.mean()
loss = loss.mean()
elif self.reduction == LossReduction.NONE.value:
spacetime_dims = input.shape[2:]
loss = flat_loss.reshape([b, n] + list(spacetime_dims))
pass
else:
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
return loss


def softmax_focal_loss(
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None
) -> torch.Tensor:
"""
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)

where p_i = exp(s_i) / sum_j exp(s_j), t is the target (ground truth) class, and
s_j is the unnormalized score for class j.
"""
input_ls = input.log_softmax(1)
loss: torch.Tensor = -(1 - input_ls.exp()).pow(gamma) * input_ls * target

if alpha is not None:
# (1-alpha) for the background class and alpha for the other classes
alpha_fac = torch.tensor([1 - alpha] + [alpha] * (target.shape[1] - 1)).to(loss)
broadcast_dims = [-1] + [1] * len(target.shape[2:])
alpha_fac = alpha_fac.view(broadcast_dims)
loss = alpha_fac * loss

return loss


def sigmoid_focal_loss(
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None
) -> torch.Tensor:
"""
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)

where p = sigmoid(x), pt = p if label is 1 or 1 - p if label is 0
"""
# computing binary cross entropy with logits
# equivalent to F.binary_cross_entropy_with_logits(input, target, reduction='none')
# see also https://github.com/pytorch/pytorch/blob/v1.9.0/aten/src/ATen/native/Loss.cpp#L231
max_val = (-input).clamp(min=0)
loss: torch.Tensor = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log()

# sigmoid(-i) if t==1; sigmoid(i) if t==0 <=>
# 1-sigmoid(i) if t==1; sigmoid(i) if t==0 <=>
# 1-p if t==1; p if t==0 <=>
# pfac, that is, the term (1 - pt)
invprobs = F.logsigmoid(-input * (target * 2 - 1)) # reduced chance of overflow
# (pfac.log() * gamma).exp() <=>
# pfac.log().exp() ^ gamma <=>
# pfac ^ gamma
loss = (invprobs * gamma).exp() * loss

if alpha is not None:
# alpha if t==1; (1-alpha) if t==0
alpha_factor = target * alpha + (1 - target) * (1 - alpha)
loss = alpha_factor * loss

return loss
Loading