Skip to content

Commit

Permalink
Merge 0b3a22b into 0b11d58
Browse files Browse the repository at this point in the history
  • Loading branch information
MengzhangLI authored Sep 24, 2021
2 parents 0b11d58 + 0b3a22b commit 01bff41
Show file tree
Hide file tree
Showing 12 changed files with 297 additions and 33 deletions.
18 changes: 18 additions & 0 deletions docs/tutorials/training_tricks.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,21 @@ model=dict(
```

`class_weight` will be passed into `CrossEntropyLoss` as `weight` argument. Please refer to [PyTorch Doc](https://pytorch.org/docs/stable/nn.html?highlight=crossentropy#torch.nn.CrossEntropyLoss) for details.

## Multiple Losses

For loss calculation, we support multiple losses training concurrently. Here is an example config of training `unet` on `DRIVE` dataset, whose loss function is `1:3` weighted sum of `CrossEntropyLoss` and `DiceLoss`:

```python
_base_ = './fcn_unet_s5-d16_64x64_40k_drive.py'
model = dict(
decode_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
auxiliary_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce',loss_weight=1.0),
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
)
```

In this way, `loss_weight` and `loss_name` will be weight and name in training log of corresponding loss, respectively.

Note: If you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name.
19 changes: 19 additions & 0 deletions docs_zh-CN/tutorials/training_tricks.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,22 @@ model=dict(
```

`class_weight` 将被作为 `weight` 参数,传递给 `CrossEntropyLoss`。详细信息请参照 [PyTorch 文档](https://pytorch.org/docs/stable/nn.html?highlight=crossentropy#torch.nn.CrossEntropyLoss)

## 同时使用多种损失函数 (Multiple Losses)

对于训练时损失函数的计算,我们目前支持多个损失函数同时使用。 以 `unet` 使用 `DRIVE` 数据集训练为例,
使用 `CrossEntropyLoss``DiceLoss``1:3` 的加权和作为损失函数。配置文件写为:

```python
_base_ = './fcn_unet_s5-d16_64x64_40k_drive.py'
model = dict(
decode_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
auxiliary_head=dict(loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce',loss_weight=1.0),
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]),
)
```

通过这种方式,确定训练过程中损失函数的权重 `loss_weight` 和在训练日志里的名字 `loss_name`

注意: `loss_name` 的名字必须带有 `loss_` 前缀,这样它才能被包括在反传的图里。
14 changes: 8 additions & 6 deletions mmseg/core/seg/sampler/ohem_pixel_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,14 @@ def sample(self, seg_logit, seg_label):
threshold = max(min_threshold, self.thresh)
valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
else:
losses = self.context.loss_decode(
seg_logit,
seg_label,
weight=None,
ignore_index=self.context.ignore_index,
reduction_override='none')
losses = 0.0
for loss_module in self.context.loss_decode:
losses += loss_module(
seg_logit,
seg_label,
weight=None,
ignore_index=self.context.ignore_index,
reduction_override='none')
# faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
_, sort_indices = losses[valid_mask].sort(descending=True)
valid_seg_weight[sort_indices[:batch_kept]] = 1.
Expand Down
43 changes: 35 additions & 8 deletions mmseg/models/decode_heads/decode_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,17 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
a list and passed into decode head.
None: Only one select feature map is allowed.
Default: None.
loss_decode (dict): Config of decode loss.
loss_decode (dict | Sequence[dict]): Config of decode loss.
The `loss_name` is property of corresponding loss function which
could be shown in training log. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_ce'.
e.g. dict(type='CrossEntropyLoss'),
[dict(type='CrossEntropyLoss', loss_name='loss_ce'),
dict(type='DiceLoss', loss_name='loss_dice')]
Default: dict(type='CrossEntropyLoss').
ignore_index (int | None): The label index to be ignored. When using
masked BCE loss, ignore_index should be set to None. Default: 255
masked BCE loss, ignore_index should be set to None. Default: 255.
sampler (dict|None): The config of segmentation map sampler.
Default: None.
align_corners (bool): align_corners argument of F.interpolate.
Expand Down Expand Up @@ -73,9 +80,20 @@ def __init__(self,
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.in_index = in_index
self.loss_decode = build_loss(loss_decode)

self.ignore_index = ignore_index
self.align_corners = align_corners
self.loss_decode = nn.ModuleList()

if isinstance(loss_decode, dict):
self.loss_decode.append(build_loss(loss_decode))
elif isinstance(loss_decode, (list, tuple)):
for loss in loss_decode:
self.loss_decode.append(build_loss(loss))
else:
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
but got {type(loss_decode)}')

if sampler is not None:
self.sampler = build_pixel_sampler(sampler, context=self)
else:
Expand Down Expand Up @@ -224,10 +242,19 @@ def losses(self, seg_logit, seg_label):
else:
seg_weight = None
seg_label = seg_label.squeeze(1)
loss['loss_seg'] = self.loss_decode(
seg_logit,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
for loss_decode in self.loss_decode:
if loss_decode.loss_name not in loss:
loss[loss_decode.loss_name] = loss_decode(
seg_logit,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
else:
loss[loss_decode.loss_name] += loss_decode(
seg_logit,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)

loss['acc_seg'] = accuracy(seg_logit, seg_label)
return loss
5 changes: 3 additions & 2 deletions mmseg/models/decode_heads/point_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,9 @@ def forward_test(self, inputs, prev_output, img_metas, test_cfg):
def losses(self, point_logits, point_label):
"""Compute segmentation loss."""
loss = dict()
loss['loss_point'] = self.loss_decode(
point_logits, point_label, ignore_index=self.ignore_index)
for loss_module in self.loss_decode:
loss['point' + loss_module.loss_name] = loss_module(
point_logits, point_label, ignore_index=self.ignore_index)
loss['acc_point'] = accuracy(point_logits, point_label)
return loss

Expand Down
21 changes: 20 additions & 1 deletion mmseg/models/losses/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,18 @@ class CrossEntropyLoss(nn.Module):
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_ce'.
"""

def __init__(self,
use_sigmoid=False,
use_mask=False,
reduction='mean',
class_weight=None,
loss_weight=1.0):
loss_weight=1.0,
loss_name='loss_ce'):
super(CrossEntropyLoss, self).__init__()
assert (use_sigmoid is False) or (use_mask is False)
self.use_sigmoid = use_sigmoid
Expand All @@ -172,6 +176,7 @@ def __init__(self,
self.cls_criterion = mask_cross_entropy
else:
self.cls_criterion = cross_entropy
self._loss_name = loss_name

def forward(self,
cls_score,
Expand All @@ -197,3 +202,17 @@ def forward(self,
avg_factor=avg_factor,
**kwargs)
return loss_cls

@property
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
19 changes: 19 additions & 0 deletions mmseg/models/losses/dice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class DiceLoss(nn.Module):
str format, read them from a file. Defaults to None.
loss_weight (float, optional): Weight of the loss. Default to 1.0.
ignore_index (int | None): The label index to be ignored. Default: 255.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_dice'.
"""

def __init__(self,
Expand All @@ -77,6 +80,7 @@ def __init__(self,
class_weight=None,
loss_weight=1.0,
ignore_index=255,
loss_name='loss_dice',
**kwards):
super(DiceLoss, self).__init__()
self.smooth = smooth
Expand All @@ -85,6 +89,7 @@ def __init__(self,
self.class_weight = get_class_weight(class_weight)
self.loss_weight = loss_weight
self.ignore_index = ignore_index
self._loss_name = loss_name

def forward(self,
pred,
Expand Down Expand Up @@ -118,3 +123,17 @@ def forward(self,
class_weight=class_weight,
ignore_index=self.ignore_index)
return loss

@property
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
21 changes: 20 additions & 1 deletion mmseg/models/losses/lovasz_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,9 @@ class LovaszLoss(nn.Module):
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_lovasz'.
"""

def __init__(self,
Expand All @@ -252,7 +255,8 @@ def __init__(self,
per_image=False,
reduction='mean',
class_weight=None,
loss_weight=1.0):
loss_weight=1.0,
loss_name='loss_lovasz'):
super(LovaszLoss, self).__init__()
assert loss_type in ('binary', 'multi_class'), "loss_type should be \
'binary' or 'multi_class'."
Expand All @@ -271,6 +275,7 @@ def __init__(self,
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = get_class_weight(class_weight)
self._loss_name = loss_name

def forward(self,
cls_score,
Expand Down Expand Up @@ -302,3 +307,17 @@ def forward(self,
avg_factor=avg_factor,
**kwargs)
return loss_cls

@property
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name
89 changes: 89 additions & 0 deletions tests/test_models/test_heads/test_decode_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,92 @@ def test_decode_head():
assert head.input_transform == 'resize_concat'
transformed_inputs = head._transform_inputs(inputs)
assert transformed_inputs.shape == (1, 48, 45, 45)

# test multi-loss, loss_decode is dict
with pytest.raises(TypeError):
# loss_decode must be a dict or sequence of dict.
BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss'])

inputs = torch.randn(2, 19, 8, 8).float()
target = torch.ones(2, 1, 64, 64).long()
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss = head.losses(seg_logit=inputs, seg_label=target)
assert 'loss_ce' in loss

# test multi-loss, loss_decode is list of dict
inputs = torch.randn(2, 19, 8, 8).float()
target = torch.ones(2, 1, 64, 64).long()
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=[
dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2')
])
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss = head.losses(seg_logit=inputs, seg_label=target)
assert 'loss_1' in loss
assert 'loss_2' in loss

# 'loss_decode' must be a dict or sequence of dict
with pytest.raises(TypeError):
BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss'])
with pytest.raises(TypeError):
BaseDecodeHead(3, 16, num_classes=19, loss_decode=0)

# test multi-loss, loss_decode is list of dict
inputs = torch.randn(2, 19, 8, 8).float()
target = torch.ones(2, 1, 64, 64).long()
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2'),
dict(type='CrossEntropyLoss', loss_name='loss_3')))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss = head.losses(seg_logit=inputs, seg_label=target)
assert 'loss_1' in loss
assert 'loss_2' in loss
assert 'loss_3' in loss

# test multi-loss, loss_decode is list of dict, names of them are identical
inputs = torch.randn(2, 19, 8, 8).float()
target = torch.ones(2, 1, 64, 64).long()
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce'),
dict(type='CrossEntropyLoss', loss_name='loss_ce'),
dict(type='CrossEntropyLoss', loss_name='loss_ce')))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss_3 = head.losses(seg_logit=inputs, seg_label=target)

head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce')))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss = head.losses(seg_logit=inputs, seg_label=target)
assert 'loss_ce' in loss
assert 'loss_ce' in loss_3
assert loss_3['loss_ce'] == 3 * loss['loss_ce']
Loading

0 comments on commit 01bff41

Please sign in to comment.