Skip to content

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Jan 11, 2022

This upgrades many of the activation functions. For instance, the new gelu is about 7x faster (when broadcast), and has worst-case error nextfloat(true, 9) instead of 7675 before. You can check my working below if so inclined.

It also defines a function which does fast_act(tanh) == tanh_fast etc, to allow Flux to automatically substitute the faster-than-Base variants. For all except tanh and sigmoid, however, there seems to be no point in keeping the slow versions around at all.

The exception is softplus, for which a fast version is surely possible, but not worked out yet, so not included here.

Paging @oscardssmith, since we discussed things in #345

using NNlib, BenchmarkTools, Statistics
# From the tests:
function countepsfrom(x::T, xtrue) where {T<:AbstractFloat}
    target = T(xtrue)
    for n in Iterators.flatten(zip(0:100, -1:-1:-100))
        nextfloat(x, n) === target && return n
    end
    return round(Int, (target - x) / eps(x))
end

mean_eps(f, g, xs) = mean(x -> abs(countepsfrom(f(x), g(big(x)))), xs)
worst_eps(f, g, xs) = maximum(x -> abs(countepsfrom(f(x), g(big(x)))), xs)
function find_worst(f, g, xs)
    c, i = findmax(x -> abs(countepsfrom(f(x), g(big(x)))), xs)
    c, xs[i]
end
# Also needed:
mean_abs(f, g, xs) = mean(x -> abs(f(x) - g(big(x))), xs) |> Float32
oftf(x, y) = oftype(float(x), y)

# In the same order as activations.jl:

########## ELU

elu_old(x, α=1) = ifelse(x  0, float(x), α * (exp(x) - 1))

deriv_elu(Ω, α=1) = ifelse 0, one(Ω), Ω + α)

elu_fast(x, α=1) = ifelse(x  0, float(x), @fastmath α * (exp(x) - 1))

#=

# No downside, just replace:

julia> @btime y .= elu_old.(x)   setup=(x=randn(Float32,1000); y=similar(x));
  6.175 μs (0 allocations: 0 bytes)

julia> @btime y .= elu_fast.(x)   setup=(x=randn(Float32,1000); y=similar(x));
  871.268 ns (0 allocations: 0 bytes)

julia> mean_abs(elu_old, elu_old, -4:0.01f0:4)
7.757981f-9

julia> mean_abs(elu_fast, elu_old, -4:0.01f0:4)
7.757981f-9

julia> find_worst(elu_old, elu_old, -4:0.01f0:4)
(11, -0.03f0)

julia> find_worst(elu_fast, elu_old, -4:0.01f0:4)
(11, -0.03f0)

=#

########## GELU

function gelu_old(x)
    α = oftf(x, 0.044715)
    λ = oftf(x, gelu_λ)
    x/2 * (1 + tanh* (x + α * x^3)))
end

const gelu_λ = (2 / π)
const gelu_2λ_ = (8 / π)

@inline function gelu_tf(x)
    α = oftf(x, 0.044715)
    λ = oftf(x, gelu_λ)
    x/2 * (1 + tanh_fast* x * muladd(x^2, α, one(x))))
end

@inline function gelu_fast(x)
    α = oftf(x, 0.044715)
    λλ = oftf(x, gelu_2λ_)
    x * sigmoid_fast(λλ * x * muladd(x^2, α, one(x)))
end

@inline function gelu_s(x)
    α = oftf(x, 0.044715)
    λλ = oftf(x, gelu_2λ_)
    x * sigmoid(λλ * x * muladd(x^2, α, one(x)))
end

#=

# gelu_fast is more accurate, and faster. Just replace?
# (gelu_tf is even faster, but much less accurate.)


julia> @btime y .= gelu_old.(x)   setup=(x=randn(Float32,1000); y=similar(x));
  min 6.548 μs, mean 7.560 μs (0 allocations)

julia> @btime y .= gelu_tf.(x)   setup=(x=randn(Float32,1000); y=similar(x));
  min 545.635 ns, mean 548.751 ns (0 allocations)

julia> @btime y .= gelu_fast.(x)   setup=(x=randn(Float32,1000); y=similar(x));
  min 977.733 ns, mean 991.673 ns (0 allocations)

julia> @btime y .= gelu_s.(x)   setup=(x=randn(Float32,1000); y=similar(x));
  min 3.401 μs, mean 3.446 μs (0 allocations)


julia> mean_abs(gelu_old, gelu_s, -5:0.01f0:5)
4.537188f-8

julia> mean_abs(gelu_tf, gelu_s, -5:0.01f0:5)
7.47122f-8

julia> mean_abs(gelu_fast, gelu_s, -5:0.01f0:5)
4.5121787f-8


julia> find_worst(gelu_old, gelu_s, -4f0:0.01f0:0f0)
(7675, -3.97f0)

julia> find_worst(gelu_fast, gelu_s, -4f0:0.01f0:0f0)
(9, -3.93f0)

julia> find_worst(gelu_tf, gelu_s, -4f0:0.01f0:0f0)
(33254, -3.93f0)


=#

########## SWISH

swish_old(x) = x * σ(x)

@inline swish_fast(x) = x * sigmoid_fast(x)

#=

# No downside, just replace. @inline is NB

julia> @btime y .= swish_old.(x)   setup=(x=randn(Float32,1000); y=similar(x));
  min 2.815 μs, mean 2.859 μs (0 allocations)

julia> @btime y .= swish_fast.(x)   setup=(x=randn(Float32,1000); y=similar(x));
  min 793.367 ns, mean 801.710 ns (0 allocations)

julia> mean_abs(swish_old, swish_old, -4:0.01f0:4)
3.824092f-8

julia> mean_abs(swish_fast, swish_old, -4:0.01f0:4)
3.824092f-8

julia> find_worst(swish_old, swish_old, -4:0.01f0:4)
(2, -3.43f0)

julia> find_worst(swish_fast, swish_old, -4:0.01f0:4)
(2, -3.43f0)

=#

########## LISHT

lisht_old(x) = x * tanh(x)

lisht_fast(x) = x * tanh_fast(x)

#=

# Just replace:


julia> @btime y .= lisht_old.(x)   setup=(x=randn(Float32,1000); y=similar(x));
  min 3.139 μs, mean 3.600 μs (0 allocations)

julia> @btime y .= lisht_fast.(x)   setup=(x=randn(Float32,1000); y=similar(x));
  min 385.261 ns, mean 387.571 ns (0 allocations)

julia> mean_abs(lisht_old, lisht_old, -4:0.01f0:4)
4.8593183f-8

julia> mean_abs(lisht_fast, lisht_old, -4:0.01f0:4)
9.2598036f-8

julia> find_worst(lisht_old, lisht_old, -4:0.01f0:4)
(1, -3.94f0)

julia> find_worst(lisht_fast, lisht_old, -4:0.01f0:4)
(2, -3.97f0)

=#

########## SELU

function selu_old(x)
    λ = oftf(x, selu_λ)
    α = oftf(x, selu_α)
    λ * ifelse(x > 0, x, α * (exp(x) - 1))
end

const selu_λ = 1.0507009873554804934193349852946
const selu_α = 1.6732632423543772848170429916717

function deriv_selu(Ω)
    λ = oftf(Ω, selu_λ)
    α = oftf(Ω, selu_α)
    ifelse> 0, λ, Ω + α * λ)
end

@inline function selu_fast(x)
    λ = oftf(x, selu_λ)
    α = oftf(x, selu_α)
    λ * ifelse(x > 0, x, @fastmath α * (exp(x) - 1))
end


#=

# Just replace:

julia> @btime y .= selu_old.(x)   setup=(x=randn(Float32,1000); y=similar(x));
  min 2.713 μs, mean 2.820 μs (0 allocations)

julia> @btime y .= selu_fast.(x)   setup=(x=randn(Float32,1000); y=similar(x));
  min 671.474 ns, mean 692.950 ns (0 allocations)

julia> mean_abs(selu_old, selu_old, -4:0.01f0:4)
5.8184565f-8

julia> mean_abs(selu_fast, selu_old, -4:0.01f0:4)
5.5488744f-8

julia> find_worst(selu_old, selu_old, -4:0.01f0:4)
(10, -0.03f0)

julia> find_worst(selu_fast, selu_old, -4:0.01f0:4)
(10, -0.03f0)

=#

########## SOFTPLUS

softplus_old(x) = log1p(exp(-abs(x))) + max(x, 0)

@inline softplus_macro(x) = @fastmath log(1 + exp(-abs(x))) + relu(x)

function softplus_fast(x::Float32)
    absx = abs(x)
    # At large negative x, approximate the tail as exp(x):
    expm_absx = @fastmath exp(-absx)
    _softplus_fast(x, expm_absx)
end
@inline function _softplus_fast(x::Float32, expm_absx::Float32)
    # Near to zero is where the log() matters, replace with a minimax rational function:
    main = evalpoly(x,
        (0.6931472f0, 0.33953527f0, 0.09100373f0, 0.016568843f0, 0.0013059091f0, 1.8928275f-6)
    )/evalpoly(x, 
        (1.0f0, 0.21118937f0, 0.11348734f0, 0.0145165585f0, 0.0014029884f0)
    )
    # Finally, at moderate negative x, neither of these is very accurate, so add a 3rd piece:
    negfact = evalpoly(x, 
        (0.99217963f0, 0.0027701536f0, -0.00039373388f0, 2.8039218f-5, -9.995612f-7, 1.4259708f-8)
    )
    ifelse(x < -15, expm_absx, ifelse(x < -10, expm_absx * negfact, expm_absx * main + relu(x)))
end
softplus_fast(x::Real) = softplus(x)


#=

# Here it's not so clear what to do. Maybe nothing, for now.
 
julia> @btime y .= softplus_old.(x)   setup=(x=randn(Float32,1000); y=similar(x));
  20.500 μs (0 allocations: 0 bytes)
  9.250 μs on 1.8

julia> @btime y .= softplus_macro.(x)   setup=(x=randn(Float32,1000); y=similar(x));
  10.750 μs (0 allocations: 0 bytes)
  7.125 μs on 1.8

julia> @btime y .= softplus_fast.(x)   setup=(x=randn(Float32,1000); y=similar(x));
  9.458 μs (0 allocations: 0 bytes)
  4.881 μs on 1.8

julia> mean_abs(softplus_old, softplus_old, -4:0.01f0:4)
2.6509584f-8

julia> mean_abs(softplus_macro, softplus_old, -4:0.01f0:4)
4.1025547f-8

julia> mean_abs(softplus_fast, softplus_old, -4:0.01f0:4)
0.038294494f0

julia> find_worst(softplus_old, softplus_old, -4:0.01f0:4)
(1, -4.0f0)

julia> find_worst(softplus_macro, softplus_old, -4:0.01f0:4)
(31, -3.98f0)

julia> find_worst(softplus_fast, softplus_old, -4:0.01f0:4)
(317443104, -3.93f0)

=#

########## tanhshrink

tanhshrink_old(x) = x - tanh(x)

@inline tanhshrink_fast(x) = x - tanh_fast(x)

#=

# Just replace?

julia> @btime y .= tanhshrink_old.(x)   setup=(x=randn(Float32,1000); y=similar(x));
  min 3.255 μs, mean 3.727 μs (0 allocations)

julia> @btime y .= tanhshrink_fast.(x)   setup=(x=randn(Float32,1000); y=similar(x));
  min 384.852 ns, mean 386.998 ns (0 allocations)

julia> mean_abs(tanhshrink_old, tanhshrink_old, -4:0.01f0:4)
2.8952828f-8

julia> mean_abs(tanhshrink_fast, tanhshrink_old, -4:0.01f0:4)
4.6727695f-8

julia> find_worst(tanhshrink_old, tanhshrink_old, -4:0.01f0:4)
(4697, -0.02f0)

julia> find_worst(tanhshrink_fast, tanhshrink_old, -4:0.01f0:4)
(3495, -0.02f0)

=#

########## softshrink

softshrink_old(x, λ=oftf(x, 0.5)) = min(max(0, x - λ), x + λ)

softshrink_fast(x, λ=oftf(x, 0.5)) = clamp(0, x - λ, x + λ)

#=

# Just replace

julia> @btime y .= softshrink_old.(x)   setup=(x=randn(Float32,1000); y=similar(x));
  429.437 ns (0 allocations: 0 bytes)

julia> @btime y .= softshrink_fast.(x)   setup=(x=randn(Float32,1000); y=similar(x));
  166.884 ns (0 allocations: 0 bytes)

julia> mean_abs(softshrink_old, softshrink_old, -4:0.01f0:4)
0.0f0

julia> mean_abs(softshrink_fast, softshrink_old, -4:0.01f0:4)
0.0f0

=#

@oscardssmith
Copy link

Well done! This looks great.

Takes an optional 2nd argument, so that you can disable
this replacement for some array or element types.
"""
@inline fast_act(f::F, ::AbstractArray = 1:0) where {F<:Function} = f
Copy link
Member

@ToucheSir ToucheSir Jan 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👀 I've never seen this ::AbstractArray = 1:0 pattern before, what does it do? Edit: figured that part out, is there a benefit to 1:0 over 0:0 or 0:1?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking is that we may later decide that it's better to skip tanh_fast on the GPU. I can't measure a difference so who knows. To do that, we can add a method in NNlibCUDA like fast_act(::typeof(tanh), ::CuArray) = tanh. Provided you call this like fast = NNlib.fast_act(fun, x) with an example of the array you plan to broadcast it over.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My branch for this is mcabbott/Flux.jl@217eb1a
which I see I did on top of FluxML/Flux.jl#1761

@mcabbott mcabbott merged commit e928fad into FluxML:master Jan 11, 2022
@mcabbott mcabbott deleted the activate3 branch January 11, 2022 04:22
@ToucheSir
Copy link
Member

Flux is not happy: https://github.com/FluxML/Flux.jl/runs/4770890956?check_suite_focus=true. I think it's time to look into reverse CI.

@mcabbott
Copy link
Member Author

Oh no, sorry. Probably not hard to fix though:

julia> methods(tanh_fast)
# 4 methods for generic function "tanh_fast":
[1] tanh_fast(x::AbstractArray, args...) in NNlib at /Users/me/.julia/packages/NNlib/tAGmA/src/activations.jl:663
[2] tanh_fast(x::Float32) in NNlib at /Users/me/.julia/packages/NNlib/tAGmA/src/activations.jl:694
[3] tanh_fast(x::Float64) in NNlib at /Users/me/.julia/packages/NNlib/tAGmA/src/activations.jl:703
[4] tanh_fast(x::Real) in NNlib at /Users/me/.julia/packages/NNlib/tAGmA/src/activations.jl:715

julia> Flux.NilNumber.Nil |> supertype
Number

Either we change the supertype, or we widen the method.

@ToucheSir
Copy link
Member

cc @darsnack for his thoughts.

@darsnack
Copy link
Member

I think we tried making Nil an AbstractFloat and all kinds of things that worked "for free" in the PR broke. So we switched back. Arguably, the default for Nil should be to try and apply it in the widest setting.

Widening tanh_fast seems like the thing to do.

@mcabbott
Copy link
Member Author

Yes I remember something about AbstractFloat causing problems. But does Real cause problems?

Should probably widen here anyway, now that I think about it, in case someone is crazy enough to use Flux with complex numbers.

@mcabbott
Copy link
Member Author

Although in fact, now that I look at the file, few activation functions will work on ::Complex, since they often compare x < 0 etc. So really they should all have signature Real. Which does argue for Nil to change too, unless this causes some disaster elsewhere?

@darsnack
Copy link
Member

Yeah, just tested that subtyping Real does break a lot. Not impossible things to fix but just niceties that are defined for Number and not Real.

Ideally, Nil would be used for any number type including Complex, since it is only used in a forward context.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants