Skip to content

Commit

Permalink
Merge pull request #247 from norci/inplace_softmax
Browse files Browse the repository at this point in the history
added softmax!. reduced softmax's memory usage.
  • Loading branch information
CarloLucibello authored Dec 21, 2020
2 parents cab2e56 + 5a9f818 commit bd63541
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 156 deletions.
113 changes: 38 additions & 75 deletions src/softmax.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
export softmax, softmax!, ∇softmax, ∇softmax!,
logsoftmax, logsoftmax!, ∇logsoftmax, ∇logsoftmax!,
logsumexp
export softmax,
softmax!,
∇softmax,
∇softmax!,
logsoftmax,
logsoftmax!,
∇logsoftmax,
∇logsoftmax!,
logsumexp

"""
softmax(x; dims=1)
Expand All @@ -26,49 +32,20 @@ julia> softmax([1, 2, 3])
See also [`logsoftmax`](@ref).
"""
function softmax(xs::AbstractArray; dims=1)
max_ = maximum(xs, dims=dims)
exp_ = exp.(xs .- max_)
exp_ ./ sum(exp_, dims=dims)
softmax(x; dims = 1) = softmax!(similar(x, (float eltype)(x)), x; dims = dims)
softmax!(x; dims = 1) = softmax!(x, x; dims = dims)
function softmax!(out::O, x::T; dims = 1) where {O<:AbstractArray,T<:AbstractArray}
out .= exp.(x .- maximum(x; dims = dims))
out ./= sum(out; dims = dims)
end

function softmax!(out::AbstractVecOrMat{T}, xs::AbstractVecOrMat{T}) where {T}
@inbounds for j = 1:size(xs, 2)
# First, store column-wise maximum in the last element of `out`
out[end, j] = xs[end, j]
@inbounds for i = 1:(size(xs, 1) - 1)
out[end, j] = max(out[end, j], xs[i, j])
end

# Subtract the column-wise maximums to normalize, take exp()
# out .= exp(xs .- out[end, :])
@inbounds for i = 1:size(out, 1)
out[i, j] = exp(xs[i, j] - out[end, j])
end

# Normalize by sum of the entire thing
# out ./= sum(out, 1)
s = T(0)
@inbounds for i = 1:size(out, 1)
s += out[i, j]
end
@inbounds for i = 1:size(out, 1)
out[i, j] /= s
end
out
end
∇softmax(Δ, x; dims = 1) = ∇softmax!(similar(Δ), Δ, x; dims = dims)
∇softmax!(Δ, x; dims = 1) = ∇softmax!(Δ, Δ, x; dims = dims)
function ∇softmax!(out::O, Δ::O, x::T; dims = 1) where {O<:AbstractArray,T<:AbstractArray}
softmax!(out, x; dims = dims)
out .*= Δ .- sum.* out; dims = dims)
end

function ∇softmax!(out::AbstractVecOrMat, Δ::AbstractVecOrMat, xs::AbstractVecOrMat)
sf = softmax(xs)
out .= sf .*.- sum.* sf, dims = 1))
end
function ∇softmax(Δ, xs; dims=1)
sf = softmax(xs, dims=dims)
sf .*.- sum.* sf, dims=dims))
end
∇softmax!(Δ, xs) = ∇softmax!(Δ, Δ, xs)


"""
logsoftmax(x; dims=1)
Expand All @@ -83,39 +60,26 @@ It is semantically equivalent to the following:
See also [`softmax`](@ref).
"""
function logsoftmax(xs::AbstractArray; dims=1)
max_ = maximum(xs, dims=dims)
exp_ = exp.(xs .- max_)
log_ = log.(sum(exp_, dims=dims))
(xs .- max_) .- log_
logsoftmax(x; dims = 1) = logsoftmax!(similar(x, (float eltype)(x)), x; dims = dims)
logsoftmax!(x; dims = 1) = logsoftmax!(x, x; dims = dims)
function logsoftmax!(out::O, x::T; dims = 1) where {O<:AbstractArray,T<:AbstractArray}
out .= x .- maximum(x; dims = dims)
# out .-= log.(sum(exp.(out); dims = dims)) # WARN: this will decrease performance.
log_ = log.(sum(exp.(out); dims = dims))
out .-= log_
end

function logsoftmax!(out::AbstractVecOrMat, xs::AbstractVecOrMat)
for j = 1:size(xs, 2)
@inbounds begin
xi_max = xs[1, j]
for i = 1:size(out, 1)
xi_max = max(xi_max, xs[i, j])
end
s = zero(eltype(out))
for i = 1:size(out, 1)
s += exp(xs[i, j] - xi_max)
end
for i = 1:size(out, 1)
out[i, j] = xs[i, j] - log(s) - xi_max
end
end
end
return out
∇logsoftmax(Δ, x; dims = 1) = ∇logsoftmax!(similar(Δ), Δ, x; dims = dims)
∇logsoftmax!(Δ, x; dims = 1) = ∇logsoftmax!(Δ, Δ, x; dims = dims)
function ∇logsoftmax!(
out::O,
Δ::O,
x::T;
dims = 1,
) where {O<:AbstractArray,T<:AbstractArray}
out .= Δ .- sum(Δ, dims = dims) .* softmax!(out, x; dims = dims)
end

function ∇logsoftmax!(out::AbstractVecOrMat, Δ::AbstractVecOrMat, xs::AbstractVecOrMat)
out .= Δ .- sum(Δ, dims=1) .* softmax(xs, dims=1)
end

∇logsoftmax(Δ, xs; dims=1) = Δ .- sum(Δ, dims=dims) .* softmax(xs, dims=dims)
∇logsoftmax!(Δ, xs) = ∇logsoftmax!(Δ, Δ, xs)

"""
logsumexp(x; dims=:)
Expand All @@ -124,8 +88,7 @@ way.
See also [`logsoftmax`](@ref).
"""
function logsumexp(xs::AbstractArray; dims=:)
max_ = maximum(xs, dims=dims)
log_ = log.(sum(exp.(xs .- max_), dims=dims))
return max_ .+ log_
function logsumexp(x::AbstractArray; dims = :)
max_ = maximum(x; dims = dims)
max_ .+ log.(sum(exp.(x .- max_); dims = dims))
end
138 changes: 57 additions & 81 deletions test/softmax.jl
Original file line number Diff line number Diff line change
@@ -1,111 +1,87 @@
using Zygote
using Zygote
using Statistics: mean

@testset "softmax integer input" begin
@test softmax(Int[0, 0]) == [0.5, 0.5]
end

@testset "softmax on different dims" begin
xs = rand(fill(2, 5)...)
out = similar(xs)
for (fn!, fn) in [(softmax!, softmax), (logsoftmax!, logsoftmax)], i = 1:ndims(xs)
@test fn!(out, xs; dims = i) == fn(xs; dims = i)
end
end

@testset "softmax" begin
xs = rand(5,5)
xs = rand(5, 5)
@test all(sum(softmax(xs), dims = 1) .≈ 1)
@test all(sum(softmax(xs; dims=2), dims = 2) .≈ 1)
@test all(sum(softmax(xs; dims = 2), dims = 2) .≈ 1)
@test sum(softmax(vec(xs))) 1
@test log.(softmax(xs; dims=2)) logsoftmax(xs; dims=2)
@test log.(softmax(xs; dims = 2)) logsoftmax(xs; dims = 2)

xs = [-100_000, -100_000.]
xs = [-100_000.0, -100_000.0]
@test softmax(xs) [0.5, 0.5]
@test logsoftmax(xs) log.([0.5, 0.5])

xs = rand(5)
@test softmax(xs) exp.(xs) ./ sum(exp.(xs))
@test logsoftmax(xs) log.(softmax(xs))

xs = Float32[1, 2, 3000.]
xs = Float32[1, 2, 3000.0]
@test logsoftmax(xs) [-2999, -2998, 0]

xs = Float32[1 2 3; 1000 2000 3000]
@test logsoftmax(xs) [-999 -1998 -2997; 0 0 0.]
@test logsoftmax(xs) [-999 -1998 -2997; 0 0 0.0]

@test NNlib.∇logsoftmax(ones(size(xs)), xs) Float32[1 1 1; -1 -1 -1]
@test NNlib.∇softmax(ones(size(xs)), xs) zeros(Float32, size(xs))
@test ∇logsoftmax(ones(Float32, size(xs)), xs) Float32[1 1 1; -1 -1 -1]
@test ∇softmax(ones(Float32, size(xs)), xs) zeros(Float32, size(xs))

# These values precalculated using PyTorch's nn.LogSoftmax
xs = [
-0.238639 0.748142 -0.283194 -0.525461 -1.5348 -0.797842;
0.690384 0.211427 0.254794 -0.213572 -0.314174 -0.372663;
-1.146370 -0.577988 0.718952 0.919720 -0.620773 0.929977
-0.238639 0.748142 -0.283194 -0.525461 -1.5348 -0.797842
0.690384 0.211427 0.254794 -0.213572 -0.314174 -0.372663
-1.146370 -0.577988 0.718952 0.919720 -0.620773 0.929977
]
ys = [
0.237703 -0.621474 0.448193 0.546047 0.564185 0.632273;
-0.930163 0.0519798 0.0549979 0.3799 -0.477112 0.437428;
0.237703 -0.621474 0.448193 0.546047 0.564185 0.632273
-0.930163 0.0519798 0.0549979 0.3799 -0.477112 0.437428
0.69246 0.569494 -0.503191 -0.925947 -0.0870738 -1.0697
]
@test isapprox(NNlib.∇logsoftmax(ones(size(xs)), xs), ys; rtol = 1e-6)
@test isapprox(NNlib.∇softmax(ones(size(xs)), xs), zeros(size(xs)); atol = 1e-6)
@test ∇logsoftmax(ones(size(xs)), xs) ys rtol = 1e-6
@test ∇softmax(ones(size(xs)), xs) zeros(size(xs)) atol = 1e-6
end

@testset "mutating softmax" begin
xs = Float64[1 2 3; 5 6 7]

out = zeros(Float64, size(xs))
NNlib.softmax!(out, xs)
@test isapprox(out, softmax(xs); rtol=1e-6)
NNlib.logsoftmax!(out, xs)
@test isapprox(out, logsoftmax(xs); rtol=1e-6)

out = ones(Float64, size(xs))
NNlib.softmax!(out, xs)
@test isapprox(out, softmax(xs); rtol=1e-6)
NNlib.logsoftmax!(out, xs)
@test isapprox(out, logsoftmax(xs); rtol=1e-6)

out = zeros(Float64, size(xs))
NNlib.∇softmax!(out, xs)
@test isapprox(out, NNlib.∇softmax(zeros(size(xs)), xs); rtol=1e-6)
out = zeros(Float64, size(xs))
NNlib.∇logsoftmax!(out, xs)
@test isapprox(out, NNlib.∇logsoftmax(zeros(size(xs)), xs); rtol=1e-6)

out = ones(Float64, size(xs))
NNlib.∇softmax!(out, xs)
@test isapprox(out, NNlib.∇softmax(ones(size(xs)), xs); rtol=1e-6)
out = ones(Float64, size(xs))
NNlib.∇logsoftmax!(out, xs)
@test isapprox(out, NNlib.∇logsoftmax(ones(size(xs)), xs); rtol=1e-6)

xs = [
-0.238639 0.748142 -0.283194 -0.525461 -1.5348 -0.797842;
0.690384 0.211427 0.254794 -0.213572 -0.314174 -0.372663;
-1.146370 -0.577988 0.718952 0.919720 -0.620773 0.929977
]

out = zeros(Float64, size(xs))
NNlib.softmax!(out, xs)
@test isapprox(out, softmax(xs); rtol=1e-6)
NNlib.logsoftmax!(out, xs)
@test isapprox(out, logsoftmax(xs); rtol=1e-6)

out = ones(Float64, size(xs))
NNlib.softmax!(out, xs)
@test isapprox(out, softmax(xs); rtol=1e-6)
NNlib.logsoftmax!(out, xs)
@test isapprox(out, logsoftmax(xs); rtol=1e-6)

out = zeros(Float64, size(xs))
NNlib.∇softmax!(out, xs)
@test isapprox(out, NNlib.∇softmax(zeros(size(xs)), xs); rtol=1e-6)
out = zeros(Float64, size(xs))
NNlib.∇logsoftmax!(out, xs)
@test isapprox(out, NNlib.∇logsoftmax(zeros(size(xs)), xs); rtol=1e-6)

out = ones(Float64, size(xs))
NNlib.∇softmax!(out, xs)
@test isapprox(out, NNlib.∇softmax(ones(size(xs)), xs); rtol=1e-6)
out = ones(Float64, size(xs))
NNlib.∇logsoftmax!(out, xs)
@test isapprox(out, NNlib.∇logsoftmax(ones(size(xs)), xs); rtol=1e-6)
map([
Float64[1 2 3; 5 6 7],
Float64[
-0.238639 0.748142 -0.283194 -0.525461 -1.5348 -0.797842
0.690384 0.211427 0.254794 -0.213572 -0.314174 -0.372663
-1.146370 -0.577988 0.718952 0.919720 -0.620773 0.929977
],
]) do xs
out = similar(xs)
softmax!(out, xs)
@test out softmax(xs) rtol = 1e-6
logsoftmax!(out, xs)
@test out logsoftmax(xs) rtol = 1e-6

map([zeros, ones]) do fn
Δ = fn(Float64, size(xs))
∇softmax!(out, Δ, xs)
@test out ∇softmax(Δ, xs) rtol = 1e-6
∇logsoftmax!(out, Δ, xs)
@test out ∇logsoftmax(Δ, xs) rtol = 1e-6
end
end
end

@testset "logsumexp" begin
flogsoft(x; dims) = mean(x .- logsoftmax(x; dims=dims), dims=dims)
x = rand(3,4)
@test logsumexp(x) flogsoft(x, dims=:)
@test logsumexp(x; dims=1) flogsoft(x, dims=1)
@test gradient(x -> logsumexp(x), x)[1] gradient(x -> flogsoft(x, dims=:), x)[1]
flogsoft(x; dims) = mean(x .- logsoftmax(x; dims = dims), dims = dims)

x = rand(3, 4)
@test logsumexp(x) flogsoft(x, dims = :)
@test logsumexp(x; dims = 1) flogsoft(x, dims = 1)
@test gradient(x -> logsumexp(x), x)[1] gradient(x -> flogsoft(x, dims = :), x)[1]
end

0 comments on commit bd63541

Please sign in to comment.