Skip to content

Commit

Permalink
[Bug fixed]Fix dice_loss errors (open-mmlab#417)
Browse files Browse the repository at this point in the history
* fix training bugs

* fix unitest error

* fix error in num_classes==2 case

* delete comments
  • Loading branch information
谢昕辰 authored Mar 29, 2021
1 parent d474cfd commit 71be1c2
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 27 deletions.
31 changes: 17 additions & 14 deletions mmseg/models/losses/dice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def dice_loss(pred,
smooth=1,
exponent=2,
class_weight=None,
ignore_index=-1):
ignore_index=255):
assert pred.shape[0] == target.shape[0]
total_loss = 0
num_classes = pred.shape[1]
Expand All @@ -36,9 +36,9 @@ def dice_loss(pred,
@weighted_loss
def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards):
assert pred.shape[0] == target.shape[0]
pred = pred.contiguous().view(pred.shape[0], -1)
target = target.contiguous().view(target.shape[0], -1)
valid_mask = valid_mask.contiguous().view(valid_mask.shape[0], -1)
pred = pred.reshape(pred.shape[0], -1)
target = target.reshape(target.shape[0], -1)
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)

num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth
den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth
Expand Down Expand Up @@ -70,27 +70,27 @@ class DiceLoss(nn.Module):
"""

def __init__(self,
loss_type='multi_class',
smooth=1,
exponent=2,
reduction='mean',
class_weight=None,
loss_weight=1.0,
ignore_index=255):
ignore_index=255,
**kwards):
super(DiceLoss, self).__init__()
assert loss_type in ['multi_class', 'binary']
if loss_type == 'multi_class':
self.cls_criterion = dice_loss
else:
self.cls_criterion = binary_dice_loss
self.smooth = smooth
self.exponent = exponent
self.reduction = reduction
self.class_weight = class_weight
self.loss_weight = loss_weight
self.ignore_index = ignore_index

def forward(self, pred, target, avg_factor=None, reduction_override=None):
def forward(self,
pred,
target,
avg_factor=None,
reduction_override=None,
**kwards):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
Expand All @@ -100,10 +100,13 @@ def forward(self, pred, target, avg_factor=None, reduction_override=None):
class_weight = None

pred = F.softmax(pred, dim=1)
one_hot_target = F.one_hot(torch.clamp_min(target.long(), 0))
num_classes = pred.shape[1]
one_hot_target = F.one_hot(
torch.clamp(target.long(), 0, num_classes - 1),
num_classes=num_classes)
valid_mask = (target != self.ignore_index).long()

loss = self.loss_weight * self.cls_criterion(
loss = self.loss_weight * dice_loss(
pred,
one_hot_target,
valid_mask=valid_mask,
Expand Down
15 changes: 2 additions & 13 deletions tests/test_models/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,19 +207,9 @@ def test_lovasz_loss():
def test_dice_lose():
from mmseg.models import build_loss

# loss_type should be 'binary' or 'multi_class'
with pytest.raises(AssertionError):
loss_cfg = dict(
type='DiceLoss',
loss_type='Binary',
reduction='none',
loss_weight=1.0)
build_loss(loss_cfg)

# test dice loss with loss_type = 'multi_class'
loss_cfg = dict(
type='DiceLoss',
loss_type='multi_class',
reduction='none',
class_weight=[1.0, 2.0, 3.0],
loss_weight=1.0,
Expand All @@ -232,13 +222,12 @@ def test_dice_lose():
# test dice loss with loss_type = 'binary'
loss_cfg = dict(
type='DiceLoss',
loss_type='binary',
smooth=2,
exponent=3,
reduction='sum',
loss_weight=1.0,
ignore_index=0)
dice_loss = build_loss(loss_cfg)
logits = torch.rand(16, 4, 4)
labels = (torch.rand(16, 4, 4)).long()
logits = torch.rand(8, 2, 4, 4)
labels = (torch.rand(8, 4, 4) * 2).long()
dice_loss(logits, labels)

0 comments on commit 71be1c2

Please sign in to comment.