diff --git a/docs/tutorials/training_tricks.md b/docs/tutorials/training_tricks.md index 98a201fa649..1c8fe06b943 100644 --- a/docs/tutorials/training_tricks.md +++ b/docs/tutorials/training_tricks.md @@ -50,3 +50,21 @@ model=dict( ``` `class_weight` will be passed into `CrossEntropyLoss` as `weight` argument. Please refer to [PyTorch Doc](https://pytorch.org/docs/stable/nn.html?highlight=crossentropy#torch.nn.CrossEntropyLoss) for details. + +## Multiple Losses + +For loss calculation, we support multiple losses training concurrently. Here is an example config of training `unet` on `DRIVE` dataset, whose loss function is `1:3` weighted sum of `CrossEntropyLoss` and `DiceLoss`: + +```python +_base_ = './fcn_unet_s5-d16_64x64_40k_drive.py' +model = dict( + decode_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0), + dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]), + auxiliary_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce',loss_weight=1.0), + dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]), + ) +``` + +In this way, `loss_weight` and `loss_name` will be weight and name in training log of corresponding loss, respectively. + +Note: If you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name. diff --git a/docs_zh-CN/tutorials/training_tricks.md b/docs_zh-CN/tutorials/training_tricks.md index 9248e5a14be..be9112cabd6 100644 --- a/docs_zh-CN/tutorials/training_tricks.md +++ b/docs_zh-CN/tutorials/training_tricks.md @@ -49,3 +49,22 @@ model=dict( ``` `class_weight` 将被作为 `weight` 参数,传递给 `CrossEntropyLoss`。详细信息请参照 [PyTorch 文档](https://pytorch.org/docs/stable/nn.html?highlight=crossentropy#torch.nn.CrossEntropyLoss) 。 + +## 同时使用多种损失函数 (Multiple Losses) + +对于训练时损失函数的计算,我们目前支持多个损失函数同时使用。 以 `unet` 使用 `DRIVE` 数据集训练为例, +使用 `CrossEntropyLoss` 和 `DiceLoss` 的 `1:3` 的加权和作为损失函数。配置文件写为: + +```python +_base_ = './fcn_unet_s5-d16_64x64_40k_drive.py' +model = dict( + decode_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0), + dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]), + auxiliary_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce',loss_weight=1.0), + dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]), + ) +``` + +通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name`。 + +注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。 diff --git a/mmseg/core/seg/sampler/ohem_pixel_sampler.py b/mmseg/core/seg/sampler/ohem_pixel_sampler.py index bcd481a9653..72ba941f03d 100644 --- a/mmseg/core/seg/sampler/ohem_pixel_sampler.py +++ b/mmseg/core/seg/sampler/ohem_pixel_sampler.py @@ -62,12 +62,14 @@ def sample(self, seg_logit, seg_label): threshold = max(min_threshold, self.thresh) valid_seg_weight[seg_prob[valid_mask] < threshold] = 1. else: - losses = self.context.loss_decode( - seg_logit, - seg_label, - weight=None, - ignore_index=self.context.ignore_index, - reduction_override='none') + losses = 0.0 + for loss_module in self.context.loss_decode: + losses += loss_module( + seg_logit, + seg_label, + weight=None, + ignore_index=self.context.ignore_index, + reduction_override='none') # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa _, sort_indices = losses[valid_mask].sort(descending=True) valid_seg_weight[sort_indices[:batch_kept]] = 1. diff --git a/mmseg/models/decode_heads/decode_head.py b/mmseg/models/decode_heads/decode_head.py index b38701a92ec..c36555eaf28 100644 --- a/mmseg/models/decode_heads/decode_head.py +++ b/mmseg/models/decode_heads/decode_head.py @@ -33,10 +33,17 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta): a list and passed into decode head. None: Only one select feature map is allowed. Default: None. - loss_decode (dict): Config of decode loss. + loss_decode (dict | Sequence[dict]): Config of decode loss. + The `loss_name` is property of corresponding loss function which + could be shown in training log. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_ce'. + e.g. dict(type='CrossEntropyLoss'), + [dict(type='CrossEntropyLoss', loss_name='loss_ce'), + dict(type='DiceLoss', loss_name='loss_dice')] Default: dict(type='CrossEntropyLoss'). ignore_index (int | None): The label index to be ignored. When using - masked BCE loss, ignore_index should be set to None. Default: 255 + masked BCE loss, ignore_index should be set to None. Default: 255. sampler (dict|None): The config of segmentation map sampler. Default: None. align_corners (bool): align_corners argument of F.interpolate. @@ -73,9 +80,20 @@ def __init__(self, self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.in_index = in_index - self.loss_decode = build_loss(loss_decode) + self.ignore_index = ignore_index self.align_corners = align_corners + self.loss_decode = nn.ModuleList() + + if isinstance(loss_decode, dict): + self.loss_decode.append(build_loss(loss_decode)) + elif isinstance(loss_decode, (list, tuple)): + for loss in loss_decode: + self.loss_decode.append(build_loss(loss)) + else: + raise TypeError(f'loss_decode must be a dict or sequence of dict,\ + but got {type(loss_decode)}') + if sampler is not None: self.sampler = build_pixel_sampler(sampler, context=self) else: @@ -224,10 +242,19 @@ def losses(self, seg_logit, seg_label): else: seg_weight = None seg_label = seg_label.squeeze(1) - loss['loss_seg'] = self.loss_decode( - seg_logit, - seg_label, - weight=seg_weight, - ignore_index=self.ignore_index) + for loss_decode in self.loss_decode: + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = loss_decode( + seg_logit, + seg_label, + weight=seg_weight, + ignore_index=self.ignore_index) + else: + loss[loss_decode.loss_name] += loss_decode( + seg_logit, + seg_label, + weight=seg_weight, + ignore_index=self.ignore_index) + loss['acc_seg'] = accuracy(seg_logit, seg_label) return loss diff --git a/mmseg/models/decode_heads/point_head.py b/mmseg/models/decode_heads/point_head.py index 4470571144d..56dfd4ed8bb 100644 --- a/mmseg/models/decode_heads/point_head.py +++ b/mmseg/models/decode_heads/point_head.py @@ -249,8 +249,9 @@ def forward_test(self, inputs, prev_output, img_metas, test_cfg): def losses(self, point_logits, point_label): """Compute segmentation loss.""" loss = dict() - loss['loss_point'] = self.loss_decode( - point_logits, point_label, ignore_index=self.ignore_index) + for loss_module in self.loss_decode: + loss['point' + loss_module.loss_name] = loss_module( + point_logits, point_label, ignore_index=self.ignore_index) loss['acc_point'] = accuracy(point_logits, point_label) return loss diff --git a/mmseg/models/losses/cross_entropy_loss.py b/mmseg/models/losses/cross_entropy_loss.py index 9a7ccea937f..ee489a888fd 100644 --- a/mmseg/models/losses/cross_entropy_loss.py +++ b/mmseg/models/losses/cross_entropy_loss.py @@ -150,6 +150,9 @@ class CrossEntropyLoss(nn.Module): class_weight (list[float] | str, optional): Weight of each class. If in str format, read them from a file. Defaults to None. loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + loss_name (str, optional): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_ce'. """ def __init__(self, @@ -157,7 +160,8 @@ def __init__(self, use_mask=False, reduction='mean', class_weight=None, - loss_weight=1.0): + loss_weight=1.0, + loss_name='loss_ce'): super(CrossEntropyLoss, self).__init__() assert (use_sigmoid is False) or (use_mask is False) self.use_sigmoid = use_sigmoid @@ -172,6 +176,7 @@ def __init__(self, self.cls_criterion = mask_cross_entropy else: self.cls_criterion = cross_entropy + self._loss_name = loss_name def forward(self, cls_score, @@ -197,3 +202,17 @@ def forward(self, avg_factor=avg_factor, **kwargs) return loss_cls + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/mmseg/models/losses/dice_loss.py b/mmseg/models/losses/dice_loss.py index 0b07e97648f..774bd1aea20 100644 --- a/mmseg/models/losses/dice_loss.py +++ b/mmseg/models/losses/dice_loss.py @@ -68,6 +68,9 @@ class DiceLoss(nn.Module): str format, read them from a file. Defaults to None. loss_weight (float, optional): Weight of the loss. Default to 1.0. ignore_index (int | None): The label index to be ignored. Default: 255. + loss_name (str, optional): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_dice'. """ def __init__(self, @@ -77,6 +80,7 @@ def __init__(self, class_weight=None, loss_weight=1.0, ignore_index=255, + loss_name='loss_dice', **kwards): super(DiceLoss, self).__init__() self.smooth = smooth @@ -85,6 +89,7 @@ def __init__(self, self.class_weight = get_class_weight(class_weight) self.loss_weight = loss_weight self.ignore_index = ignore_index + self._loss_name = loss_name def forward(self, pred, @@ -118,3 +123,17 @@ def forward(self, class_weight=class_weight, ignore_index=self.ignore_index) return loss + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/mmseg/models/losses/lovasz_loss.py b/mmseg/models/losses/lovasz_loss.py index 275c4c54326..2bb0fad3931 100644 --- a/mmseg/models/losses/lovasz_loss.py +++ b/mmseg/models/losses/lovasz_loss.py @@ -244,6 +244,9 @@ class LovaszLoss(nn.Module): class_weight (list[float] | str, optional): Weight of each class. If in str format, read them from a file. Defaults to None. loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + loss_name (str, optional): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_lovasz'. """ def __init__(self, @@ -252,7 +255,8 @@ def __init__(self, per_image=False, reduction='mean', class_weight=None, - loss_weight=1.0): + loss_weight=1.0, + loss_name='loss_lovasz'): super(LovaszLoss, self).__init__() assert loss_type in ('binary', 'multi_class'), "loss_type should be \ 'binary' or 'multi_class'." @@ -271,6 +275,7 @@ def __init__(self, self.reduction = reduction self.loss_weight = loss_weight self.class_weight = get_class_weight(class_weight) + self._loss_name = loss_name def forward(self, cls_score, @@ -302,3 +307,17 @@ def forward(self, avg_factor=avg_factor, **kwargs) return loss_cls + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/tests/test_models/test_heads/test_decode_head.py b/tests/test_models/test_heads/test_decode_head.py index 421043d398e..cb9ab97181b 100644 --- a/tests/test_models/test_heads/test_decode_head.py +++ b/tests/test_models/test_heads/test_decode_head.py @@ -74,3 +74,92 @@ def test_decode_head(): assert head.input_transform == 'resize_concat' transformed_inputs = head._transform_inputs(inputs) assert transformed_inputs.shape == (1, 48, 45, 45) + + # test multi-loss, loss_decode is dict + with pytest.raises(TypeError): + # loss_decode must be a dict or sequence of dict. + BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss']) + + inputs = torch.randn(2, 19, 8, 8).float() + target = torch.ones(2, 1, 64, 64).long() + head = BaseDecodeHead( + 3, + 16, + num_classes=19, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + head, target = to_cuda(head, target) + loss = head.losses(seg_logit=inputs, seg_label=target) + assert 'loss_ce' in loss + + # test multi-loss, loss_decode is list of dict + inputs = torch.randn(2, 19, 8, 8).float() + target = torch.ones(2, 1, 64, 64).long() + head = BaseDecodeHead( + 3, + 16, + num_classes=19, + loss_decode=[ + dict(type='CrossEntropyLoss', loss_name='loss_1'), + dict(type='CrossEntropyLoss', loss_name='loss_2') + ]) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + head, target = to_cuda(head, target) + loss = head.losses(seg_logit=inputs, seg_label=target) + assert 'loss_1' in loss + assert 'loss_2' in loss + + # 'loss_decode' must be a dict or sequence of dict + with pytest.raises(TypeError): + BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss']) + with pytest.raises(TypeError): + BaseDecodeHead(3, 16, num_classes=19, loss_decode=0) + + # test multi-loss, loss_decode is list of dict + inputs = torch.randn(2, 19, 8, 8).float() + target = torch.ones(2, 1, 64, 64).long() + head = BaseDecodeHead( + 3, + 16, + num_classes=19, + loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_1'), + dict(type='CrossEntropyLoss', loss_name='loss_2'), + dict(type='CrossEntropyLoss', loss_name='loss_3'))) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + head, target = to_cuda(head, target) + loss = head.losses(seg_logit=inputs, seg_label=target) + assert 'loss_1' in loss + assert 'loss_2' in loss + assert 'loss_3' in loss + + # test multi-loss, loss_decode is list of dict, names of them are identical + inputs = torch.randn(2, 19, 8, 8).float() + target = torch.ones(2, 1, 64, 64).long() + head = BaseDecodeHead( + 3, + 16, + num_classes=19, + loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce'), + dict(type='CrossEntropyLoss', loss_name='loss_ce'), + dict(type='CrossEntropyLoss', loss_name='loss_ce'))) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + head, target = to_cuda(head, target) + loss_3 = head.losses(seg_logit=inputs, seg_label=target) + + head = BaseDecodeHead( + 3, + 16, + num_classes=19, + loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce'))) + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + head, target = to_cuda(head, target) + loss = head.losses(seg_logit=inputs, seg_label=target) + assert 'loss_ce' in loss + assert 'loss_ce' in loss_3 + assert loss_3['loss_ce'] == 3 * loss['loss_ce'] diff --git a/tests/test_models/test_losses/test_ce_loss.py b/tests/test_models/test_losses/test_ce_loss.py index 73217ec8c00..03bc3beef9f 100644 --- a/tests/test_models/test_losses/test_ce_loss.py +++ b/tests/test_models/test_losses/test_ce_loss.py @@ -20,7 +20,8 @@ def test_ce_loss(): type='CrossEntropyLoss', use_sigmoid=False, class_weight=[0.8, 0.2], - loss_weight=1.0) + loss_weight=1.0, + loss_name='loss_ce') loss_cls = build_loss(loss_cls_cfg) fake_pred = torch.Tensor([[100, -100]]) fake_label = torch.Tensor([1]).long() @@ -38,7 +39,8 @@ def test_ce_loss(): type='CrossEntropyLoss', use_sigmoid=False, class_weight=f'{tmp_file.name}.pkl', - loss_weight=1.0) + loss_weight=1.0, + loss_name='loss_ce') loss_cls = build_loss(loss_cls_cfg) assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.)) @@ -47,7 +49,8 @@ def test_ce_loss(): type='CrossEntropyLoss', use_sigmoid=False, class_weight=f'{tmp_file.name}.npy', - loss_weight=1.0) + loss_weight=1.0, + loss_name='loss_ce') loss_cls = build_loss(loss_cls_cfg) assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.)) tmp_file.close() @@ -74,4 +77,12 @@ def test_ce_loss(): torch.tensor(0.9354), atol=1e-4) + # test cross entropy loss has name `loss_ce` + loss_cls_cfg = dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0, + loss_name='loss_ce') + loss_cls = build_loss(loss_cls_cfg) + assert loss_cls.loss_name == 'loss_ce' # TODO test use_mask diff --git a/tests/test_models/test_losses/test_dice_loss.py b/tests/test_models/test_losses/test_dice_loss.py index 05d1b1e0538..d6f10439d86 100644 --- a/tests/test_models/test_losses/test_dice_loss.py +++ b/tests/test_models/test_losses/test_dice_loss.py @@ -11,7 +11,8 @@ def test_dice_lose(): reduction='none', class_weight=[1.0, 2.0, 3.0], loss_weight=1.0, - ignore_index=1) + ignore_index=1, + loss_name='loss_dice') dice_loss = build_loss(loss_cfg) logits = torch.rand(8, 3, 4, 4) labels = (torch.rand(8, 4, 4) * 3).long() @@ -30,7 +31,8 @@ def test_dice_lose(): reduction='none', class_weight=f'{tmp_file.name}.pkl', loss_weight=1.0, - ignore_index=1) + ignore_index=1, + loss_name='loss_dice') dice_loss = build_loss(loss_cfg) dice_loss(logits, labels, ignore_index=None) @@ -40,7 +42,8 @@ def test_dice_lose(): reduction='none', class_weight=f'{tmp_file.name}.pkl', loss_weight=1.0, - ignore_index=1) + ignore_index=1, + loss_name='loss_dice') dice_loss = build_loss(loss_cfg) dice_loss(logits, labels, ignore_index=None) tmp_file.close() @@ -54,8 +57,21 @@ def test_dice_lose(): exponent=3, reduction='sum', loss_weight=1.0, - ignore_index=0) + ignore_index=0, + loss_name='loss_dice') dice_loss = build_loss(loss_cfg) logits = torch.rand(8, 2, 4, 4) labels = (torch.rand(8, 4, 4) * 2).long() dice_loss(logits, labels) + + # test dice loss has name `loss_dice` + loss_cfg = dict( + type='DiceLoss', + smooth=2, + exponent=3, + reduction='sum', + loss_weight=1.0, + ignore_index=0, + loss_name='loss_dice') + dice_loss = build_loss(loss_cfg) + assert dice_loss.loss_name == 'loss_dice' diff --git a/tests/test_models/test_losses/test_lovasz_loss.py b/tests/test_models/test_losses/test_lovasz_loss.py index e2dee81de8c..74ddb48d8ea 100644 --- a/tests/test_models/test_losses/test_lovasz_loss.py +++ b/tests/test_models/test_losses/test_lovasz_loss.py @@ -12,16 +12,24 @@ def test_lovasz_loss(): type='LovaszLoss', loss_type='Binary', reduction='none', - loss_weight=1.0) + loss_weight=1.0, + loss_name='loss_lovasz') build_loss(loss_cfg) # reduction should be 'none' when per_image is False. with pytest.raises(AssertionError): - loss_cfg = dict(type='LovaszLoss', loss_type='multi_class') + loss_cfg = dict( + type='LovaszLoss', + loss_type='multi_class', + loss_name='loss_lovasz') build_loss(loss_cfg) # test lovasz loss with loss_type = 'multi_class' and per_image = False - loss_cfg = dict(type='LovaszLoss', reduction='none', loss_weight=1.0) + loss_cfg = dict( + type='LovaszLoss', + reduction='none', + loss_weight=1.0, + loss_name='loss_lovasz') lovasz_loss = build_loss(loss_cfg) logits = torch.rand(1, 3, 4, 4) labels = (torch.rand(1, 4, 4) * 2).long() @@ -33,7 +41,8 @@ def test_lovasz_loss(): per_image=True, reduction='mean', class_weight=[1.0, 2.0, 3.0], - loss_weight=1.0) + loss_weight=1.0, + loss_name='loss_lovasz') lovasz_loss = build_loss(loss_cfg) logits = torch.rand(1, 3, 4, 4) labels = (torch.rand(1, 4, 4) * 2).long() @@ -52,7 +61,8 @@ def test_lovasz_loss(): per_image=True, reduction='mean', class_weight=f'{tmp_file.name}.pkl', - loss_weight=1.0) + loss_weight=1.0, + loss_name='loss_lovasz') lovasz_loss = build_loss(loss_cfg) lovasz_loss(logits, labels, ignore_index=None) @@ -62,7 +72,8 @@ def test_lovasz_loss(): per_image=True, reduction='mean', class_weight=f'{tmp_file.name}.npy', - loss_weight=1.0) + loss_weight=1.0, + loss_name='loss_lovasz') lovasz_loss = build_loss(loss_cfg) lovasz_loss(logits, labels, ignore_index=None) tmp_file.close() @@ -74,7 +85,8 @@ def test_lovasz_loss(): type='LovaszLoss', loss_type='binary', reduction='none', - loss_weight=1.0) + loss_weight=1.0, + loss_name='loss_lovasz') lovasz_loss = build_loss(loss_cfg) logits = torch.rand(2, 4, 4) labels = (torch.rand(2, 4, 4)).long() @@ -86,8 +98,20 @@ def test_lovasz_loss(): loss_type='binary', per_image=True, reduction='mean', - loss_weight=1.0) + loss_weight=1.0, + loss_name='loss_lovasz') lovasz_loss = build_loss(loss_cfg) logits = torch.rand(2, 4, 4) labels = (torch.rand(2, 4, 4)).long() lovasz_loss(logits, labels, ignore_index=None) + + # test lovasz loss has name `loss_lovasz` + loss_cfg = dict( + type='LovaszLoss', + loss_type='binary', + per_image=True, + reduction='mean', + loss_weight=1.0, + loss_name='loss_lovasz') + lovasz_loss = build_loss(loss_cfg) + assert lovasz_loss.loss_name == 'loss_lovasz'