Skip to content

Avoid maximum in softmax #450

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 2, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/softmax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ See also [`logsoftmax`](@ref).

# Examples

```jldoctest
```jldoctest; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
julia> softmax([1, 2, 3])
3-element Vector{Float64}:
0.09003057317038046
Expand Down Expand Up @@ -58,13 +58,14 @@ softmax(x::AbstractArray{T}; dims = 1) where {T} = softmax!(similar(x, float(T))
softmax!(x::AbstractArray; dims = 1) = softmax!(x, x; dims)

function softmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T}
max_ = maximum(x; dims)
max_ = fast_maximum(x; dims)
if all(isfinite, max_)
@fastmath out .= exp.(x .- max_)
else
@fastmath @. out = ifelse(isequal(max_,Inf), ifelse(isequal(x,Inf), 1, 0), exp(x - max_))
end
out ./= sum(out; dims)
tmp = dims isa Colon ? sum(out) : sum!(max_, out)
out ./= tmp
end

function ∇softmax_data(dy::AbstractArray{T}, y::AbstractArray{S}; dims = 1) where {T,S}
Expand All @@ -75,7 +76,7 @@ function ∇softmax_data(dy::AbstractArray{T}, y::AbstractArray{S}; dims = 1) wh
# This path is faster, only safe for 1st derivatives though.
# Was previously `∇softmax!(dx, dy, x, y; dims)` to allow CUDA overloads,
# but that was slow: https://github.com/FluxML/NNlibCUDA.jl/issues/30
out = similar(y, promote_type(T,S))
out = similar(y, promote_type(T,S)) # sure to be mutable
out .= dy .* y
out .= out .- y .* sum(out; dims)
end
Expand All @@ -90,6 +91,7 @@ end
within_grad() = false
rrule(::typeof(within_grad)) = true, _ -> (NoTangent(),)

fast_maximum(x::AbstractArray{T}; dims) where {T} = @fastmath reduce(max, x; dims, init = float(T)(-Inf))

"""
logsoftmax(x; dims = 1)
Expand All @@ -109,7 +111,7 @@ logsoftmax(x::AbstractArray{T}; dims = 1) where {T} = logsoftmax!(similar(x, flo
logsoftmax!(x::AbstractArray; dims = 1) = logsoftmax!(x, x; dims)

function logsoftmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T}
max_ = maximum(x; dims)
max_ = fast_maximum(x; dims)
if all(isfinite, max_)
out .= x .- max_
else
Expand Down