Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problems with variable indirect use #946

Open
RainerHeintzmann opened this issue Apr 17, 2021 · 12 comments
Open

Problems with variable indirect use #946

RainerHeintzmann opened this issue Apr 17, 2021 · 12 comments

Comments

@RainerHeintzmann
Copy link

We are trying to write an rrule for a custom array class, such that Zygote can differentiate through it, but are stuck due to an error about a missing adjoint for a constructor. This may well be a user error, but it could also be a problem of Zygote. Any help is appreciated!

using ChainRulesCore

struct Example{T,N,F} <: AbstractArray{T,N} where F
    sz::NTuple{N, Int}
    f::F
end

function Base.getindex(a::Example{T,N,F}, idx::Vararg{B,N}) where {T,N,F,B}
    a.f(idx)
end
Base.size(e::Example) = e.sz

function ChainRulesCore.rrule(::typeof(Example{Float64,2,F}),
    sz::NTuple{N, Int},
    gen::F,
    ) where {F,N}
    function IFA_pullback(ΔΩ)
        @show outer = ΔΩ  
        @show inner = Example{Float64,2,typeof(gen)}(sz, gen)
        ∂gen = outer .* inner # wrap in @thunk()
        @show  ∂gen
        return (NO_FIELDS, NO_FIELDS, ∂gen) # why do we need four here?
    end
    Ω = Example{Float64,2,typeof(gen)}(sz, gen)
    return (Ω, IFA_pullback)
end

This code seems to generally work fine for using the error, but the point is the needed ability to differentiate wrt a variable used in the innermost function. The code using this definitions, which then causes the error:

c(a) = begin
    g(idx)= idx[1]*idx[2]*a
    sum(Example{Float64,2,typeof(g)}((3,3),g))
end

using Zygote
gradient(c, 2)

The error looks like this:

julia> include("Scratch_05_GradientTest_Chain_.jl")
outer = ΔΩ = 3×3 Fill{Int64}: entries equal to 1
inner = Example{Float64, 2, typeof(gen)}(sz, gen) = [2 4 6; 4 8 12; 6 12 18]
∂gen = [2.0 4.0 6.0; 4.0 8.0 12.0; 6.0 12.0 18.0]
ERROR: Need an adjoint for constructor var"#g#6"{Int64}. Gradient is of type Matrix{Float64}
Stacktrace:
 [1] error(s::String)  
   @ Base .\error.jl:33
 [2] (::Zygote.Jnew{var"#g#6"{Int64}, Nothing, false})(Δ::Matrix{Float64})
   @ Zygote ~\.julia\packages\Zygote\RxTZu\src\lib\lib.jl:314
 [3] (::Zygote.var"#1720#back#194"{Zygote.Jnew{var"#g#6"{Int64}, Nothing, false}})(Δ::Matrix{Float64})
   @ Zygote ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [4] Pullback
   @ ~\Documents\Programming\Julia\Development\Scratch_05_GradientTest_Chain_.jl:29 [inlined]
 [5] (::Zygote.Pullback{Tuple{typeof(c), Int64}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#26"}, Zygote.var"#1566#back#125"{typeof(identity)}, Zygote.var"#1720#back#194"{Zygote.Jnew{var"#g#6"{Int64}, Nothing, false}}, Zygote.ZBack{var"#IFA_pullback#5"{Tuple{Int64, Int64}, var"#g#6"{Int64}}}, Zygote.var"#2646#back#601"{Zygote.var"#597#599"{Example{Float64, 2, var"#g#6"{Int64}}}}}})(Δ::Int64)
   @ Zygote ~\.julia\packages\Zygote\RxTZu\src\compiler\interface2.jl:0
 [6] (::Zygote.var"#41#42"{Zygote.Pullback{Tuple{typeof(c), Int64}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#26"}, Zygote.var"#1566#back#125"{typeof(identity)}, Zygote.var"#1720#back#194"{Zygote.Jnew{var"#g#6"{Int64}, Nothing, false}}, Zygote.ZBack{var"#IFA_pullback#5"{Tuple{Int64, Int64}, var"#g#6"{Int64}}}, Zygote.var"#2646#back#601"{Zygote.var"#597#599"{Example{Float64, 2, var"#g#6"{Int64}}}}}}})(Δ::Int64)
   @ Zygote ~\.julia\packages\Zygote\RxTZu\src\compiler\interface.jl:41
 [7] gradient(f::Function, args::Int64)
   @ Zygote ~\.julia\packages\Zygote\RxTZu\src\compiler\interface.jl:59
 [8] top-level scope
   @ ~\Documents\Programming\Julia\Development\Scratch_05_GradientTest_Chain_.jl:34

Something similar happens, if you place the variable a right behind sum( in function c.

@DhairyaLGandhi
Copy link
Member

For Zygote, it would be better to use Composite. I would try wot use @adjoint here

@RainerHeintzmann
Copy link
Author

RainerHeintzmann commented Apr 17, 2021

Thanks for the great hint. Can you point me to an example or a documentation on Composite and canonicalize()? The documentation of these functions did not allow me to figure out how to define these correctly. Presumably one needs to define struct Composite and then call canonicalize with an example of such a struct?

@mzgubic
Copy link
Collaborator

mzgubic commented Apr 19, 2021

You can have a look at the ChainRules documentation, and for examples see the ChainRules package.

@MikeInnes
Copy link
Member

We've discussed this over email but I thought the short version might as well be recorded here. The reason for the error is that a matrix, ∂gen is used as the gradient of a closure g. When it comes to unpacking ∂gen to get the gradient of a, Zygote doesn't know what to do. In this example you'd want to use Zygote._pullback inside the rule to get the gradient of g.

@devmotion
Copy link
Collaborator

Another problem here seems to be that you define the rrule for ::typeof(Example{...}) instead of ::Type{Example{...}} or ::Type{<:Example{...}}: https://juliadiff.org/ChainRulesCore.jl/previews/PR331/writing_good_rules.html#Use-Type{T},-not-typeof(T),-to-define-rules-for-constructors (it's not mentioned in the official documentation yet but only part of a PR).

@RainerHeintzmann
Copy link
Author

Thanks for picking this up! Yet the error message remains the same...

@RainerHeintzmann
Copy link
Author

RainerHeintzmann commented May 4, 2021

This is my attempt trying to use the suggestion of @MikeInnes :

function ChainRulesCore.rrule(::Type{Example{Float64,2,F}}, 
    sz::NTuple{N, Int},
    gen::F,
    ) where {F,N}
    val_grad(x) = Zygote._pullback(gen, x)[2](1.0) 
    gradgen(x) = val_grad(x)[1][:a] 
    function IFA_pullback(ΔΩ)    
        inner = Example{Float64,2,typeof(gradgen)}(sz, gradgen) 
        ∂gen = ΔΩ .*  inner 
        @show ∂gen
        return (NO_FIELDS,NO_FIELDS,∂gen) 
    end
    Ω = Example{Float64,2,typeof(gen)}(sz, gen)
    return (Ω, IFA_pullback)
end

As you see by the @show ∂gen this code seems to get pretty far, yet somehow the output still leaves Zygote stuck:

julia> gradient(c, 2.0)
∂gen = [4.0 8.0 12.0; 4.0 8.0 12.0; 4.0 8.0 12.0]
ERROR: Need an adjoint for constructor var"#g#14"{Float64}. Gradient is of type Matrix{Float64}
Stacktrace:
 [1] error(s::String)
   @ Base .\error.jl:33
 [2] (::Zygote.Jnew{var"#g#14"{Float64}, Nothing, false})(Δ::Matrix{Float64})
   @ Zygote ~\.julia\packages\Zygote\6HN9x\src\lib\lib.jl:314
 [3] (::Zygote.var"#1723#back#196"{Zygote.Jnew{var"#g#14"{Float64}, Nothing, false}})(Δ::Matrix{Float64})
   @ Zygote ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [4] Pullback
   @ ~\Documents\Programming\Julia\Development\TestingZygote.jl:63 [inlined]
 [5] (::typeof((c)))(Δ::Float64)
   @ Zygote ~\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0
 [6] (::Zygote.var"#41#42"{typeof((c))})(Δ::Float64)
   @ Zygote ~\.julia\packages\Zygote\6HN9x\src\compiler\interface.jl:41
 [7] gradient(f::Function, args::Float64)
   @ Zygote ~\.julia\packages\Zygote\6HN9x\src\compiler\interface.jl:59
 [8] top-level scope

@MikeInnes
Copy link
Member

Again though, you're giving a matrix ∂gen as the gradient for the closure g, so it's the same issue as before. The gradient (for this specific g) should be a named tuple of the form (a = da::Real,); that's something the pullback rule you've defined has to get right.

I suspect the right gradient here would be (a = sum(∂gen),), but that would only work for this specific closure, since others might have more than one capture or call it something other than a. So :a shouldn't appear in the code.

Instead, you want to do something like broadcast the pullback of g over ΔΩ. That gets you a matrix of named tuples, which you can sum with Zygote.accum_sum (which is like sum but supports named tuples).

@RainerHeintzmann
Copy link
Author

Thanks @MikeInnes, for this hint. It took me ages to understand that not the returned Tuple needs to be a named tuple but the third of its elements. As far as I can see, there is no need to involve sum or Zygote.accum_sum but its useful to know that they exist. Here is now an implementation, which should hopefully also work for slightly more general cases:

using ChainRulesCore
using Zygote

struct Example{T,N,F} <: AbstractArray{T,N} where F
    sz::NTuple{N, Int}
    f::F
end

function Base.getindex(a::Example{T,N,F}, idx::Vararg{B,N}) where {T,N,F,B}
    a.f(idx)
end
Base.size(e::Example) = e.sz

function ChainRulesCore.rrule(::Type{Example{Float64,2,F}}, 
    sz::NTuple{N, Int},
    gen::F,
    ) where {F,N}
    val_grad(idx) = Zygote._pullback(gen, idx)[2](1.0) # 1.0 is only the seed
    mySymbols = keys(val_grad(sz)[1])
    gradgen(idx) = val_grad(idx)[1] 
    function IFA_pullback(ΔΩ)   
        Fcts = ((idx)-> val_grad(idx)[1][aSymbol] for aSymbol in mySymbols)
        TupleVals = (ΔΩ .* Example{Float64,2,typeof(Fun)}(sz, Fun) for Fun in Fcts)
        ∂gen = NamedTuple{mySymbols}(TupleVals)
        return (NO_FIELDS, NO_FIELDS, ∂gen) 
    end
    Ω = Example{Float64,2,typeof(gen)}(sz, gen)
    return (Ω, IFA_pullback)
end

c(a) = begin
    g(idx)= idx[1] + idx[2] *a*a
    myarr = Example{Float64,2,typeof(g)}((3,3),g)  # 3,3 refers to size
    sum(myarr)  
end

The output looks like this:

julia> gradient(c, 2.0)
([4.0 8.0 12.0; 4.0 8.0 12.0; 4.0 8.0 12.0],)

Pheew. That took longer than planned ;-)

@RainerHeintzmann
Copy link
Author

Does anyone know if Zygote._pullback(gen, idx)[2](1.0) can be avoided or replaced with a function in ChainRulesCore? It would be nice to avoid the dependence of our package on Zygote.

@devmotion
Copy link
Collaborator

Only if you know the differential explicitly or if there exists an rrule that you can call. Currently, ChainRules does not support calling back into the AD system (such as Zygote): JuliaDiff/ChainRulesCore.jl#68

@mzgubic
Copy link
Collaborator

mzgubic commented May 17, 2021

The plan is to allow this by JuliaCon, so watch the issue @devmotion posted in case you will find this useful

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants