Skip to content

Conversation

@zhengshengning
Copy link
Contributor

@zhengshengning zhengshengning commented Oct 16, 2025

PR Category

Operator Mechanism

PR Types

Performance

Description

torch实现 : aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu

void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter) {
  AT_DISPATCH_FLOATING_TYPES_AND2(
      kHalf, kBFloat16, iter.common_dtype(), "log_sigmoid_forward_cuda", [&] {
        using opmath_t = at::opmath_type<scalar_t>;

        gpu_kernel(iter, [] GPU_LAMBDA(scalar_t in_) -> scalar_t {
          const opmath_t in = in_;
          const auto min = std::min(opmath_t(0), in);
          const auto z = std::exp(-std::abs(in));
          return min - std::log1p(z);
        });
      });
}

主要修改:

  1. 前向计算 (CudaLogSigmoidFunctor)
    旧实现: - (max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
    新实现: log_sigmoid(x) = min(0,x) - log1p(exp(-|x|))
  2. 反向计算 (CudaLogSigmoidGradFunctor)
    旧实现: dx = dout * exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + exp(-x - max(-x,0)))
    新实现:grad = dout * (max_deriv - sign * (z / (1 + z))) ; z = exp(-abs(x)); max_deriv = (x < 0) ? 1 : 0, sign = (x < 0) ? 1 : -1 ;

测试结果:
paddle.nn.functional.log_sigmoid 对齐case(共16 个)

pcard-93269

@paddle-bot
Copy link

paddle-bot bot commented Oct 16, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@zhengshengning zhengshengning changed the title accuracy_stable_log_sigmoid [Precision Depth Alignment] paddle.log_sigmoid Oct 17, 2025
Copy link
Contributor

@Eddie-Wang1120 Eddie-Wang1120 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@A-nnonymous A-nnonymous left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhengshengning zhengshengning merged commit 37f7dbe into PaddlePaddle:develop Oct 17, 2025
57 of 58 checks passed
zhengshengning added a commit to zhengshengning/Paddle that referenced this pull request Oct 24, 2025
* accuracy_stable_log_sigmoid

* fix test_activation_stride_op.py
zhengshengning added a commit to zhengshengning/Paddle that referenced this pull request Oct 24, 2025
* accuracy_stable_log_sigmoid

* fix test_activation_stride_op.py
zhengshengning added a commit that referenced this pull request Oct 27, 2025
* CallScalarFunction uses the dtype of 'self' as the type of 'other' when opotype is 'div'(#75237)

* LinspaceKernel uses the dtype of 'self' as the type of 'step' when tensor is floating (#75238)

* align LinspaceKernel

* update meta

* update gpu kernel

* fix LinspaceKernelInner

* improve kernel

* fix CudaSigmoidGradFunctor and CudaSiluGradFunctor (#75341)

* Softplus accuracy and torch alignment 1 (#75363)

* [Precision Depth Alignment] paddle.tan reverse calculation: dx = dout *(1 + tan(x)^2) (#75335)

* Tan reverse calculation: dx = dout *(1 + tan(x)^2)

* [Precision Depth Alignment] Add support for CUDNN to paddle.nn.functional.grid_sample to align with torch accuracy.  (#75355)

* accuracy_stable_grid_sample

* fix

* correlation supports big tensor (#75383)

* fix

* fix test

* fix

* paddle.tanh Grad and torch alignment (float16) (#75454)

* [Precision Depth Alignment] paddle.sin and paddle.cos aligns with torch precision. (#75503)

* accuracy_stable_sin

* accuracy_stable_cos

* [深度对齐]Divide (#75379)

* fix

* fix

* fix

* fix

* fix

* [Precision Depth Alignment] fix precision for float16 of paddle.tan backward (#75525)

* fix precision for float16 of paddle.tan backward

* fix else branch of CudaTanGradFunctor

* [Precision Depth Alignment] fix precision for  paddle.expm1 (#75549)

* accuracy_stable_expm1

* fix

* Bigtensor排查修复[Paddle/paddle/phi/kernels/funcs] (#75523)

* fix

* fix

* [Precision Depth Alignment]  fix beta and threshold of paddle.nn.functional.softplus  to double (#75426)

* fix beta and threshold of Softplus to double

* fix test_softplus_activation_fuse_pass v1

* fix test_activation_zero

* fix flaot of SoftplusDoubleGradKernel to double

* add op_patches for softplus

* add yaml for ops/yaml/legacy

* fix infershape/operator for FLOAT64

* fix

* add SoftPlusOpTranscriber

* fix

* fix

* fix1

* fix2

* fix coverage

* fix coverage2

* fix (#75605)

* [深度对齐] dot (#75717)

* fix

* fix

* fix dcu

* [Precision Depth Alignment]  paddle.log aligns with torch precision (#75799)

* accuracy_stable_log

* accuracy_stable_log

* fix

* fix

* fix

* fix

* fix5

* [Precision Depth Alignment] fix eps of paddle.logit from float to double (#75816)

* accuracy_stable_logit

* add LogitOpTranscriber

* fix coverage

* fix 0yaml

* [Precision Depth Alignment] paddle.log_sigmoid (#75898)

* accuracy_stable_log_sigmoid

* fix test_activation_stride_op.py

* [Precision Depth Alignment] Modify the negative_slope parameter of the paddle.nn.functional.leaky_relu API to double (#75547)

* [big tensor] Paddle/paddle/phi/kernels/funcs gpuBigtensor (#75856)

* fix funcs

* gpu

* fix

* fix

* 修改PADDLE_ENFORCE信息

* fix cpu error

* fix dcu

* fix dcu

* fix

* [Fix] log sigmoid complex (#75953)

* feature: Add specialized LogSigmoidFunctor and CudaLogSigmoidFunctor for complex numbers

This commit introduces specialized implementations of LogSigmoidFunctor and CudaLogSigmoidFunctor to handle complex number inputs. The new implementations utilize direct formulas for improved accuracy and stability in calculations involving complex types.

* refactor: Optimize LogSigmoidFunctor and CudaLogSigmoidFunctor for complex types by caching exp(-x) to reduce redundant computations. This change enhances performance while maintaining accuracy in calculations.

* refactor: modified the formula in LogSigmoidFunctor to make it numerical stable

---------

Co-authored-by: Zhan Rongrui <46243324+zrr1999@users.noreply.github.com>
Co-authored-by: 正在学习 <62892980+cszdrg@users.noreply.github.com>
Co-authored-by: Bvicii <98971614+scyyh11@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants