Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ def __init__(
weight: torch.Tensor | None = None,
lambda_dice: float = 1.0,
lambda_ce: float = 1.0,
label_smoothing: float = 0.0,
) -> None:
"""
Args:
Expand Down Expand Up @@ -728,7 +729,7 @@ def __init__(
batch=batch,
weight=dice_weight,
)
self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction)
self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction, label_smoothing=label_smoothing)
self.binary_cross_entropy = nn.BCEWithLogitsLoss(pos_weight=weight, reduction=reduction)
if lambda_dice < 0.0:
raise ValueError("lambda_dice should be no less than 0.0.")
Expand All @@ -737,6 +738,7 @@ def __init__(
self.lambda_dice = lambda_dice
self.lambda_ce = lambda_ce
self.old_pt_ver = not pytorch_after(1, 10)
self.label_smoothing = label_smoothing

def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Expand Down