Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaN caused by tanh_fast #407

Closed
zhubonan opened this issue May 5, 2022 · 2 comments · Fixed by #408
Closed

NaN caused by tanh_fast #407

zhubonan opened this issue May 5, 2022 · 2 comments · Fixed by #408
Labels

Comments

@zhubonan
Copy link

zhubonan commented May 5, 2022

FluxML/Flux.jl#1776 subsitutes tanh by tanh_fast, howeer, the latter gives NaN for large numbers:

I am on Julia 1.7.2 and Flux 0.13.

using Flux

julia> tanh_fast(1e10)
NaN

julia> tanh_fast(-1e10)
-1.0

julia> d = Dense(1=>1, tanh)
Dense(1 => 1, tanh)  # 2 parameters

julia> d([-1.e16])
1-element Vector{Float64}:
 NaN

julia> d([1.e16])
1-element Vector{Float64}:
 -1.0

NaN would then propagate down the network....

@zhubonan zhubonan changed the title NaN caused by by tanh_fast NaN caused by tanh_fast May 5, 2022
@zhubonan
Copy link
Author

zhubonan commented May 5, 2022

Err may be this should be raised to NNLib.jl instead?

@mcabbott mcabbott transferred this issue from FluxML/Flux.jl May 5, 2022
@mcabbott
Copy link
Member

mcabbott commented May 5, 2022

I think that's a stupid bug: There's a test which should use a constant for x > sqrt(900) to avoid this, but it uses sign(y) not sign(x).

NNlib.jl/src/activations.jl

Lines 776 to 784 in 886b34c

@inline function tanh_fast(x::Float64)
exp2x = @fastmath exp(x + x)
y = (exp2x - 1) / (exp2x + 1)
# That has large errors near zero; using `expm1` would more accurate, but about as slow as `tanh`.
# Instead, we switch to a polynomial, which is very accurate within its range:
x2 = x * x
ypoly = x * evalpoly(x2, (1.0, -0.33333333333324583, 0.13333333325511604, -0.05396823125794372, 0.02186660872609521, -0.008697141630499953))
ifelse(x2 > 900.0, sign(y), ifelse(x2 < 0.017, oftype(y, ypoly), y))
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants