-
-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added softmax!. reduced softmax's memory usage. #247
Changes from 2 commits
94726b9
02bfb40
e5bb4d4
9a9cbde
39970e8
6d48ff8
5a9f818
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,12 @@ | ||
export softmax, softmax!, ∇softmax, ∇softmax!, | ||
logsoftmax, logsoftmax!, ∇logsoftmax, ∇logsoftmax!, | ||
logsumexp | ||
export softmax, | ||
softmax!, | ||
∇softmax, | ||
∇softmax!, | ||
logsoftmax, | ||
logsoftmax!, | ||
∇logsoftmax, | ||
∇logsoftmax!, | ||
logsumexp | ||
|
||
""" | ||
softmax(x; dims=1) | ||
|
@@ -26,49 +32,20 @@ julia> softmax([1, 2, 3]) | |
|
||
See also [`logsoftmax`](@ref). | ||
""" | ||
function softmax(xs::AbstractArray; dims=1) | ||
max_ = maximum(xs, dims=dims) | ||
exp_ = exp.(xs .- max_) | ||
exp_ ./ sum(exp_, dims=dims) | ||
softmax(x; dims = 1) = softmax!(similar(x), x; dims) | ||
softmax!(x; dims = 1) = softmax!(x, x; dims) | ||
function softmax!(out::T, x::T; dims = 1) where {T<:AbstractArray} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can relax the constraint of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have updated the constraint for I added only one test case for integer input. is this OK? @test softmax(Int[0, 0]) == [0.5, 0.5] |
||
out .= exp.(x .- maximum(x; dims)) | ||
out ./= sum(out; dims) | ||
end | ||
|
||
function softmax!(out::AbstractVecOrMat{T}, xs::AbstractVecOrMat{T}) where {T} | ||
@inbounds for j = 1:size(xs, 2) | ||
# First, store column-wise maximum in the last element of `out` | ||
out[end, j] = xs[end, j] | ||
@inbounds for i = 1:(size(xs, 1) - 1) | ||
out[end, j] = max(out[end, j], xs[i, j]) | ||
end | ||
|
||
# Subtract the column-wise maximums to normalize, take exp() | ||
# out .= exp(xs .- out[end, :]) | ||
@inbounds for i = 1:size(out, 1) | ||
out[i, j] = exp(xs[i, j] - out[end, j]) | ||
end | ||
|
||
# Normalize by sum of the entire thing | ||
# out ./= sum(out, 1) | ||
s = T(0) | ||
@inbounds for i = 1:size(out, 1) | ||
s += out[i, j] | ||
end | ||
@inbounds for i = 1:size(out, 1) | ||
out[i, j] /= s | ||
end | ||
out | ||
end | ||
∇softmax(Δ, x; dims = 1) = ∇softmax!(similar(Δ), Δ, x; dims) | ||
∇softmax!(Δ, x; dims = 1) = ∇softmax!(Δ, Δ, x; dims) | ||
function ∇softmax!(out::T, Δ::T, x::T; dims = 1) where {T<:AbstractArray} | ||
softmax!(out, x; dims) | ||
out .= out .* (Δ .- sum(Δ .* out; dims)) | ||
end | ||
|
||
function ∇softmax!(out::AbstractVecOrMat, Δ::AbstractVecOrMat, xs::AbstractVecOrMat) | ||
sf = softmax(xs) | ||
out .= sf .* (Δ .- sum(Δ .* sf, dims = 1)) | ||
end | ||
function ∇softmax(Δ, xs; dims=1) | ||
sf = softmax(xs, dims=dims) | ||
sf .* (Δ .- sum(Δ .* sf, dims=dims)) | ||
end | ||
∇softmax!(Δ, xs) = ∇softmax!(Δ, Δ, xs) | ||
|
||
|
||
""" | ||
logsoftmax(x; dims=1) | ||
|
@@ -83,38 +60,22 @@ It is semantically equivalent to the following: | |
|
||
See also [`softmax`](@ref). | ||
""" | ||
function logsoftmax(xs::AbstractArray; dims=1) | ||
max_ = maximum(xs, dims=dims) | ||
exp_ = exp.(xs .- max_) | ||
log_ = log.(sum(exp_, dims=dims)) | ||
(xs .- max_) .- log_ | ||
end | ||
|
||
function logsoftmax!(out::AbstractVecOrMat, xs::AbstractVecOrMat) | ||
for j = 1:size(xs, 2) | ||
@inbounds begin | ||
xi_max = xs[1, j] | ||
for i = 1:size(out, 1) | ||
xi_max = max(xi_max, xs[i, j]) | ||
end | ||
s = zero(eltype(out)) | ||
for i = 1:size(out, 1) | ||
s += exp(xs[i, j] - xi_max) | ||
end | ||
for i = 1:size(out, 1) | ||
out[i, j] = xs[i, j] - log(s) - xi_max | ||
end | ||
end | ||
end | ||
return out | ||
logsoftmax(x; dims = 1) = logsoftmax!(similar(x), x; dims) | ||
logsoftmax!(x; dims = 1) = logsoftmax!(x, x; dims) | ||
function logsoftmax!(out::T, x::T; dims = 1) where {T<:AbstractArray} | ||
out .= x .- maximum(x; dims) | ||
# out .= out .- log.(sum(exp.(out); dims = dims)) # WARN: this will decrease performance. | ||
log_ = log.(sum(exp.(out); dims)) | ||
out .-= log_ | ||
end | ||
|
||
function ∇logsoftmax!(out::AbstractVecOrMat, Δ::AbstractVecOrMat, xs::AbstractVecOrMat) | ||
out .= Δ .- sum(Δ, dims=1) .* softmax(xs, dims=1) | ||
∇logsoftmax(Δ, x; dims = 1) = ∇logsoftmax!(similar(Δ), Δ, x; dims) | ||
∇logsoftmax!(Δ, x; dims = 1) = ∇logsoftmax!(Δ, Δ, x; dims) | ||
function ∇logsoftmax!(out::T, Δ::T, x::T; dims = 1) where {T<:AbstractArray} | ||
softmax!(out, x; dims) | ||
out .= Δ .- sum(Δ, dims = dims) .* out | ||
end | ||
|
||
∇logsoftmax(Δ, xs; dims=1) = Δ .- sum(Δ, dims=dims) .* softmax(xs, dims=dims) | ||
∇logsoftmax!(Δ, xs) = ∇logsoftmax!(Δ, Δ, xs) | ||
|
||
""" | ||
logsumexp(x; dims=:) | ||
|
@@ -124,8 +85,7 @@ way. | |
|
||
See also [`logsoftmax`](@ref). | ||
""" | ||
function logsumexp(xs::AbstractArray; dims=:) | ||
max_ = maximum(xs, dims=dims) | ||
log_ = log.(sum(exp.(xs .- max_), dims=dims)) | ||
return max_ .+ log_ | ||
function logsumexp(x::AbstractArray; dims = :) | ||
max_ = maximum(x; dims) | ||
max_ .+ log.(sum(exp.(x .- max_); dims)) | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,111 +1,75 @@ | ||
using Zygote | ||
using Zygote | ||
using Statistics: mean | ||
|
||
@testset "softmax" begin | ||
xs = rand(5,5) | ||
xs = rand(5, 5) | ||
@test all(sum(softmax(xs), dims = 1) .≈ 1) | ||
@test all(sum(softmax(xs; dims=2), dims = 2) .≈ 1) | ||
@test all(sum(softmax(xs; dims = 2), dims = 2) .≈ 1) | ||
@test sum(softmax(vec(xs))) ≈ 1 | ||
@test log.(softmax(xs; dims=2)) ≈ logsoftmax(xs; dims=2) | ||
@test log.(softmax(xs; dims = 2)) ≈ logsoftmax(xs; dims = 2) | ||
|
||
xs = [-100_000, -100_000.] | ||
xs = [-100_000.0, -100_000.0] | ||
@test softmax(xs) ≈ [0.5, 0.5] | ||
@test logsoftmax(xs) ≈ log.([0.5, 0.5]) | ||
|
||
xs = rand(5) | ||
@test softmax(xs) ≈ exp.(xs) ./ sum(exp.(xs)) | ||
@test logsoftmax(xs) ≈ log.(softmax(xs)) | ||
|
||
xs = Float32[1, 2, 3000.] | ||
xs = Float32[1, 2, 3000.0] | ||
@test logsoftmax(xs) ≈ [-2999, -2998, 0] | ||
|
||
xs = Float32[1 2 3; 1000 2000 3000] | ||
@test logsoftmax(xs) ≈ [-999 -1998 -2997; 0 0 0.] | ||
@test logsoftmax(xs) ≈ [-999 -1998 -2997; 0 0 0.0] | ||
|
||
@test NNlib.∇logsoftmax(ones(size(xs)), xs) ≈ Float32[1 1 1; -1 -1 -1] | ||
@test NNlib.∇softmax(ones(size(xs)), xs) ≈ zeros(Float32, size(xs)) | ||
@test ∇logsoftmax(ones(Float32, size(xs)), xs) ≈ Float32[1 1 1; -1 -1 -1] | ||
@test ∇softmax(ones(Float32, size(xs)), xs) ≈ zeros(Float32, size(xs)) | ||
|
||
# These values precalculated using PyTorch's nn.LogSoftmax | ||
xs = [ | ||
-0.238639 0.748142 -0.283194 -0.525461 -1.5348 -0.797842; | ||
0.690384 0.211427 0.254794 -0.213572 -0.314174 -0.372663; | ||
-1.146370 -0.577988 0.718952 0.919720 -0.620773 0.929977 | ||
-0.238639 0.748142 -0.283194 -0.525461 -1.5348 -0.797842 | ||
0.690384 0.211427 0.254794 -0.213572 -0.314174 -0.372663 | ||
-1.146370 -0.577988 0.718952 0.919720 -0.620773 0.929977 | ||
] | ||
ys = [ | ||
0.237703 -0.621474 0.448193 0.546047 0.564185 0.632273; | ||
-0.930163 0.0519798 0.0549979 0.3799 -0.477112 0.437428; | ||
0.237703 -0.621474 0.448193 0.546047 0.564185 0.632273 | ||
-0.930163 0.0519798 0.0549979 0.3799 -0.477112 0.437428 | ||
0.69246 0.569494 -0.503191 -0.925947 -0.0870738 -1.0697 | ||
] | ||
@test isapprox(NNlib.∇logsoftmax(ones(size(xs)), xs), ys; rtol = 1e-6) | ||
@test isapprox(NNlib.∇softmax(ones(size(xs)), xs), zeros(size(xs)); atol = 1e-6) | ||
@test isapprox(∇logsoftmax(ones(size(xs)), xs), ys; rtol = 1e-6) | ||
@test isapprox(∇softmax(ones(size(xs)), xs), zeros(size(xs)); atol = 1e-6) | ||
end | ||
|
||
@testset "mutating softmax" begin | ||
xs = Float64[1 2 3; 5 6 7] | ||
|
||
out = zeros(Float64, size(xs)) | ||
NNlib.softmax!(out, xs) | ||
@test isapprox(out, softmax(xs); rtol=1e-6) | ||
NNlib.logsoftmax!(out, xs) | ||
@test isapprox(out, logsoftmax(xs); rtol=1e-6) | ||
|
||
out = ones(Float64, size(xs)) | ||
NNlib.softmax!(out, xs) | ||
@test isapprox(out, softmax(xs); rtol=1e-6) | ||
NNlib.logsoftmax!(out, xs) | ||
@test isapprox(out, logsoftmax(xs); rtol=1e-6) | ||
|
||
out = zeros(Float64, size(xs)) | ||
NNlib.∇softmax!(out, xs) | ||
@test isapprox(out, NNlib.∇softmax(zeros(size(xs)), xs); rtol=1e-6) | ||
out = zeros(Float64, size(xs)) | ||
NNlib.∇logsoftmax!(out, xs) | ||
@test isapprox(out, NNlib.∇logsoftmax(zeros(size(xs)), xs); rtol=1e-6) | ||
|
||
out = ones(Float64, size(xs)) | ||
NNlib.∇softmax!(out, xs) | ||
@test isapprox(out, NNlib.∇softmax(ones(size(xs)), xs); rtol=1e-6) | ||
out = ones(Float64, size(xs)) | ||
NNlib.∇logsoftmax!(out, xs) | ||
@test isapprox(out, NNlib.∇logsoftmax(ones(size(xs)), xs); rtol=1e-6) | ||
|
||
xs = [ | ||
-0.238639 0.748142 -0.283194 -0.525461 -1.5348 -0.797842; | ||
0.690384 0.211427 0.254794 -0.213572 -0.314174 -0.372663; | ||
-1.146370 -0.577988 0.718952 0.919720 -0.620773 0.929977 | ||
] | ||
|
||
out = zeros(Float64, size(xs)) | ||
NNlib.softmax!(out, xs) | ||
@test isapprox(out, softmax(xs); rtol=1e-6) | ||
NNlib.logsoftmax!(out, xs) | ||
@test isapprox(out, logsoftmax(xs); rtol=1e-6) | ||
|
||
out = ones(Float64, size(xs)) | ||
NNlib.softmax!(out, xs) | ||
@test isapprox(out, softmax(xs); rtol=1e-6) | ||
NNlib.logsoftmax!(out, xs) | ||
@test isapprox(out, logsoftmax(xs); rtol=1e-6) | ||
|
||
out = zeros(Float64, size(xs)) | ||
NNlib.∇softmax!(out, xs) | ||
@test isapprox(out, NNlib.∇softmax(zeros(size(xs)), xs); rtol=1e-6) | ||
out = zeros(Float64, size(xs)) | ||
NNlib.∇logsoftmax!(out, xs) | ||
@test isapprox(out, NNlib.∇logsoftmax(zeros(size(xs)), xs); rtol=1e-6) | ||
|
||
out = ones(Float64, size(xs)) | ||
NNlib.∇softmax!(out, xs) | ||
@test isapprox(out, NNlib.∇softmax(ones(size(xs)), xs); rtol=1e-6) | ||
out = ones(Float64, size(xs)) | ||
NNlib.∇logsoftmax!(out, xs) | ||
@test isapprox(out, NNlib.∇logsoftmax(ones(size(xs)), xs); rtol=1e-6) | ||
map([ | ||
Float64[1 2 3; 5 6 7], | ||
Float64[ | ||
-0.238639 0.748142 -0.283194 -0.525461 -1.5348 -0.797842 | ||
0.690384 0.211427 0.254794 -0.213572 -0.314174 -0.372663 | ||
-1.146370 -0.577988 0.718952 0.919720 -0.620773 0.929977 | ||
], | ||
]) do xs | ||
out = similar(xs) | ||
softmax!(out, xs) | ||
@test isapprox(out, softmax(xs); rtol = 1e-6) | ||
logsoftmax!(out, xs) | ||
@test isapprox(out, logsoftmax(xs); rtol = 1e-6) | ||
|
||
map([zeros, ones]) do fn | ||
Δ = fn(Float64, size(xs)) | ||
∇softmax!(out, Δ, xs) | ||
@test isapprox(out, ∇softmax(Δ, xs); rtol = 1e-6) | ||
∇logsoftmax!(out, Δ, xs) | ||
@test isapprox(out, ∇logsoftmax(Δ, xs); rtol = 1e-6) | ||
end | ||
end | ||
end | ||
|
||
@testset "logsumexp" begin | ||
flogsoft(x; dims) = mean(x .- logsoftmax(x; dims=dims), dims=dims) | ||
x = rand(3,4) | ||
@test logsumexp(x) ≈ flogsoft(x, dims=:) | ||
@test logsumexp(x; dims=1) ≈ flogsoft(x, dims=1) | ||
@test gradient(x -> logsumexp(x), x)[1] ≈ gradient(x -> flogsoft(x, dims=:), x)[1] | ||
flogsoft(x; dims) = mean(x .- logsoftmax(x; dims = dims), dims = dims) | ||
|
||
x = rand(3, 4) | ||
@test logsumexp(x) ≈ flogsoft(x, dims = :) | ||
@test logsumexp(x; dims = 1) ≈ flogsoft(x, dims = 1) | ||
@test gradient(x -> logsumexp(x), x)[1] ≈ gradient(x -> flogsoft(x, dims = :), x)[1] | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we still support julia < 1.5, so you will have to explicitly pass the keyword value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, this pattern will error on integer array inputs, that we still want to support, so we should promote to float types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.