-
-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added softmax!. reduced softmax's memory usage. #247
Conversation
I think both changes are ok. |
1. added dims param in softmax! and similar functions. fixes #249 2. refactor code.
I'm glad to help. I have updated all the function, and refactored the test cases by removing duplicate code. The performance has increased a little, the usage of memory has reduced a lot. benchmark codeusing BenchmarkTools, NNlib
x = rand(1000,1000,16);
Δ = rand(1000,1000,16);
macro mb(expr)
:(median(@benchmark $expr samples=9 evals=1))
end
@mb softmax(x)
@mb logsoftmax(x)
@mb ∇softmax(Δ, x)
@mb ∇logsoftmax(Δ, x) benchmark results
|
src/softmax.jl
Outdated
max_ = maximum(xs, dims=dims) | ||
exp_ = exp.(xs .- max_) | ||
exp_ ./ sum(exp_, dims=dims) | ||
softmax(x; dims = 1) = softmax!(similar(x), x; dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we still support julia < 1.5, so you will have to explicitly pass the keyword value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, this pattern will error on integer array inputs, that we still want to support, so we should promote to float types.
similar(x, float(eltype(x)))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
src/softmax.jl
Outdated
exp_ ./ sum(exp_, dims=dims) | ||
softmax(x; dims = 1) = softmax!(similar(x), x; dims) | ||
softmax!(x; dims = 1) = softmax!(x, x; dims) | ||
function softmax!(out::T, x::T; dims = 1) where {T<:AbstractArray} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can relax the constraint of out
and x
being of the same type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated the constraint for softmax!
logsoftmax!
.
For ∇softmax
, I guess Δ
and out
should be the same type. Right?
I added only one test case for integer input. is this OK?
@test softmax(Int[0, 0]) == [0.5, 0.5]
very nice. After addressing the comments, it will be convenient to add some tests with integer inputs |
I just noticed #248, I'm not sure how to handle it, but thought it was worth mentioning here |
better do #248 in a separate PR where we also bump the minor version |
You could add a test for #249 and we should be ready |
Done. |
dims
parameter?softmax
is updated too, this can reduce memory usage by 40%.note:
The memory uses of
softmax!
is the same as the oldsoftmax
.The performance does not change.