Skip to content

[Feature] Support multiple losses during training #818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Sep 24, 2021
18 changes: 18 additions & 0 deletions docs/tutorials/training_tricks.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,21 @@ model=dict(
```

`class_weight` will be passed into `CrossEntropyLoss` as `weight` argument. Please refer to [PyTorch Doc](https://pytorch.org/docs/stable/nn.html?highlight=crossentropy#torch.nn.CrossEntropyLoss) for details.

## Multiple Losses

For loss calculation, we support multiple losses training concurrently. Here is an example config of training `unet` on `DRIVE` dataset, whose loss function is `1:3` weighted sum of `CrossEntropyLoss` and `DiceLoss`:

```python
_base_ = './fcn_unet_s5-d16_64x64_40k_drive.py'
model = dict(
decode_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
auxiliary_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce',loss_weight=1.0),
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
)
```

In this way, `loss_weight` and `loss_name` will be weight and name in training log of corresponding loss, respectively.

Note: If you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name.
19 changes: 19 additions & 0 deletions docs_zh-CN/tutorials/training_tricks.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,22 @@ model=dict(
```

`class_weight` 将被作为 `weight` 参数,传递给 `CrossEntropyLoss`。详细信息请参照 [PyTorch 文档](https://pytorch.org/docs/stable/nn.html?highlight=crossentropy#torch.nn.CrossEntropyLoss) 。

## 同时使用多种损失函数 (Multiple Losses)

对于训练时损失函数的计算,我们目前支持多个损失函数同时使用。 以 `unet` 使用 `DRIVE` 数据集训练为例,
使用 `CrossEntropyLoss` 和 `DiceLoss` 的 `1:3` 的加权和作为损失函数。配置文件写为:

```python
_base_ = './fcn_unet_s5-d16_64x64_40k_drive.py'
model = dict(
decode_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
auxiliary_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce',loss_weight=1.0),
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
)
```

通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name`。

注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。
14 changes: 8 additions & 6 deletions mmseg/core/seg/sampler/ohem_pixel_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,14 @@ def sample(self, seg_logit, seg_label):
threshold = max(min_threshold, self.thresh)
valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
else:
losses = self.context.loss_decode(
seg_logit,
seg_label,
weight=None,
ignore_index=self.context.ignore_index,
reduction_override='none')
losses = 0.0
for loss_module in self.context.loss_decode:
losses += loss_module(
seg_logit,
seg_label,
weight=None,
ignore_index=self.context.ignore_index,
reduction_override='none')
# faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
_, sort_indices = losses[valid_mask].sort(descending=True)
valid_seg_weight[sort_indices[:batch_kept]] = 1.
Expand Down
43 changes: 35 additions & 8 deletions mmseg/models/decode_heads/decode_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,17 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
a list and passed into decode head.
None: Only one select feature map is allowed.
Default: None.
loss_decode (dict): Config of decode loss.
loss_decode (dict | Sequence[dict]): Config of decode loss.
The `loss_name` is property of corresponding loss function which
could be shown in training log. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_ce'.
e.g. dict(type='CrossEntropyLoss'),
[dict(type='CrossEntropyLoss', loss_name='loss_ce'),
dict(type='DiceLoss', loss_name='loss_dice')]
Default: dict(type='CrossEntropyLoss').
ignore_index (int | None): The label index to be ignored. When using
masked BCE loss, ignore_index should be set to None. Default: 255
masked BCE loss, ignore_index should be set to None. Default: 255.
sampler (dict|None): The config of segmentation map sampler.
Default: None.
align_corners (bool): align_corners argument of F.interpolate.
Expand Down Expand Up @@ -73,9 +80,20 @@ def __init__(self,
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.in_index = in_index
self.loss_decode = build_loss(loss_decode)

self.ignore_index = ignore_index
self.align_corners = align_corners
self.loss_decode = nn.ModuleList()

if isinstance(loss_decode, dict):
self.loss_decode.append(build_loss(loss_decode))
elif isinstance(loss_decode, (list, tuple)):
for loss in loss_decode:
self.loss_decode.append(build_loss(loss))
else:
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
but got {type(loss_decode)}')

if sampler is not None:
self.sampler = build_pixel_sampler(sampler, context=self)
else:
Expand Down Expand Up @@ -224,10 +242,19 @@ def losses(self, seg_logit, seg_label):
else:
seg_weight = None
seg_label = seg_label.squeeze(1)
loss['loss_seg'] = self.loss_decode(
seg_logit,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
for loss_decode in self.loss_decode:
if loss_decode.loss_name not in loss:
loss[loss_decode.loss_name] = loss_decode(
seg_logit,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
else:
loss[loss_decode.loss_name] += loss_decode(
seg_logit,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)

loss['acc_seg'] = accuracy(seg_logit, seg_label)
return loss
5 changes: 3 additions & 2 deletions mmseg/models/decode_heads/point_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,9 @@ def forward_test(self, inputs, prev_output, img_metas, test_cfg):
def losses(self, point_logits, point_label):
"""Compute segmentation loss."""
loss = dict()
loss['loss_point'] = self.loss_decode(
point_logits, point_label, ignore_index=self.ignore_index)
for loss_module in self.loss_decode:
loss['point' + loss_module.loss_name] = loss_module(
point_logits, point_label, ignore_index=self.ignore_index)
loss['acc_point'] = accuracy(point_logits, point_label)
return loss

Expand Down
21 changes: 20 additions & 1 deletion mmseg/models/losses/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,18 @@ class CrossEntropyLoss(nn.Module):
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_ce'.
"""

def __init__(self,
use_sigmoid=False,
use_mask=False,
reduction='mean',
class_weight=None,
loss_weight=1.0):
loss_weight=1.0,
loss_name='loss_ce'):
super(CrossEntropyLoss, self).__init__()
assert (use_sigmoid is False) or (use_mask is False)
self.use_sigmoid = use_sigmoid
Expand All @@ -172,6 +176,7 @@ def __init__(self,
self.cls_criterion = mask_cross_entropy
else:
self.cls_criterion = cross_entropy
self._loss_name = loss_name

def forward(self,
cls_score,
Expand All @@ -197,3 +202,17 @@ def forward(self,
avg_factor=avg_factor,
**kwargs)
return loss_cls

@property
def loss_name(self):
"""Loss Name.

This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
19 changes: 19 additions & 0 deletions mmseg/models/losses/dice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class DiceLoss(nn.Module):
str format, read them from a file. Defaults to None.
loss_weight (float, optional): Weight of the loss. Default to 1.0.
ignore_index (int | None): The label index to be ignored. Default: 255.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_dice'.
"""

def __init__(self,
Expand All @@ -77,6 +80,7 @@ def __init__(self,
class_weight=None,
loss_weight=1.0,
ignore_index=255,
loss_name='loss_dice',
**kwards):
super(DiceLoss, self).__init__()
self.smooth = smooth
Expand All @@ -85,6 +89,7 @@ def __init__(self,
self.class_weight = get_class_weight(class_weight)
self.loss_weight = loss_weight
self.ignore_index = ignore_index
self._loss_name = loss_name

def forward(self,
pred,
Expand Down Expand Up @@ -118,3 +123,17 @@ def forward(self,
class_weight=class_weight,
ignore_index=self.ignore_index)
return loss

@property
def loss_name(self):
"""Loss Name.

This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
21 changes: 20 additions & 1 deletion mmseg/models/losses/lovasz_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,9 @@ class LovaszLoss(nn.Module):
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_lovasz'.
"""

def __init__(self,
Expand All @@ -252,7 +255,8 @@ def __init__(self,
per_image=False,
reduction='mean',
class_weight=None,
loss_weight=1.0):
loss_weight=1.0,
loss_name='loss_lovasz'):
super(LovaszLoss, self).__init__()
assert loss_type in ('binary', 'multi_class'), "loss_type should be \
'binary' or 'multi_class'."
Expand All @@ -271,6 +275,7 @@ def __init__(self,
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = get_class_weight(class_weight)
self._loss_name = loss_name

def forward(self,
cls_score,
Expand Down Expand Up @@ -302,3 +307,17 @@ def forward(self,
avg_factor=avg_factor,
**kwargs)
return loss_cls

@property
def loss_name(self):
"""Loss Name.

This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
89 changes: 89 additions & 0 deletions tests/test_models/test_heads/test_decode_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,92 @@ def test_decode_head():
assert head.input_transform == 'resize_concat'
transformed_inputs = head._transform_inputs(inputs)
assert transformed_inputs.shape == (1, 48, 45, 45)

# test multi-loss, loss_decode is dict
with pytest.raises(TypeError):
# loss_decode must be a dict or sequence of dict.
BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss'])

inputs = torch.randn(2, 19, 8, 8).float()
target = torch.ones(2, 1, 64, 64).long()
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss = head.losses(seg_logit=inputs, seg_label=target)
assert 'loss_ce' in loss

# test multi-loss, loss_decode is list of dict
inputs = torch.randn(2, 19, 8, 8).float()
target = torch.ones(2, 1, 64, 64).long()
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=[
dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2')
])
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss = head.losses(seg_logit=inputs, seg_label=target)
assert 'loss_1' in loss
assert 'loss_2' in loss

# 'loss_decode' must be a dict or sequence of dict
with pytest.raises(TypeError):
BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss'])
with pytest.raises(TypeError):
BaseDecodeHead(3, 16, num_classes=19, loss_decode=0)

# test multi-loss, loss_decode is list of dict
inputs = torch.randn(2, 19, 8, 8).float()
target = torch.ones(2, 1, 64, 64).long()
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2'),
dict(type='CrossEntropyLoss', loss_name='loss_3')))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss = head.losses(seg_logit=inputs, seg_label=target)
assert 'loss_1' in loss
assert 'loss_2' in loss
assert 'loss_3' in loss

# test multi-loss, loss_decode is list of dict, names of them are identical
inputs = torch.randn(2, 19, 8, 8).float()
target = torch.ones(2, 1, 64, 64).long()
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce'),
dict(type='CrossEntropyLoss', loss_name='loss_ce'),
dict(type='CrossEntropyLoss', loss_name='loss_ce')))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss_3 = head.losses(seg_logit=inputs, seg_label=target)

head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce')))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss = head.losses(seg_logit=inputs, seg_label=target)
assert 'loss_ce' in loss
assert 'loss_ce' in loss_3
assert loss_3['loss_ce'] == 3 * loss['loss_ce']
Loading