You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This package is currently compatible with ForwardDiff, but not ReverseDiff.
using NLsolve, ForwardDiff, ReverseDiff
functionresidual!(r, y, x)
r[1] = (y[1] + x[1])*(y[2]^3-x[2])+x[3]
r[2] =sin(y[2]*exp(y[1])-1)*x[4]
endfunctionsolve(x)
TF =eltype(x)
rwrap(r, y) =residual!(r, y, x[1:4])
res =nlsolve(rwrap, TF[0.1; 1.2], autodiff=:forward)
return res.zero
endfunctionprogram(x)
z =2.0*x
w = z + x.^2
y =solve(w)
return y[1] .+ w*y[2]
end
x = [1.0, 2.0, 3.0, 4.0, 5.0]
ForwardDiff.jacobian(program, x)
# 5×5 Matrix{Float64}:# 8.05247 1.94271 -0.95879 2.90746e-25 0.0# 8.55572 14.3307 -3.0819 8.60063e-26 0.0# 16.8073 12.2672 4.72726 -2.0063e-25 0.0# 27.4165 20.0105 -9.87583 13.4769 0.0# 40.3833 29.4746 -14.5467 -1.01959e-24 16.1723
ReverseDiff.jacobian(program, x)
# ERROR: UndefVarError: rT not defined
The reverse pass fails because the default constructor for SolverResults can't figure out the right type for rT. The fix is to define a more reliable constructor. For example:
mutable struct SolverResults{rT<:Real,T<:Union{rT,Complex{rT}},I<:AbstractArray{T},Z<:AbstractArray{T}}
method::String
initial_x::I
zero::Z
residual_norm::rT
iterations::Int
x_converged::Bool
xtol::rT
f_converged::Bool
ftol::rT
trace::SolverTrace
f_calls::Int
g_calls::Int# provide inner constructor (default inner constructor doesn't work for all cases)functionSolverResults(method, initial_x, zero, residual_norm, iterations, x_converged,
xtol, f_converged, ftol, trace, f_calls, g_calls)
# real type
rT =promote_type(real(eltype(initial_x)), real(eltype(zero)), typeof(residual_norm), typeof(xtol), typeof(ftol))
# real/complex typeifpromote_type(eltype(initial_x), eltype(zero)) <:Complex
T = Complex{rT}
else
T = rT
end# correct initial guess typeif!(eltype(initial_x) <:T)
initial_x =T.(initial_x)
end# correct zero element type (if necessary)if!(eltype(zero) <:T)
zero =T.(zero)
end# initial guess type
I =typeof(initial_x)
# zero type
Z =typeof(initial_x)
returnnew{rT,T,I,Z}(method, initial_x, zero, residual_norm, iterations,
x_converged, xtol, f_converged, ftol, trace, f_calls, g_calls)
endend
Then the ReverseDiff derivatives propagate as expected.
Note that this issue involves passing derivatives through the nonlinear solve, rather than defining a custom pullback for the nonlinear solve (as discussed in #205).
The text was updated successfully, but these errors were encountered:
This package is currently compatible with ForwardDiff, but not ReverseDiff.
The reverse pass fails because the default constructor for
SolverResults
can't figure out the right type forrT
. The fix is to define a more reliable constructor. For example:Then the ReverseDiff derivatives propagate as expected.
Note that this issue involves passing derivatives through the nonlinear solve, rather than defining a custom pullback for the nonlinear solve (as discussed in #205).
The text was updated successfully, but these errors were encountered: