Skip to content

Add alpha parameter to DiceFocalLoss #7841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 29, 2024
Merged
13 changes: 10 additions & 3 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,7 @@ class DiceFocalLoss(_Loss):
The details of Focal Loss is shown in ``monai.losses.FocalLoss``.

``gamma`` and ``lambda_focal`` are only used for the focal loss.
``include_background``, ``weight`` and ``reduction`` are used for both losses
``include_background``, ``weight``, ``reduction``, and ``alpha`` are used for both losses,
and other parameters are only used for dice loss.

"""
Expand All @@ -837,6 +837,7 @@ def __init__(
weight: Sequence[float] | float | int | torch.Tensor | None = None,
lambda_dice: float = 1.0,
lambda_focal: float = 1.0,
alpha: float | None = None,
) -> None:
"""
Args:
Expand Down Expand Up @@ -871,7 +872,8 @@ def __init__(
Defaults to 1.0.
lambda_focal: the trade-off weight value for focal loss. The value should be no less than 0.0.
Defaults to 1.0.

alpha: value of the alpha in the definition of the alpha-balanced Focal loss. The value should be in
[0, 1]. Defaults to None.
"""
super().__init__()
weight = focal_weight if focal_weight is not None else weight
Expand All @@ -890,7 +892,12 @@ def __init__(
weight=weight,
)
self.focal = FocalLoss(
include_background=include_background, to_onehot_y=False, gamma=gamma, weight=weight, reduction=reduction
include_background=include_background,
to_onehot_y=False,
gamma=gamma,
weight=weight,
alpha=alpha,
reduction=reduction,
)
if lambda_dice < 0.0:
raise ValueError("lambda_dice should be no less than 0.0.")
Expand Down
29 changes: 29 additions & 0 deletions tests/test_dice_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,35 @@ def test_script(self):
test_input = torch.ones(2, 1, 8, 8)
test_script_save(loss, test_input, test_input)

@parameterized.expand(
[
("sum_None_0.5_0.25", "sum", None, 0.5, 0.25),
("sum_weight_0.5_0.25", "sum", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25),
("sum_weight_tuple_0.5_0.25", "sum", (3, 2.0, 1), 0.5, 0.25),
("mean_None_0.5_0.25", "mean", None, 0.5, 0.25),
("mean_weight_0.5_0.25", "mean", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25),
("mean_weight_tuple_0.5_0.25", "mean", (3, 2.0, 1), 0.5, 0.25),
("none_None_0.5_0.25", "none", None, 0.5, 0.25),
("none_weight_0.5_0.25", "none", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25),
("none_weight_tuple_0.5_0.25", "none", (3, 2.0, 1), 0.5, 0.25),
]
)
def test_with_alpha(self, name, reduction, weight, lambda_focal, alpha):
size = [3, 3, 5, 5]
label = torch.randint(low=0, high=2, size=size)
pred = torch.randn(size)

common_params = {"include_background": True, "to_onehot_y": False, "reduction": reduction, "weight": weight}

dice_focal = DiceFocalLoss(gamma=1.0, lambda_focal=lambda_focal, alpha=alpha, **common_params)
dice = DiceLoss(**common_params)
focal = FocalLoss(gamma=1.0, alpha=alpha, **common_params)

result = dice_focal(pred, label)
expected_val = dice(pred, label) + lambda_focal * focal(pred, label)

np.testing.assert_allclose(result, expected_val, err_msg=f"Failed on case: {name}")


if __name__ == "__main__":
unittest.main()
Loading