Skip to content

Do gradient via mutating and unmutating cell #59

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 2, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions src/grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,40 @@ Approximate the gradient of `f` at `xs...` using `fdm`. Assumes that `f(xs...)`
"""
function grad end

function grad(fdm, f, x::AbstractArray{T}) where T <: Number
function _grad(fdm, f, x::AbstractArray{T}) where T <: Number
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't technically matter, but maybe change / remove the type constraint on x, since we know it's going to be an Array by virtue of the fact it's called from grad

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be a a few things actually.
But yes, the constraints are not needed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am going to keep it because the T is useful

# x must be mutable, we will mutate it and then mutate it back.
dx = similar(x)
tmp = similar(x)
for k in eachindex(x)
dx[k] = fdm(zero(T)) do ϵ
tmp .= x
tmp[k] += ϵ
return f(tmp)
xk = x[k]
x[k] = xk + ϵ
ret = f(x)
x[k] = xk # Can't do `x[k] -= ϵ` as floating-point math is not associative
return ret
end
end
return (dx, )
end

grad(fdm, f, x::Array{<:Number}) = _grad(fdm, f, x)
# Fallback for when we don't know `x` will be mutable:
grad(fdm, f, x::AbstractArray{<:Number}) = _grad(fdm, f, similar(x).=x)

grad(fdm, f, x::Real) = (fdm(f, x), )
grad(fdm, f, x::Tuple) = (grad(fdm, (xs...)->f(xs), x...), )

function grad(fdm, f, d::Dict{K, V}) where {K, V}
dd = Dict{K, V}()
∇d = Dict{K, V}()
for (k, v) in d
dk = d[k]
function f′(x)
tmp = copy(d)
tmp[k] = x
return f(tmp)
d[k] = x
return f(d)
end
dd[k] = grad(fdm, f′, v)[1]
∇d[k] = grad(fdm, f′, v)[1]
d[k] = dk
end
return (dd, )
return (∇d, )
end

function grad(fdm, f, x)
Expand Down