-
Notifications
You must be signed in to change notification settings - Fork 50
/
losses.py
97 lines (83 loc) · 3.3 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from typing import Optional, Sequence, Union
import torch
class CombinedLoss(torch.nn.Module):
"""Defines a loss function as a weighted sum of combinable loss criteria.
Args:
criteria: List of loss criterion modules that should be combined.
weight: Weight assigned to the individual loss criteria (in the same
order as ``criteria``).
device: The device on which the loss should be computed. This needs
to be set to the device that the loss arguments are allocated on.
"""
def __init__(
self,
criteria: Sequence[torch.nn.Module],
weight: Optional[Sequence[float]] = None,
device: Optional[torch.device] = None,
):
super().__init__()
self.criteria = torch.nn.ModuleList(criteria)
self.device = device
if weight is None:
weight = torch.ones(len(criteria))
else:
weight = torch.as_tensor(weight, dtype=torch.float32)
assert weight.shape == (len(criteria),)
self.register_buffer("weight", weight.to(self.device))
def forward(self, *args):
loss = torch.tensor(0.0, device=self.device)
for crit, weight in zip(self.criteria, self.weight):
loss += weight * crit(*args)
return loss
def _channelwise_sum(x: torch.Tensor) -> torch.Tensor:
"""Sum-reduce all dimensions of a tensor except dimension 1 (C)"""
reduce_dims = tuple([0] + list(range(x.dim()))[2:]) # = (0, 2, 3, ...)
return x.sum(dim=reduce_dims)
def dice_loss(
probs: torch.Tensor,
target: torch.Tensor,
weight: float = 1.0,
eps: float = 0.0001,
smooth: float = 0.0,
):
tsh, psh = target.shape, probs.shape
if tsh == psh: # Already one-hot
onehot_target = target.to(probs.dtype)
elif (
tsh[0] == psh[0] and tsh[1:] == psh[2:]
): # Assume dense target storage, convert to one-hot
onehot_target = torch.zeros_like(probs)
onehot_target.scatter_(1, target.unsqueeze(1), 1)
else:
raise ValueError(
f"Target shape {target.shape} is not compatible with output shape {probs.shape}."
)
intersection = probs * onehot_target # (N, C, ...)
numerator = 2 * _channelwise_sum(intersection) + smooth # (C,)
denominator = probs + onehot_target # (N, C, ...)
denominator = _channelwise_sum(denominator) + smooth + eps # (C,)
loss_per_channel = 1 - (numerator / denominator) # (C,)
weighted_loss_per_channel = weight * loss_per_channel # (C,)
return weighted_loss_per_channel.mean() # ()
class DiceLoss(torch.nn.Module):
def __init__(
self,
apply_softmax: bool = True,
weight: Optional[torch.Tensor] = None,
smooth: float = 0.0,
):
super().__init__()
if apply_softmax:
self.softmax = torch.nn.Softmax(dim=1)
else:
self.softmax = lambda x: x # Identity (no softmax)
self.dice = dice_loss
if weight is None:
weight = torch.tensor(1.0)
self.register_buffer("weight", weight)
self.smooth = smooth
def forward(self, output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
probs = self.softmax(output)
return self.dice(
probs=probs, target=target, weight=self.weight, smooth=self.smooth
)