Skip to content

Commit

Permalink
revise 5-5
Browse files Browse the repository at this point in the history
  • Loading branch information
lyhue1991 committed Jun 18, 2020
1 parent d3b624e commit ec43286
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
2 changes: 1 addition & 1 deletion 5-5,损失函数losses.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ _________________________________________________________________

它有两个可调参数,alpha参数和gamma参数。其中alpha参数主要用于衰减负样本的权重,gamma参数主要用于衰减容易训练样本的权重。

从而让模型更加聚焦在正样本和困难样本上。
从而让模型更加聚焦在正样本和困难样本上。这就是为什么这个损失函数叫做Focal Loss。

详见《5分钟理解Focal Loss与GHM——解决样本不平衡利器》

Expand Down
31 changes: 18 additions & 13 deletions english/Chapter5-5.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,18 +96,23 @@ It is also possible to customize loss function through inheriting from the base

Here is an example of customized implementation to the Focal Loss, which is an improvement of `binary_crossentropy` loss function.

Focal Loss results better comparing to the binary cross entropy, given the condition of unbalanced category and difficulties in training.
Focal Loss results better comparing to the binary cross entropy, given the condition of unbalanced category and many easy samples in training data.

You may refer to the following article for details of this topic: [How to comment the "Focal Loss for Dense Object Detection" by Kaiming ?](https://www.zhihu.com/question/63581984)
It has two adjustable parameters,alpha and gamma. The aim of alpha is to decay the weight of negative samples,and gamma to decay the weight of the easy samples.

So the model will then focal its weight on the positive samples and hard samples. This is why the loss is called focal loss.

You may refer to the following article for details of this topic: [Understand Focal Loss and GHM in 5 minutes](https://www.zhihu.com/question/63581984)

```python
def focal_loss(gamma=2., alpha=0.25):
def focal_loss(gamma=2., alpha=0.75):

def focal_loss_fixed(y_true, y_pred):
pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
loss = -tf.reduce_sum(alpha * tf.pow(1. - pt_1, gamma) * tf.log(1e-07+pt_1)) \
-tf.reduce_sum((1-alpha) * tf.pow( pt_0, gamma) * tf.log(1. - pt_0 + 1e-07))
bce = tf.losses.binary_crossentropy(y_true, y_pred)
p_t = (y_true * y_pred) + ((1 - y_true) * (1 - y_pred))
alpha_factor = y_true * alpha + (1 - y_true) * (1 - alpha)
modulating_factor = tf.pow(1.0 - p_t, gamma)
loss = tf.reduce_sum(alpha_factor * modulating_factor * bce,axis = -1 )
return loss
return focal_loss_fixed

Expand All @@ -116,16 +121,16 @@ def focal_loss(gamma=2., alpha=0.25):
```python
class FocalLoss(losses.Loss):

def __init__(self,gamma=2.0,alpha=0.25):
def __init__(self,gamma=2.0,alpha=0.75,name = "focal_loss"):
self.gamma = gamma
self.alpha = alpha

def call(self,y_true,y_pred):

pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
loss = -tf.reduce_sum(self.alpha * tf.pow(1. - pt_1, self.gamma) * tf.log(1e-07+pt_1)) \
-tf.reduce_sum((1-self.alpha) * tf.pow( pt_0, self.gamma) * tf.log(1. - pt_0 + 1e-07))
bce = tf.losses.binary_crossentropy(y_true, y_pred)
p_t = (y_true * y_pred) + ((1 - y_true) * (1 - y_pred))
alpha_factor = y_true * self.alpha + (1 - y_true) * (1 - self.alpha)
modulating_factor = tf.pow(1.0 - p_t, self.gamma)
loss = tf.reduce_sum(alpha_factor * modulating_factor * bce,axis = -1 )
return loss
```

Expand Down

0 comments on commit ec43286

Please sign in to comment.