-
Notifications
You must be signed in to change notification settings - Fork 611
/
focal_loss.py
142 lines (117 loc) · 5.17 KB
/
focal_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implements Focal loss."""
import tensorflow as tf
import tensorflow.keras.backend as K
from typeguard import typechecked
from tensorflow_addons.utils.keras_utils import LossFunctionWrapper
from tensorflow_addons.utils.types import FloatTensorLike, TensorLike
@tf.keras.utils.register_keras_serializable(package="Addons")
class SigmoidFocalCrossEntropy(LossFunctionWrapper):
"""Implements the focal loss function.
Focal loss was first introduced in the RetinaNet paper
(https://arxiv.org/pdf/1708.02002.pdf). Focal loss is extremely useful for
classification when you have highly imbalanced classes. It down-weights
well-classified examples and focuses on hard examples. The loss value is
much high for a sample which is misclassified by the classifier as compared
to the loss value corresponding to a well-classified example. One of the
best use-cases of focal loss is its usage in object detection where the
imbalance between the background class and other classes is extremely high.
Usage:
>>> fl = tfa.losses.SigmoidFocalCrossEntropy()
>>> loss = fl(
... y_true = [[1.0], [1.0], [0.0]],y_pred = [[0.97], [0.91], [0.03]])
>>> loss
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([6.8532745e-06, 1.9097870e-04, 2.0559824e-05],
dtype=float32)>
Usage with `tf.keras` API:
>>> model = tf.keras.Model()
>>> model.compile('sgd', loss=tfa.losses.SigmoidFocalCrossEntropy())
Args:
alpha: balancing factor, default value is 0.25.
gamma: modulating factor, default value is 2.0.
Returns:
Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
shape as `y_true`; otherwise, it is scalar.
Raises:
ValueError: If the shape of `sample_weight` is invalid or value of
`gamma` is less than zero.
"""
@typechecked
def __init__(
self,
from_logits: bool = False,
alpha: FloatTensorLike = 0.25,
gamma: FloatTensorLike = 2.0,
reduction: str = tf.keras.losses.Reduction.NONE,
name: str = "sigmoid_focal_crossentropy",
):
super().__init__(
sigmoid_focal_crossentropy,
name=name,
reduction=reduction,
from_logits=from_logits,
alpha=alpha,
gamma=gamma,
)
@tf.keras.utils.register_keras_serializable(package="Addons")
@tf.function
def sigmoid_focal_crossentropy(
y_true: TensorLike,
y_pred: TensorLike,
alpha: FloatTensorLike = 0.25,
gamma: FloatTensorLike = 2.0,
from_logits: bool = False,
) -> tf.Tensor:
"""Implements the focal loss function.
Focal loss was first introduced in the RetinaNet paper
(https://arxiv.org/pdf/1708.02002.pdf). Focal loss is extremely useful for
classification when you have highly imbalanced classes. It down-weights
well-classified examples and focuses on hard examples. The loss value is
much high for a sample which is misclassified by the classifier as compared
to the loss value corresponding to a well-classified example. One of the
best use-cases of focal loss is its usage in object detection where the
imbalance between the background class and other classes is extremely high.
Args:
y_true: true targets tensor.
y_pred: predictions tensor.
alpha: balancing factor.
gamma: modulating factor.
Returns:
Weighted loss float `Tensor`. If `reduction` is `NONE`,this has the
same shape as `y_true`; otherwise, it is scalar.
"""
if gamma and gamma < 0:
raise ValueError("Value of gamma should be greater than or equal to zero")
y_pred = tf.convert_to_tensor(y_pred)
y_true = tf.convert_to_tensor(y_true, dtype=y_pred.dtype)
# Get the cross_entropy for each entry
ce = K.binary_crossentropy(y_true, y_pred, from_logits=from_logits)
# If logits are provided then convert the predictions into probabilities
if from_logits:
pred_prob = tf.sigmoid(y_pred)
else:
pred_prob = y_pred
p_t = (y_true * pred_prob) + ((1 - y_true) * (1 - pred_prob))
alpha_factor = 1.0
modulating_factor = 1.0
if alpha:
alpha = tf.convert_to_tensor(alpha, dtype=K.floatx())
alpha_factor = y_true * alpha + (1 - y_true) * (1 - alpha)
if gamma:
gamma = tf.convert_to_tensor(gamma, dtype=K.floatx())
modulating_factor = tf.pow((1.0 - p_t), gamma)
# compute the final loss and return
return tf.reduce_sum(alpha_factor * modulating_factor * ce, axis=-1)