Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added softmax!. reduced softmax's memory usage. #247

Merged
merged 7 commits into from
Dec 21, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 34 additions & 74 deletions src/softmax.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
export softmax, softmax!, ∇softmax, ∇softmax!,
logsoftmax, logsoftmax!, ∇logsoftmax, ∇logsoftmax!,
logsumexp
export softmax,
softmax!,
∇softmax,
∇softmax!,
logsoftmax,
logsoftmax!,
∇logsoftmax,
∇logsoftmax!,
logsumexp

"""
softmax(x; dims=1)
Expand All @@ -26,49 +32,20 @@ julia> softmax([1, 2, 3])

See also [`logsoftmax`](@ref).
"""
function softmax(xs::AbstractArray; dims=1)
max_ = maximum(xs, dims=dims)
exp_ = exp.(xs .- max_)
exp_ ./ sum(exp_, dims=dims)
softmax(x; dims = 1) = softmax!(similar(x), x; dims)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we still support julia < 1.5, so you will have to explicitly pass the keyword value

Copy link
Member

@CarloLucibello CarloLucibello Dec 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, this pattern will error on integer array inputs, that we still want to support, so we should promote to float types.

similar(x, float(eltype(x)))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

softmax!(x; dims = 1) = softmax!(x, x; dims)
function softmax!(out::T, x::T; dims = 1) where {T<:AbstractArray}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can relax the constraint of out and x being of the same type

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the constraint for softmax! logsoftmax!.
For ∇softmax , I guess Δ and out should be the same type. Right?

I added only one test case for integer input. is this OK?

@test softmax(Int[0, 0]) == [0.5, 0.5]

out .= exp.(x .- maximum(x; dims))
out ./= sum(out; dims)
end

function softmax!(out::AbstractVecOrMat{T}, xs::AbstractVecOrMat{T}) where {T}
@inbounds for j = 1:size(xs, 2)
# First, store column-wise maximum in the last element of `out`
out[end, j] = xs[end, j]
@inbounds for i = 1:(size(xs, 1) - 1)
out[end, j] = max(out[end, j], xs[i, j])
end

# Subtract the column-wise maximums to normalize, take exp()
# out .= exp(xs .- out[end, :])
@inbounds for i = 1:size(out, 1)
out[i, j] = exp(xs[i, j] - out[end, j])
end

# Normalize by sum of the entire thing
# out ./= sum(out, 1)
s = T(0)
@inbounds for i = 1:size(out, 1)
s += out[i, j]
end
@inbounds for i = 1:size(out, 1)
out[i, j] /= s
end
out
end
∇softmax(Δ, x; dims = 1) = ∇softmax!(similar(Δ), Δ, x; dims)
∇softmax!(Δ, x; dims = 1) = ∇softmax!(Δ, Δ, x; dims)
function ∇softmax!(out::T, Δ::T, x::T; dims = 1) where {T<:AbstractArray}
softmax!(out, x; dims)
out .= out .* (Δ .- sum(Δ .* out; dims))
end

function ∇softmax!(out::AbstractVecOrMat, Δ::AbstractVecOrMat, xs::AbstractVecOrMat)
sf = softmax(xs)
out .= sf .* (Δ .- sum(Δ .* sf, dims = 1))
end
function ∇softmax(Δ, xs; dims=1)
sf = softmax(xs, dims=dims)
sf .* (Δ .- sum(Δ .* sf, dims=dims))
end
∇softmax!(Δ, xs) = ∇softmax!(Δ, Δ, xs)


"""
logsoftmax(x; dims=1)
Expand All @@ -83,38 +60,22 @@ It is semantically equivalent to the following:

See also [`softmax`](@ref).
"""
function logsoftmax(xs::AbstractArray; dims=1)
max_ = maximum(xs, dims=dims)
exp_ = exp.(xs .- max_)
log_ = log.(sum(exp_, dims=dims))
(xs .- max_) .- log_
end

function logsoftmax!(out::AbstractVecOrMat, xs::AbstractVecOrMat)
for j = 1:size(xs, 2)
@inbounds begin
xi_max = xs[1, j]
for i = 1:size(out, 1)
xi_max = max(xi_max, xs[i, j])
end
s = zero(eltype(out))
for i = 1:size(out, 1)
s += exp(xs[i, j] - xi_max)
end
for i = 1:size(out, 1)
out[i, j] = xs[i, j] - log(s) - xi_max
end
end
end
return out
logsoftmax(x; dims = 1) = logsoftmax!(similar(x), x; dims)
logsoftmax!(x; dims = 1) = logsoftmax!(x, x; dims)
function logsoftmax!(out::T, x::T; dims = 1) where {T<:AbstractArray}
out .= x .- maximum(x; dims)
# out .= out .- log.(sum(exp.(out); dims = dims)) # WARN: this will decrease performance.
log_ = log.(sum(exp.(out); dims))
out .-= log_
end

function ∇logsoftmax!(out::AbstractVecOrMat, Δ::AbstractVecOrMat, xs::AbstractVecOrMat)
out .= Δ .- sum(Δ, dims=1) .* softmax(xs, dims=1)
∇logsoftmax(Δ, x; dims = 1) = ∇logsoftmax!(similar(Δ), Δ, x; dims)
∇logsoftmax!(Δ, x; dims = 1) = ∇logsoftmax!(Δ, Δ, x; dims)
function ∇logsoftmax!(out::T, Δ::T, x::T; dims = 1) where {T<:AbstractArray}
softmax!(out, x; dims)
out .= Δ .- sum(Δ, dims = dims) .* out
end

∇logsoftmax(Δ, xs; dims=1) = Δ .- sum(Δ, dims=dims) .* softmax(xs, dims=dims)
∇logsoftmax!(Δ, xs) = ∇logsoftmax!(Δ, Δ, xs)

"""
logsumexp(x; dims=:)
Expand All @@ -124,8 +85,7 @@ way.

See also [`logsoftmax`](@ref).
"""
function logsumexp(xs::AbstractArray; dims=:)
max_ = maximum(xs, dims=dims)
log_ = log.(sum(exp.(xs .- max_), dims=dims))
return max_ .+ log_
function logsumexp(x::AbstractArray; dims = :)
max_ = maximum(x; dims)
max_ .+ log.(sum(exp.(x .- max_); dims))
end
126 changes: 45 additions & 81 deletions test/softmax.jl
Original file line number Diff line number Diff line change
@@ -1,111 +1,75 @@
using Zygote
using Zygote
using Statistics: mean

@testset "softmax" begin
xs = rand(5,5)
xs = rand(5, 5)
@test all(sum(softmax(xs), dims = 1) .≈ 1)
@test all(sum(softmax(xs; dims=2), dims = 2) .≈ 1)
@test all(sum(softmax(xs; dims = 2), dims = 2) .≈ 1)
@test sum(softmax(vec(xs))) ≈ 1
@test log.(softmax(xs; dims=2)) ≈ logsoftmax(xs; dims=2)
@test log.(softmax(xs; dims = 2)) ≈ logsoftmax(xs; dims = 2)

xs = [-100_000, -100_000.]
xs = [-100_000.0, -100_000.0]
@test softmax(xs) ≈ [0.5, 0.5]
@test logsoftmax(xs) ≈ log.([0.5, 0.5])

xs = rand(5)
@test softmax(xs) ≈ exp.(xs) ./ sum(exp.(xs))
@test logsoftmax(xs) ≈ log.(softmax(xs))

xs = Float32[1, 2, 3000.]
xs = Float32[1, 2, 3000.0]
@test logsoftmax(xs) ≈ [-2999, -2998, 0]

xs = Float32[1 2 3; 1000 2000 3000]
@test logsoftmax(xs) ≈ [-999 -1998 -2997; 0 0 0.]
@test logsoftmax(xs) ≈ [-999 -1998 -2997; 0 0 0.0]

@test NNlib.∇logsoftmax(ones(size(xs)), xs) ≈ Float32[1 1 1; -1 -1 -1]
@test NNlib.∇softmax(ones(size(xs)), xs) ≈ zeros(Float32, size(xs))
@test ∇logsoftmax(ones(Float32, size(xs)), xs) ≈ Float32[1 1 1; -1 -1 -1]
@test ∇softmax(ones(Float32, size(xs)), xs) ≈ zeros(Float32, size(xs))

# These values precalculated using PyTorch's nn.LogSoftmax
xs = [
-0.238639 0.748142 -0.283194 -0.525461 -1.5348 -0.797842;
0.690384 0.211427 0.254794 -0.213572 -0.314174 -0.372663;
-1.146370 -0.577988 0.718952 0.919720 -0.620773 0.929977
-0.238639 0.748142 -0.283194 -0.525461 -1.5348 -0.797842
0.690384 0.211427 0.254794 -0.213572 -0.314174 -0.372663
-1.146370 -0.577988 0.718952 0.919720 -0.620773 0.929977
]
ys = [
0.237703 -0.621474 0.448193 0.546047 0.564185 0.632273;
-0.930163 0.0519798 0.0549979 0.3799 -0.477112 0.437428;
0.237703 -0.621474 0.448193 0.546047 0.564185 0.632273
-0.930163 0.0519798 0.0549979 0.3799 -0.477112 0.437428
0.69246 0.569494 -0.503191 -0.925947 -0.0870738 -1.0697
]
@test isapprox(NNlib.∇logsoftmax(ones(size(xs)), xs), ys; rtol = 1e-6)
@test isapprox(NNlib.∇softmax(ones(size(xs)), xs), zeros(size(xs)); atol = 1e-6)
@test isapprox(∇logsoftmax(ones(size(xs)), xs), ys; rtol = 1e-6)
@test isapprox(∇softmax(ones(size(xs)), xs), zeros(size(xs)); atol = 1e-6)
end

@testset "mutating softmax" begin
xs = Float64[1 2 3; 5 6 7]

out = zeros(Float64, size(xs))
NNlib.softmax!(out, xs)
@test isapprox(out, softmax(xs); rtol=1e-6)
NNlib.logsoftmax!(out, xs)
@test isapprox(out, logsoftmax(xs); rtol=1e-6)

out = ones(Float64, size(xs))
NNlib.softmax!(out, xs)
@test isapprox(out, softmax(xs); rtol=1e-6)
NNlib.logsoftmax!(out, xs)
@test isapprox(out, logsoftmax(xs); rtol=1e-6)

out = zeros(Float64, size(xs))
NNlib.∇softmax!(out, xs)
@test isapprox(out, NNlib.∇softmax(zeros(size(xs)), xs); rtol=1e-6)
out = zeros(Float64, size(xs))
NNlib.∇logsoftmax!(out, xs)
@test isapprox(out, NNlib.∇logsoftmax(zeros(size(xs)), xs); rtol=1e-6)

out = ones(Float64, size(xs))
NNlib.∇softmax!(out, xs)
@test isapprox(out, NNlib.∇softmax(ones(size(xs)), xs); rtol=1e-6)
out = ones(Float64, size(xs))
NNlib.∇logsoftmax!(out, xs)
@test isapprox(out, NNlib.∇logsoftmax(ones(size(xs)), xs); rtol=1e-6)

xs = [
-0.238639 0.748142 -0.283194 -0.525461 -1.5348 -0.797842;
0.690384 0.211427 0.254794 -0.213572 -0.314174 -0.372663;
-1.146370 -0.577988 0.718952 0.919720 -0.620773 0.929977
]

out = zeros(Float64, size(xs))
NNlib.softmax!(out, xs)
@test isapprox(out, softmax(xs); rtol=1e-6)
NNlib.logsoftmax!(out, xs)
@test isapprox(out, logsoftmax(xs); rtol=1e-6)

out = ones(Float64, size(xs))
NNlib.softmax!(out, xs)
@test isapprox(out, softmax(xs); rtol=1e-6)
NNlib.logsoftmax!(out, xs)
@test isapprox(out, logsoftmax(xs); rtol=1e-6)

out = zeros(Float64, size(xs))
NNlib.∇softmax!(out, xs)
@test isapprox(out, NNlib.∇softmax(zeros(size(xs)), xs); rtol=1e-6)
out = zeros(Float64, size(xs))
NNlib.∇logsoftmax!(out, xs)
@test isapprox(out, NNlib.∇logsoftmax(zeros(size(xs)), xs); rtol=1e-6)

out = ones(Float64, size(xs))
NNlib.∇softmax!(out, xs)
@test isapprox(out, NNlib.∇softmax(ones(size(xs)), xs); rtol=1e-6)
out = ones(Float64, size(xs))
NNlib.∇logsoftmax!(out, xs)
@test isapprox(out, NNlib.∇logsoftmax(ones(size(xs)), xs); rtol=1e-6)
map([
Float64[1 2 3; 5 6 7],
Float64[
-0.238639 0.748142 -0.283194 -0.525461 -1.5348 -0.797842
0.690384 0.211427 0.254794 -0.213572 -0.314174 -0.372663
-1.146370 -0.577988 0.718952 0.919720 -0.620773 0.929977
],
]) do xs
out = similar(xs)
softmax!(out, xs)
@test isapprox(out, softmax(xs); rtol = 1e-6)
logsoftmax!(out, xs)
@test isapprox(out, logsoftmax(xs); rtol = 1e-6)

map([zeros, ones]) do fn
Δ = fn(Float64, size(xs))
∇softmax!(out, Δ, xs)
@test isapprox(out, ∇softmax(Δ, xs); rtol = 1e-6)
∇logsoftmax!(out, Δ, xs)
@test isapprox(out, ∇logsoftmax(Δ, xs); rtol = 1e-6)
end
end
end

@testset "logsumexp" begin
flogsoft(x; dims) = mean(x .- logsoftmax(x; dims=dims), dims=dims)
x = rand(3,4)
@test logsumexp(x) ≈ flogsoft(x, dims=:)
@test logsumexp(x; dims=1) ≈ flogsoft(x, dims=1)
@test gradient(x -> logsumexp(x), x)[1] ≈ gradient(x -> flogsoft(x, dims=:), x)[1]
flogsoft(x; dims) = mean(x .- logsoftmax(x; dims = dims), dims = dims)

x = rand(3, 4)
@test logsumexp(x) ≈ flogsoft(x, dims = :)
@test logsumexp(x; dims = 1) ≈ flogsoft(x, dims = 1)
@test gradient(x -> logsumexp(x), x)[1] ≈ gradient(x -> flogsoft(x, dims = :), x)[1]
end