Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

f/rrules should support receiving ZeroTangent() #442

Open
mzgubic opened this issue Jun 11, 2021 · 7 comments
Open

f/rrules should support receiving ZeroTangent() #442

mzgubic opened this issue Jun 11, 2021 · 7 comments
Labels
design Requires some design before changes are made needs-careful-thought A reminder that this thing is not obviouis and care must be taken when working on it
Milestone

Comments

@mzgubic
Copy link
Member

mzgubic commented Jun 11, 2021

Similarly to #408, ZeroTangent() is a valid input to the pullback, and we need to make sure it is supported.

Using JuliaDiff/ChainRulesTestUtils.jl#176, there are at least three kinds of errors:

  1. Pullbacks are written such that they do not support taking in ZeroTangent(), e.g.
    MethodError: no method matching (::ChainRules.var"#transpose_pullback#1894")(::ZeroTangent). These just need to be fixed in ChainRules.jl
  2. Places where we (I think?) have to project the ZeroTangent():
    TypeError: in Hermitian, in S, expected S<:(AbstractArray{var"#s832", 2} where var"#s832"<:T), got Type{ZeroTangent}
    and
MethodError: Cannot `convert` an object of type 
    ZeroTangent to an object of type 
    Matrix{T} where T
  1. Errors which could be solved by projecting the ZeroTangent() e.g.
    MethodError: no method matching getindex(::ZeroTangent, ::Int64). The question is whether we actually want to project to an array, since that would allocate quite a bit. Alternatively, we could define Base.getindex(::ZeroTangent, args...) = ZeroTangent(). There might be quite a few of these functions to define, but it would be much faster.
    Some examples are:
  • MethodError: no method matching Complex(::ZeroTangent)
  • MethodError: no method matching mapfoldl(::typeof(identity), ::typeof(Base.add_sum), ::ZeroTangent; dims=Colon())
  • MethodError: no method matching tr(::ZeroTangent)
  • MethodError: no method matching mul!(::ZeroTangent, ::Matrix{Float64}, ::ZeroTangent, ::Bool, ::Bool)
  • MethodError: no method matching trsyl!(::Char, ::Char, ::Matrix{ComplexF64}, ::Matrix{ComplexF64}, ::ZeroTangent)
  • MethodError: no method matching size(::ZeroTangent, ::Int64)
  • MethodError: no method matching LowerTriangular(::ZeroTangent)
    and the list goes on

This is just a quick (incomplete) dump of observations and first thoughts. I may have missed kinds of errors, or said things which are untrue.

Alltogether:

Test Summary:     |  Pass  Error  Broken  Total
ChainRules          | 24330   429    4    24763
@mzgubic mzgubic added this to the v1 milestone Jun 11, 2021
@mzgubic mzgubic added design Requires some design before changes are made needs-careful-thought A reminder that this thing is not obviouis and care must be taken when working on it labels Jun 11, 2021
@oxinabox
Copy link
Member

I think most of the oens under 3 should be solved via implementing them for ZeroTangent
without consideration for projecting.
While there are many they are finite in number.
and the answers are mostly obvious, because it is general ZeroTanget() because linear operators map zero to zero.
In some cases it isn't but we would be able to get through a lot of them, pretty quickly.
And it is the same set as for #408

@mcabbott
Copy link
Member

Should every rrule have a method for this, or should it be handled at a higher level, by not calling the rule at all? The former seems like a lot of boilerplate. Maybe the answer depends on this:

In some cases it isn't

What cases are there where a zero tangent should become something nonzero?

@oxinabox
Copy link
Member

The case we have to worry about is functions with multiple inputs, some of which are zero, and some of which are not.

Which as I say that, I realize makes no sense: pullbacks always have 1 input, because julia functions are all single output (it is just sometimes that output is an iterator).
So yes, I think this can be handled at a higher level by not calling the pullback at all.

@mcabbott
Copy link
Member

Yes, I guess if your function returns a tuple, then it may have to worry about getting back (Zero(), something). Maybe this one would work?

julia> y, b = pullback(findmax, rand(3));

julia> b((1,2))
([0.0, 1.0, 0.0],)

julia> b((pi,3))
([0.0, 3.141592653589793, 0.0],)

julia> b((pi,"nothing"))
ERROR: MethodError: no method matching +(::Int64, ::String)
  [1] accum(x::Int64, y::String)
    @ Zygote ~/.julia/packages/Zygote/0da6K/src/lib/lib.jl:17
  [2] macro expansion
    @ ~/.julia/packages/Zygote/0da6K/src/lib/lib.jl:27 [inlined]
  [3] accum(x::NamedTuple{(:first, :second), Tuple{Int64, Irrational{:π}}}, y::NamedTuple{(:first, :second), Tuple{String, Float64}})
    @ Zygote ~/.julia/packages/Zygote/0da6K/src/lib/lib.jl:27
  [4] accum
    @ ~/.julia/packages/Zygote/0da6K/src/lib/lib.jl:17 [inlined]
  [5] (::typeof(∂(#250)))(Δ::Tuple{Irrational{:π}, String})
    @ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface2.jl:0
  [6] Pullback
    @ ./reduce.jl:95 [inlined]
  [7] (::typeof(∂(Base.MappingRF{Base.var"#250#251"{typeof(identity)}, Base.BottomRF{typeof(Base._rf_findmax)}}(Base.var"#250#251"{typeof(identity)}(identity), Base.BottomRF{typeof(Base._rf_findmax)}(Base._rf_findmax)))))(Δ::Tuple{Irrational{:π}, String})
    @ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface2.jl:0

@oxinabox
Copy link
Member

oxinabox commented Jun 30, 2021

(Zero(), something) should be Tangent{Tuple{Int, Float64}(ZeroTangent(), something).
Since Tuple is not a valid tangent type since it doesn't support zero or +.
Though a lot of our methods do let you pass in any iterator, when you really should only be allowed to pass in a Tangent{Tuple}.
(or some otehr iterator that overloads zero and + etc)

@mzgubic
Copy link
Member Author

mzgubic commented Jul 1, 2021

What do you mean by "handled at a higher level"? Like handled automatically by the AD system, or making rrule a macro which add a line for treating ZeroTangent automatically? Or is there another way?

@oxinabox
Copy link
Member

oxinabox commented Jul 1, 2021

Handled in the AD system before calling the pullback.
like this line in Zygote https://github.com/FluxML/Zygote.jl/blob/1082ebd3aced63b99c4b6c2956a122ce6a37f97d/src/compiler/chainrules.jl#L94
and this is where we would change Nabla https://github.com/invenia/Nabla.jl/pull/189/files#r662148541

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Requires some design before changes are made needs-careful-thought A reminder that this thing is not obviouis and care must be taken when working on it
Projects
None yet
Development

No branches or pull requests

3 participants