Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility issues of Zygote with results obtained from ForwardDiff #1189

Open
kishore-nori opened this issue Mar 23, 2022 · 5 comments
Open
Labels
second order zygote over zygote, or otherwise

Comments

@kishore-nori
Copy link

kishore-nori commented Mar 23, 2022

Hi everyone! First of all, thank you so much for the Flux and Zygote packages, very useful and great usage experience :D.

I have the following incompatibility when Zygote.gradient is used on a result from ForwardDiff. Following are the failing MWEs:

Specifically, I am trying to take the Zygote gradient of the Laplacian of a function, which has been computed using ForwardDiff as follows,

using Zygote: gradient as zygrad 
using ForwardDiff
using LinearAlgebra 

function laplacian_fd(f::Function, x)
  tr(ForwardDiff.jacobian(y->ForwardDiff.gradient(f,y), x))
end

g(a, x) = x[1]*sin(x[2]*norm(a)) # let's call a as parameters

h(a) = laplacian_fd(x->g(a,x), SVector(0.1,0.2))

h([0.1,0.2,0.3])

zygrad(h,[0.1,0.2,0.3]) # fails
julia> h([0.1,0.2,0.3])
-0.0010466865222526678

julia> zygrad(h,[0.1,0.2,0.3]) # fails
ERROR: Need an adjoint for constructor SVector{2, Float64}. Gradient is of type SparseArrays.SparseVector{Float64, Int64}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.Jnew{SVector{2, Float64}, Nothing, false})(Δ::SparseArrays.SparseVector{Float64, Int64})
    @ Zygote ~/.julia/packages/Zygote/3I4nT/src/lib/lib.jl:324
  [3] (::Zygote.var"#1788#back#229"{Zygote.Jnew{SVector{2, Float64}, Nothing, false}})(Δ::SparseArrays.SparseVector{Float64, Int64})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [4] Pullback
    @ ~/.julia/packages/StaticArrays/NQjQM/src/SArray.jl:23 [inlined]
  [5] (::typeof((SVector{2, Float64})))(Δ::SparseArrays.SparseVector{Float64, Int64})
    @ Zygote ~/.julia/packages/Zygote/3I4nT/src/compiler/interface2.jl:0
  [6] Pullback
    @ ~/.julia/packages/StaticArrays/NQjQM/src/SVector.jl:20 [inlined]
  [7] Pullback
    @ ~/.julia/packages/StaticArrays/NQjQM/src/SVector.jl:19 [inlined]
  [8] Pullback
    @ ~/.julia/packages/StaticArrays/NQjQM/src/convert.jl:4 [inlined]
  [9] Pullback
    @ ./REPL[48]:1 [inlined]
 [10] (::typeof((h)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/3I4nT/src/compiler/interface2.jl:0
 [11] (::Zygote.var"#57#58"{typeof((h))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/3I4nT/src/compiler/interface.jl:41
 [12] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/3I4nT/src/compiler/interface.jl:76
 [13] top-level scope
    @ REPL[50]:1

Although not my intended usage

zygrad(h, SVector(0.1,0.2,0.3)) # also returns the same error as above 

But the most surprising "error" or problem is the following, which was going to be my backup if I can't have a fix for the above. Now, I use Vector for both variables (x) - the internal AD part and parameters (a) - the outer AD part.

julia> h(a) = laplacian_fd(x->g(a,x), [0.1,0.2])

julia> h([0.1,0.2,0.3])
-0.0010466865222526678

julia> zygrad(h,[0.1,0.2,0.3])
(nothing,) 
# :(

I would be very happy to know fixes or alternatives for the above two procedures, especially the first one. One brute force fix I have is to use ForwardDiff for the outer gradient as well but this is very slow, as expected, for reasonably high number of parameters. I am using ForwardDiffs inside as the dim(x) in my case is not going to be large.

@mcabbott
Copy link
Member

The problem is that Zygote.forwarddiff(f, x) only tracks the effect of x, not f or any parameters closed over by it, such as a here. And this is what's called by ForwardDiff.jacobian(f, x) within Zygote.

This was added in #968, and should probably be removed, it's too much of a footgun.

What you should probably try is the other order, ForwardDiff over Zygote.

@kishore-nori
Copy link
Author

Thank you for your comment explaining the reasons for the failure @mcabbott. Can functors help in this case to remove the shadow on the parameters, so that their effects can be tracked again?

Sure I ll try out them in the other order to check if it works, but in my situation the dim(x) << dim(a) so I was little sceptical on the computational time if I choose this order. But my intuition may not be true since the gradient with respect to a has to actually take the AD of the code generated by laplacian

ForwardDiff.gradient on ForwardDiff based laplacian works very well, but may not be the best for the case where dim(a) is large (~ 500).

An other thing that I am trying is to use ReverseDiff on the laplacian defined using ForwardDiff, which works but has some type conversion issues in my exact application where x may not be a Vector or SVector but a Struct.

@kishore-nori
Copy link
Author

It would be great to also if there are fixes for the following error in the above MWE. I had this coming up in other places as well even when there was no gradient over laplacian kind of a situation

ERROR: Need an adjoint for constructor SVector{2, Float64}. Gradient is of type SparseArrays.SparseVector{Float64, Int64}

@mcabbott
Copy link
Member

Can functors help

I don't think so, it's worse than that.

If ReverseDiff-over-ForwardDiff does work, then how it works might have clues to making Zygote-over-ForwardDiff work.

but in my situation the dim(x) << dim(a) so I was little sceptical on the computational time if I choose this order.

500 is not crazy for ForwardDiff. Note that even when it does work well, the overload (which I think we should remove) turns Zygote-over-ForwardDiff (error) into Forward-over-Forward, so it won't help with algorithmic complexity concerns.

ERROR: Need an adjoint for constructor SVector

This is a longstanding problem, and I think you will find sketches of how to fix it in old issues. If someone tidied them up & made a PR that would be great. But someone has to own it.

@kishore-nori
Copy link
Author

Thank you for the insights! I ll try to dig deeper with ReverseDiff over ForwardDiff

The Zygote.adjoint construction for StaticArrays, tried out here (#570 (comment)) works. Now the error is gone and both SVector and Vector are on the same page resulting in (nothing,). Just for the sake of completeness, we have the following:

@Zygote.adjoint (T::Type{<:SArray})(x::Number...) = T(x...), y->(nothing, y...)

julia> zygrad(h,[0.1,0.2,0.3])
(nothing,)

julia> zygrad(h,SVector(0.1,0.2,0.3))
(nothing,)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
second order zygote over zygote, or otherwise
Projects
None yet
Development

No branches or pull requests

2 participants