From 7e4480b4a733a4027a0afdfc4ba8ccb4bddd97e7 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 11 Dec 2021 23:07:25 -0500 Subject: [PATCH] fast activation functions --- src/layers/basic.jl | 3 ++- src/layers/conv.jl | 12 ++++++++---- src/layers/recurrent.jl | 11 ++++++----- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index e40457ef53..f2b0511145 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -154,7 +154,8 @@ end @functor Dense function (a::Dense)(x::AbstractVecOrMat) - W, b, σ = a.weight, a.bias, a.σ + W, b= a.weight, a.bias + σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc return σ.(W*x .+ b) end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 6cb564924e..8cac79b803 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -161,7 +161,8 @@ end @functor Conv function (c::Conv)(x::AbstractArray) - σ, b = c.σ, reshape(c.bias, ntuple(_ -> 1, length(c.stride))..., :, 1) + b = reshape(c.bias, ntuple(_ -> 1, length(c.stride))..., :, 1) + σ = NNlib.fast_act(c.σ, x) cdims = DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups) σ.(conv(x, c.weight, cdims) .+ b) end @@ -278,7 +279,8 @@ end @nograd conv_transpose_dims function (c::ConvTranspose)(x::AbstractArray) - σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) + b = reshape(c.bias, map(_->1, c.stride)..., :, 1) + σ = NNlib.fast_act(c.σ, x) cdims = conv_transpose_dims(c, x) σ.(∇conv_data(x, c.weight, cdims) .+ b) end @@ -371,7 +373,8 @@ depthwiseconvfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}; init = glorot_uniform) where N = init(filter..., div(ch[2], ch[1]), ch[1]) function (c::DepthwiseConv)(x) - σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) + b = reshape(c.bias, map(_->1, c.stride)..., :, 1) + σ = NNlib.fast_act(c.σ, x) cdims = DepthwiseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation) σ.(depthwiseconv(x, c.weight, cdims) .+ b) end @@ -450,7 +453,8 @@ function crosscor(x, w, ddims::DenseConvDims) end function (c::CrossCor)(x::AbstractArray) - σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) + b = reshape(c.bias, map(_->1, c.stride)..., :, 1) + σ = NNlib.fast_act(c.σ, x) cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation) σ.(crosscor(x, c.weight, cdims) .+ b) end diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 87f77c565a..989015e59a 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -117,7 +117,8 @@ RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zero RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1)) function (m::RNNCell{F,A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,A,V,T} - σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b + Wi, Wh, b = m.Wi, m.Wh, m.b + σ = NNlib.fast_act(m.σ, x) h = σ.(Wi*x .+ Wh*h .+ b) return h, reshape_cell_output(h, x) end @@ -224,8 +225,8 @@ function (m::LSTMCell{A,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{Abstr b, o = m.b, size(h, 1) g = m.Wi*x .+ m.Wh*h .+ b input, forget, cell, output = multigate(g, o, Val(4)) - c′ = @. σ(forget) * c + σ(input) * tanh(cell) - h′ = @. σ(output) * tanh(c′) + c′ = @. sigmoid_fast(forget) * c + sigmoid_fast(input) * tanh_fast(cell) + h′ = @. sigmoid_fast(output) * tanh_fast(c′) return (h′, c′), reshape_cell_output(h′, x) end @@ -309,7 +310,7 @@ function (m::GRUCell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},O Wi, Wh, b, o = m.Wi, m.Wh, m.b, size(h, 1) gxs, ghs, bs = multigate(Wi*x, o, Val(3)), multigate(Wh*h, o, Val(3)), multigate(b, o, Val(3)) r, z = _gru_output(gxs, ghs, bs) - h̃ = @. tanh(gxs[3] + r * ghs[3] + bs[3]) + h̃ = @. tanh_fast(gxs[3] + r * ghs[3] + bs[3]) h′ = @. (1 - z) * h̃ + z * h return h′, reshape_cell_output(h′, x) end @@ -387,7 +388,7 @@ function (m::GRUv3Cell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T} Wi, Wh, b, Wh_h̃, o = m.Wi, m.Wh, m.b, m.Wh_h̃, size(h, 1) gxs, ghs, bs = multigate(Wi*x, o, Val(3)), multigate(Wh*h, o, Val(2)), multigate(b, o, Val(3)) r, z = _gru_output(gxs, ghs, bs) - h̃ = tanh.(gxs[3] .+ (Wh_h̃ * (r .* h)) .+ bs[3]) + h̃ = tanh_fast.(gxs[3] .+ (Wh_h̃ * (r .* h)) .+ bs[3]) h′ = @. (1 - z) * h̃ + z * h return h′, reshape_cell_output(h′, x) end