Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow ReverseDiff Propagation #281

Open
taylormcd opened this issue Nov 1, 2022 · 0 comments
Open

Allow ReverseDiff Propagation #281

taylormcd opened this issue Nov 1, 2022 · 0 comments

Comments

@taylormcd
Copy link

This package is currently compatible with ForwardDiff, but not ReverseDiff.

using NLsolve, ForwardDiff, ReverseDiff

function residual!(r, y, x)
    r[1] = (y[1] + x[1])*(y[2]^3-x[2])+x[3]
    r[2] = sin(y[2]*exp(y[1])-1)*x[4]
end

function solve(x)
    TF = eltype(x)
    rwrap(r, y) = residual!(r, y, x[1:4])
    res = nlsolve(rwrap, TF[0.1; 1.2], autodiff=:forward)
    return res.zero
end

function program(x)
    z = 2.0*x
    w = z + x.^2
    y = solve(w)
    return y[1] .+ w*y[2]
end

x = [1.0, 2.0, 3.0, 4.0, 5.0]

ForwardDiff.jacobian(program, x)
# 5×5 Matrix{Float64}:
#   8.05247   1.94271   -0.95879   2.90746e-25   0.0
#   8.55572  14.3307    -3.0819    8.60063e-26   0.0
#  16.8073   12.2672     4.72726  -2.0063e-25    0.0
#  27.4165   20.0105    -9.87583  13.4769        0.0
#  40.3833   29.4746   -14.5467   -1.01959e-24  16.1723

ReverseDiff.jacobian(program, x)
# ERROR: UndefVarError: rT not defined

The reverse pass fails because the default constructor for SolverResults can't figure out the right type for rT. The fix is to define a more reliable constructor. For example:

mutable struct SolverResults{rT<:Real,T<:Union{rT,Complex{rT}},I<:AbstractArray{T},Z<:AbstractArray{T}}
    method::String
    initial_x::I
    zero::Z
    residual_norm::rT
    iterations::Int
    x_converged::Bool
    xtol::rT
    f_converged::Bool
    ftol::rT
    trace::SolverTrace
    f_calls::Int
    g_calls::Int
    # provide inner constructor (default inner constructor doesn't work for all cases)
    function SolverResults(method, initial_x, zero, residual_norm, iterations, x_converged, 
        xtol, f_converged, ftol, trace, f_calls, g_calls)

        # real type
        rT = promote_type(real(eltype(initial_x)), real(eltype(zero)), typeof(residual_norm), typeof(xtol), typeof(ftol))
        
        # real/complex type
        if promote_type(eltype(initial_x), eltype(zero)) <: Complex
            T = Complex{rT}
        else
            T = rT
        end

        # correct initial guess type
        if !(eltype(initial_x) <: T)
            initial_x = T.(initial_x)
        end

        # correct zero element type (if necessary)
        if !(eltype(zero) <: T)
            zero = T.(zero)
        end

        # initial guess type
        I = typeof(initial_x)

        # zero type
        Z = typeof(initial_x)

        return new{rT,T,I,Z}(method, initial_x, zero, residual_norm, iterations, 
            x_converged, xtol, f_converged, ftol, trace, f_calls, g_calls)
    end
end

Then the ReverseDiff derivatives propagate as expected.

# with modified implementation
ReverseDiff.jacobian(program, x)
# 5×5 Matrix{Float64}:
#   8.05247   1.94271   -0.95879  -4.91066e-28   0.0
#   8.55572  14.3307    -3.0819   -2.04315e-27   0.0
#  16.8073   12.2672     4.72726  -2.00371e-27   0.0
#  27.4165   20.0105    -9.87583  13.4769        0.0
#  40.3833   29.4746   -14.5467    0.0          16.1723

Note that this issue involves passing derivatives through the nonlinear solve, rather than defining a custom pullback for the nonlinear solve (as discussed in #205).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant