Skip to content

Commit 2f62b81

Browse files
kephalepre-commit-ci[bot]ericspodKumoLiu
authored
Add alpha parameter to DiceFocalLoss (#7841)
Fixes #7682. ### Description This PR introduces the `alpha` parameter from `FocalLoss` into `DiceFocalLoss`. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Kyle Harrington <czi@kyleharrington.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent 06cbd70 commit 2f62b81

File tree

2 files changed

+39
-3
lines changed

2 files changed

+39
-3
lines changed

monai/losses/dice.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -811,7 +811,7 @@ class DiceFocalLoss(_Loss):
811811
The details of Focal Loss is shown in ``monai.losses.FocalLoss``.
812812
813813
``gamma`` and ``lambda_focal`` are only used for the focal loss.
814-
``include_background``, ``weight`` and ``reduction`` are used for both losses
814+
``include_background``, ``weight``, ``reduction``, and ``alpha`` are used for both losses,
815815
and other parameters are only used for dice loss.
816816
817817
"""
@@ -837,6 +837,7 @@ def __init__(
837837
weight: Sequence[float] | float | int | torch.Tensor | None = None,
838838
lambda_dice: float = 1.0,
839839
lambda_focal: float = 1.0,
840+
alpha: float | None = None,
840841
) -> None:
841842
"""
842843
Args:
@@ -871,7 +872,8 @@ def __init__(
871872
Defaults to 1.0.
872873
lambda_focal: the trade-off weight value for focal loss. The value should be no less than 0.0.
873874
Defaults to 1.0.
874-
875+
alpha: value of the alpha in the definition of the alpha-balanced Focal loss. The value should be in
876+
[0, 1]. Defaults to None.
875877
"""
876878
super().__init__()
877879
weight = focal_weight if focal_weight is not None else weight
@@ -890,7 +892,12 @@ def __init__(
890892
weight=weight,
891893
)
892894
self.focal = FocalLoss(
893-
include_background=include_background, to_onehot_y=False, gamma=gamma, weight=weight, reduction=reduction
895+
include_background=include_background,
896+
to_onehot_y=False,
897+
gamma=gamma,
898+
weight=weight,
899+
alpha=alpha,
900+
reduction=reduction,
894901
)
895902
if lambda_dice < 0.0:
896903
raise ValueError("lambda_dice should be no less than 0.0.")

tests/test_dice_focal_loss.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,35 @@ def test_script(self):
9191
test_input = torch.ones(2, 1, 8, 8)
9292
test_script_save(loss, test_input, test_input)
9393

94+
@parameterized.expand(
95+
[
96+
("sum_None_0.5_0.25", "sum", None, 0.5, 0.25),
97+
("sum_weight_0.5_0.25", "sum", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25),
98+
("sum_weight_tuple_0.5_0.25", "sum", (3, 2.0, 1), 0.5, 0.25),
99+
("mean_None_0.5_0.25", "mean", None, 0.5, 0.25),
100+
("mean_weight_0.5_0.25", "mean", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25),
101+
("mean_weight_tuple_0.5_0.25", "mean", (3, 2.0, 1), 0.5, 0.25),
102+
("none_None_0.5_0.25", "none", None, 0.5, 0.25),
103+
("none_weight_0.5_0.25", "none", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25),
104+
("none_weight_tuple_0.5_0.25", "none", (3, 2.0, 1), 0.5, 0.25),
105+
]
106+
)
107+
def test_with_alpha(self, name, reduction, weight, lambda_focal, alpha):
108+
size = [3, 3, 5, 5]
109+
label = torch.randint(low=0, high=2, size=size)
110+
pred = torch.randn(size)
111+
112+
common_params = {"include_background": True, "to_onehot_y": False, "reduction": reduction, "weight": weight}
113+
114+
dice_focal = DiceFocalLoss(gamma=1.0, lambda_focal=lambda_focal, alpha=alpha, **common_params)
115+
dice = DiceLoss(**common_params)
116+
focal = FocalLoss(gamma=1.0, alpha=alpha, **common_params)
117+
118+
result = dice_focal(pred, label)
119+
expected_val = dice(pred, label) + lambda_focal * focal(pred, label)
120+
121+
np.testing.assert_allclose(result, expected_val, err_msg=f"Failed on case: {name}")
122+
94123

95124
if __name__ == "__main__":
96125
unittest.main()

0 commit comments

Comments
 (0)