Skip to content

Conversation

@zhengshengning
Copy link
Contributor

@zhengshengning zhengshengning commented Sep 22, 2025

PR Category

Operator Mechanism

PR Types

Bug fixes

Description

主要修改点如下:

  1. 增加对 float16 类型的分支处理,其它类型走原逻辑不变。
  2. 在 float16 分支中使用 CUDA half :将 out 显式转换为半精度,用 __hmul 计算 out^2,显式“先乘”,避免被融合为 FMA。
  3. 非 float16 类型保持原实现:dout * (one - out * out)。

测试结果:

  1. paddle.nn.functional.tanh 全部与torch对齐(共120个case);
  2. paddle.tanh 全部与torch对齐(共142个case);

pcard-67164

@paddle-bot
Copy link

paddle-bot bot commented Sep 22, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@zhengshengning zhengshengning changed the title [Precision Depth Alignment] Paddle.tanh Reverse Gradient Calculation and Torch Accuracy Alignment [Precision Depth Alignment] paddle.tanh reverse gradient calculation and Torch accuracy alignment. Sep 22, 2025
if constexpr (std::is_same<T, phi::float16>::value) {
__half out_half = __float2half_rn(static_cast<float>(out));
__half tmp_half = __hmul(out_half, out_half);
return dout * (one - static_cast<T>(__half2float(tmp_half)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

进了这个if, T应该就是float16了,为什么还要cast to float然后static cast回来,接着做减法?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为没找到float16可以直接转_half的,所以先转成float

Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@A-nnonymous A-nnonymous left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhengshengning zhengshengning merged commit 67c20ac into PaddlePaddle:develop Sep 24, 2025
81 of 84 checks passed
wanglezz pushed a commit to wanglezz/Paddle that referenced this pull request Sep 25, 2025
zhengshengning added a commit to zhengshengning/Paddle that referenced this pull request Oct 24, 2025
zhengshengning added a commit to zhengshengning/Paddle that referenced this pull request Oct 24, 2025
zhengshengning added a commit that referenced this pull request Oct 27, 2025
* CallScalarFunction uses the dtype of 'self' as the type of 'other' when opotype is 'div'(#75237)

* LinspaceKernel uses the dtype of 'self' as the type of 'step' when tensor is floating (#75238)

* align LinspaceKernel

* update meta

* update gpu kernel

* fix LinspaceKernelInner

* improve kernel

* fix CudaSigmoidGradFunctor and CudaSiluGradFunctor (#75341)

* Softplus accuracy and torch alignment 1 (#75363)

* [Precision Depth Alignment] paddle.tan reverse calculation: dx = dout *(1 + tan(x)^2) (#75335)

* Tan reverse calculation: dx = dout *(1 + tan(x)^2)

* [Precision Depth Alignment] Add support for CUDNN to paddle.nn.functional.grid_sample to align with torch accuracy.  (#75355)

* accuracy_stable_grid_sample

* fix

* correlation supports big tensor (#75383)

* fix

* fix test

* fix

* paddle.tanh Grad and torch alignment (float16) (#75454)

* [Precision Depth Alignment] paddle.sin and paddle.cos aligns with torch precision. (#75503)

* accuracy_stable_sin

* accuracy_stable_cos

* [深度对齐]Divide (#75379)

* fix

* fix

* fix

* fix

* fix

* [Precision Depth Alignment] fix precision for float16 of paddle.tan backward (#75525)

* fix precision for float16 of paddle.tan backward

* fix else branch of CudaTanGradFunctor

* [Precision Depth Alignment] fix precision for  paddle.expm1 (#75549)

* accuracy_stable_expm1

* fix

* Bigtensor排查修复[Paddle/paddle/phi/kernels/funcs] (#75523)

* fix

* fix

* [Precision Depth Alignment]  fix beta and threshold of paddle.nn.functional.softplus  to double (#75426)

* fix beta and threshold of Softplus to double

* fix test_softplus_activation_fuse_pass v1

* fix test_activation_zero

* fix flaot of SoftplusDoubleGradKernel to double

* add op_patches for softplus

* add yaml for ops/yaml/legacy

* fix infershape/operator for FLOAT64

* fix

* add SoftPlusOpTranscriber

* fix

* fix

* fix1

* fix2

* fix coverage

* fix coverage2

* fix (#75605)

* [深度对齐] dot (#75717)

* fix

* fix

* fix dcu

* [Precision Depth Alignment]  paddle.log aligns with torch precision (#75799)

* accuracy_stable_log

* accuracy_stable_log

* fix

* fix

* fix

* fix

* fix5

* [Precision Depth Alignment] fix eps of paddle.logit from float to double (#75816)

* accuracy_stable_logit

* add LogitOpTranscriber

* fix coverage

* fix 0yaml

* [Precision Depth Alignment] paddle.log_sigmoid (#75898)

* accuracy_stable_log_sigmoid

* fix test_activation_stride_op.py

* [Precision Depth Alignment] Modify the negative_slope parameter of the paddle.nn.functional.leaky_relu API to double (#75547)

* [big tensor] Paddle/paddle/phi/kernels/funcs gpuBigtensor (#75856)

* fix funcs

* gpu

* fix

* fix

* 修改PADDLE_ENFORCE信息

* fix cpu error

* fix dcu

* fix dcu

* fix

* [Fix] log sigmoid complex (#75953)

* feature: Add specialized LogSigmoidFunctor and CudaLogSigmoidFunctor for complex numbers

This commit introduces specialized implementations of LogSigmoidFunctor and CudaLogSigmoidFunctor to handle complex number inputs. The new implementations utilize direct formulas for improved accuracy and stability in calculations involving complex types.

* refactor: Optimize LogSigmoidFunctor and CudaLogSigmoidFunctor for complex types by caching exp(-x) to reduce redundant computations. This change enhances performance while maintaining accuracy in calculations.

* refactor: modified the formula in LogSigmoidFunctor to make it numerical stable

---------

Co-authored-by: Zhan Rongrui <46243324+zrr1999@users.noreply.github.com>
Co-authored-by: 正在学习 <62892980+cszdrg@users.noreply.github.com>
Co-authored-by: Bvicii <98971614+scyyh11@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants