Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support multiple losses during training #818

Merged
merged 15 commits into from
Sep 24, 2021
Prev Previous commit
Next Next commit
loss_name must has 'loss_' prefix
  • Loading branch information
MengzhangLI committed Sep 23, 2021
commit 96213d392c62dde512201ad5f5a4f0c19b1ddb52
10 changes: 6 additions & 4 deletions docs/tutorials/training_tricks.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,13 @@ For loss calculation, we support multiple losses training concurrently. Here is
```python
_base_ = './fcn_unet_s5-d16_64x64_40k_drive.py'
model = dict(
decode_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='CE', loss_weight=1.0),
dict(type='DiceLoss', loss_name='Dice', loss_weight=3.0)]),
auxiliary_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='CE',loss_weight=1.0),
dict(type='DiceLoss', loss_name='Dice', loss_weight=3.0)]),
decode_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
auxiliary_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce',loss_weight=1.0),
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
)
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
```

In this way, `loss_weight` and `loss_name` will be weight and name in training log of corresponding loss, respectively.

Note: If you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name.
10 changes: 6 additions & 4 deletions docs_zh-CN/tutorials/training_tricks.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,13 @@ model=dict(
```python
_base_ = './fcn_unet_s5-d16_64x64_40k_drive.py'
model = dict(
decode_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='CE', loss_weight=1.0),
dict(type='DiceLoss', loss_name='Dice', loss_weight=3.0)]),
auxiliary_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='CE',loss_weight=1.0),
dict(type='DiceLoss', loss_name='Dice', loss_weight=3.0)]),
decode_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
auxiliary_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce',loss_weight=1.0),
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
)
```

通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name`。

注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。
10 changes: 6 additions & 4 deletions mmseg/models/decode_heads/decode_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
Default: None.
loss_decode (dict | Sequence[dict]): Config of decode loss.
The `loss_name` is property of corresponding loss function which
could be shown in training log.
could be shown in training log. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_ce'.
e.g. dict(type='CrossEntropyLoss'),
[dict(type='CrossEntropyLoss', loss_name='CE'),
dict(type='DiceLoss', loss_name='Dice')]
[dict(type='CrossEntropyLoss', loss_name='loss_ce'),
dict(type='DiceLoss', loss_name='loss_dice')]
Default: dict(type='CrossEntropyLoss').
ignore_index (int | None): The label index to be ignored. When using
masked BCE loss, ignore_index should be set to None. Default: 255
Junjun2016 marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -240,7 +242,7 @@ def losses(self, seg_logit, seg_label):
else:
seg_weight = None
seg_label = seg_label.squeeze(1)
for i, loss_decode in enumerate(self.loss_decode):
for loss_decode in self.loss_decode:
if loss_decode.loss_name not in loss:
loss[loss_decode.loss_name] = loss_decode(
seg_logit,
Expand Down
32 changes: 22 additions & 10 deletions tests/test_models/test_heads/test_decode_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,15 @@ def test_decode_head():
16,
num_classes=19,
loss_decode=[
dict(type='CrossEntropyLoss', loss_name='ce_1'),
dict(type='CrossEntropyLoss', loss_name='ce_2')
dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2')
])
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss = head.losses(seg_logit=inputs, seg_label=target)
assert 'ce_1' in loss
assert 'ce_2' in loss
assert 'loss_1' in loss
assert 'loss_2' in loss

# 'loss_decode' must be a dict or sequence of dict
with pytest.raises(TypeError):
Expand All @@ -125,16 +125,16 @@ def test_decode_head():
3,
16,
num_classes=19,
loss_decode=(dict(type='CrossEntropyLoss', loss_name='ce_1'),
dict(type='CrossEntropyLoss', loss_name='ce_2'),
dict(type='CrossEntropyLoss', loss_name='ce_3')))
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2'),
dict(type='CrossEntropyLoss', loss_name='loss_3')))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss = head.losses(seg_logit=inputs, seg_label=target)
assert 'ce_1' in loss
assert 'ce_2' in loss
assert 'ce_3' in loss
assert 'loss_1' in loss
assert 'loss_2' in loss
assert 'loss_3' in loss

# test multi-loss, loss_decode is list of dict, names of them are identical
inputs = torch.randn(2, 19, 8, 8).float()
Expand All @@ -149,5 +149,17 @@ def test_decode_head():
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss_3 = head.losses(seg_logit=inputs, seg_label=target)

head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce')))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss = head.losses(seg_logit=inputs, seg_label=target)
assert 'loss_ce' in loss
assert 'loss_ce' in loss_3
assert loss_3['loss_ce'] == 3 * loss['loss_ce']