Skip to content

Simpler defaults without FiniteDifferences special cases #96

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 24 additions & 70 deletions src/AbstractDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,8 @@ function value_and_gradient(ab::AbstractBackend, f, xs...)
return value, reshape.(adjoint.(jacs),size.(xs))
end
function value_and_jacobian(ab::AbstractBackend, f, xs...)
local value
primalcalled = false
if lowest(ab) isa AbstractFiniteDifference
value = primal_value(ab, nothing, f, xs)
primalcalled = true
end
jacs = jacobian(lowest(ab), (_xs...,) -> begin
v = f(_xs...)
if !primalcalled
value = primal_value(ab, v, f, xs)
primalcalled = true
end
return v
end, xs...)

value = f(xs...)
jacs = jacobian(lowest(ab), f, xs...)
return value, jacs
end
function value_and_hessian(ab::AbstractBackend, f, x)
Expand All @@ -89,71 +76,54 @@ function value_and_hessian(ab::AbstractBackend, f, x)
x = only(x)
end

local value
primalcalled = false
if ab isa AbstractFiniteDifference
value = primal_value(ab, nothing, f, (x,))
primalcalled = true
end
value = f(x)
hess = jacobian(second_lowest(ab), _x -> begin
v, g = value_and_gradient(lowest(ab), f, _x)
if !primalcalled
value = primal_value(ab, v, f, (x,))
primalcalled = true
end
g = gradient(lowest(ab), f, _x)
return g[1] # gradient returns a tuple
end, x)

return value, hess
end
function value_and_hessian(ab::HigherOrderBackend, f, x)
if x isa Tuple
# only support computation of Hessian for functions with single input argument
x = only(x)
end
local value
primalcalled = false

value = f(x)
hess = jacobian(second_lowest(ab), (_x,) -> begin
v, g = value_and_gradient(lowest(ab), f, _x)
if !primalcalled
value = primal_value(ab, v, f, (x,))
primalcalled = true
end
g = gradient(lowest(ab), f, _x)
return g[1] # gradient returns a tuple
end, x)

return value, hess
end
function value_gradient_and_hessian(ab::AbstractBackend, f, x)
if x isa Tuple
# only support computation of Hessian for functions with single input argument
x = only(x)
end
local value
primalcalled = false

value = f(x)
grads, hess = value_and_jacobian(second_lowest(ab), _x -> begin
v, g = value_and_gradient(lowest(ab), f, _x)
if !primalcalled
value = primal_value(second_lowest(ab), v, f, (x,))
primalcalled = true
end
g = gradient(lowest(ab), f, _x)
return g[1] # gradient returns a tuple
end, x)

return value, (grads,), hess
end
function value_gradient_and_hessian(ab::HigherOrderBackend, f, x)
if x isa Tuple
# only support computation of Hessian for functions with single input argument
x = only(x)
end
local value
primalcalled = false

value = f(x)
grads, hess = value_and_jacobian(second_lowest(ab), _x -> begin
v, g = value_and_gradient(lowest(ab), f, _x)
if !primalcalled
value = primal_value(second_lowest(ab), v, f, (x,))
primalcalled = true
end
g = gradient(lowest(ab), f, _x)
return g[1] # gradient returns a tuple
end, x)

return value, (grads,), hess
end

Expand All @@ -180,26 +150,16 @@ function value_and_pushforward_function(
f,
xs...,
)
return (ds) -> begin
n = length(xs)
value = f(xs...)
pf_function = pushforward_function(lowest(ab), f, xs...)

return ds -> begin
if !(ds isa Tuple)
ds = (ds,)
end
@assert length(ds) == length(xs)
local value
primalcalled = false
if ab isa AbstractFiniteDifference
value = primal_value(ab, nothing, f, xs)
primalcalled = true
end
pf = pushforward_function(lowest(ab), (_xs...,) -> begin
vs = f(_xs...)
if !primalcalled
value = primal_value(lowest(ab), vs, f, xs)
primalcalled = true
end
return vs
end, xs...)(ds)

@assert length(ds) == n
pf = pf_function(ds)
return value, pf
end
end
Expand Down Expand Up @@ -476,12 +436,6 @@ macro primitive(expr)
return define_pushforward_function_and_friends(fdef) |> esc
elseif name == :value_and_pullback_function
return define_value_and_pullback_function_and_friends(fdef) |> esc
elseif name == :jacobian
return define_jacobian_and_friends(fdef) |> esc
elseif name == :primal_value
return define_primal_value(fdef) |> esc
elseif name == :pullback_function
return define_pullback_function_and_friends(fdef) |> esc
else
throw("Unsupported AD primitive.")
end
Expand Down