Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev bcewithlogits loss #4024

Merged
merged 10 commits into from
Dec 24, 2020
Merged

Dev bcewithlogits loss #4024

merged 10 commits into from
Dec 24, 2020

Conversation

MARD1NO
Copy link
Contributor

@MARD1NO MARD1NO commented Dec 22, 2020

  1. 增加 BCEWithLogitsLoss 算子
  2. 去除了 原始 BCELoss 引入的sigmoid 运算,但是缺乏对数值范围的检查,需要保证数值在 0-1, 后续到C++后端重写算子

BCEWithLogitsLoss 相当于 sigmoid + bceloss, 但是Pytorch为了更好的数值稳定性, 没有直接使用sigmoid运算, 而是做了一些数学操作, 下面是其源代码

Tensor binary_cross_entropy_with_logits(const Tensor& input, const Tensor& target, const Tensor& weight, const Tensor& pos_weight, int64_t reduction) {
    Tensor loss;
    auto max_val = (-input).clamp_min_(0);
    if (pos_weight.defined()) {
        // pos_weight need to be broadcasted, thus mul(target) is not inplace.
        auto log_weight = (pos_weight - 1).mul(target).add_(1);
        loss = (1 - target).mul_(input).add_(log_weight.mul_(((-max_val).exp_().add_((-input - max_val).exp_())).log_().add_(max_val)));
    } else {
        loss = (1 - target).mul_(input).add_(max_val).add_((-max_val).exp_().add_((-input -max_val).exp_()).log_());
    }

    if (weight.defined()) {
        loss.mul_(weight);
    }

    return apply_loss_reduction(loss, reduction);
}

使用了大量数学操作,以及inplace操作

为了更清晰的表示整个计算思路,相关计算的python版本,会在单测中体现

@Ldpe2G
Copy link
Contributor

Ldpe2G commented Dec 24, 2020

推导了一下,知道 pytroch 的公式为什么是这个形式的了:

image-20201224150442549

下面是带 pos_weight 的版本:

image-20201224150456695

文档说是用到了 log-sum-exp 这个 trick,查了下资料
https://www.xarg.org/2016/06/the-log-sum-exp-trick-in-machine-learning/

@Ldpe2G Ldpe2G marked this pull request as ready for review December 24, 2020 07:17
@MARD1NO
Copy link
Contributor Author

MARD1NO commented Dec 24, 2020

推导了一下,知道 pytroch 的公式为什么是这个形式的了:

image-20201224150442549

下面是带 pos_weight 的版本:

image-20201224150456695

文档说是用到了 log-sum-exp 这个 trick,查了下资料
https://www.xarg.org/2016/06/the-log-sum-exp-trick-in-machine-learning/

之前看过几个基于softmax的,现在好像更清晰了, 为了避免负数 在 计算 e^(-x) 出现溢出情况, 先反转,取到max,再间接通过

log-sum-exp 这个小技巧来避免了数值溢出

@oneflow-ci-bot oneflow-ci-bot requested review from oneflow-ci-bot and removed request for oneflow-ci-bot December 24, 2020 10:15
@oneflow-ci-bot oneflow-ci-bot merged commit 40e939e into master Dec 24, 2020
@oneflow-ci-bot oneflow-ci-bot deleted the dev_bcewithlogits_loss branch December 24, 2020 13:27
liujuncheng pushed a commit that referenced this pull request Jun 3, 2021
* remove sigmoid

* add Bce with logits loss

* add test cases

* refine

* change random distribution

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Former-commit-id: 40e939e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants