Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type inference fails for custom type broadcasting #1359

Closed
tansongchen opened this issue Jan 13, 2023 · 6 comments
Closed

Type inference fails for custom type broadcasting #1359

tansongchen opened this issue Jan 13, 2023 · 6 comments

Comments

@tansongchen
Copy link

Let's say I have a broadcast over an array of custom type. The following MWE warns that the return type of pb(d_result) cannot be determined:

using Zygote, InteractiveUtils
import Base: exp

struct MyFloat64 <: Number
    n::Float64
end

exp(f::MyFloat64) = MyFloat64(exp(f.n))
my_vector = MyFloat64[1., 2., 3.]
result, pb = Zygote._pullback(broadcast, exp, my_vector)
d_result = MyFloat64[1., 1., 1.]
@code_warntype pb(d_result)

output is

MethodInstance for (::Zygote.var"#3909#back#952"{typeof(∂(broadcasted))})(::Vector{MyFloat64})
  from (::Zygote.var"#3909#back#952")(Δ) in Zygote at /Users/tansongchen/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
Arguments
  #self#::Zygote.var"#3909#back#952"{typeof(∂(broadcasted))}
  Δ::Vector{MyFloat64}
Body::Tuple{Nothing, Any, Any}
1 ─ %1 = Core.getfield(#self#, Symbol("#3908#_back"))::typeof(∂(broadcasted))
│   %2 = ZygoteRules.unthunk_tangent(Δ)::Vector{MyFloat64}
│   %3 = (%1)(%2)::Tuple{Nothing, Any, Any}
│   %4 = (ZygoteRules.gradtuple0)(%3)::Tuple{Nothing, Any, Any}
└──      return %4

To my best knowledge of Zygote, the magic happens here, but I couldn't tell what happens in the pullback function so that it's type unstable. Real and Complex are rescued by ForwardDiff, but not for custom types.

p.s. Found a similar issue #885 which is two years ago, but it didn't really get solved.

@tansongchen tansongchen changed the title Zygote type inference fails for custom type broadcasting Type inference fails for custom type broadcasting Jan 13, 2023
@ToucheSir
Copy link
Member

Cthulhu reports a Union{} inside https://github.com/FluxML/Zygote.jl/blob/v0.6.54/src/lib/broadcast.jl#L205-L209, and running the pullback shows why:

julia> pb(d_result)
ERROR: MethodError: no method matching conj(::MyFloat64)
Closest candidates are:
  conj(::Union{LinearAlgebra.Hermitian{T, S}, LinearAlgebra.Symmetric{T, S}} where {T, S}) at ~/.julia/juliaup/julia-1.8.5+0.x64.linux.gnu/share/julia/stdlib/v1.8/LinearAlgebra/src/symmetric.jl:368
  conj(::SparseArrays.SparseVector{<:Complex}) at ~/.julia/juliaup/julia-1.8.5+0.x64.linux.gnu/share/julia/stdlib/v1.8/SparseArrays/src/sparsevector.jl:1215
  conj(::ChainRulesCore.Tangent) at ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/tangent.jl:167
  ...
Stacktrace:
  [1] exp_pullback
    @ ~/.julia/packages/ChainRules/RZYEu/src/rulesets/Base/fastmath_able.jl:56 [inlined]
  [2] ZBack
    @ ~/.julia/packages/Zygote/AS0Go/src/compiler/chainrules.jl:206 [inlined]
  [3] (::Zygote.var"#938#943")(::Tuple{MyFloat64, Zygote.ZBack{ChainRules.var"#exp_pullback#1313"{MyFloat64, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}}}}, ȳ₁::MyFloat64)
    @ Zygote ~/.julia/packages/Zygote/AS0Go/src/lib/broadcast.jl:205
  [4] (::Base.var"#4#5"{Zygote.var"#938#943"})(a::Tuple{Tuple{MyFloat64, Zygote.ZBack{ChainRules.var"#exp_pullback#1313"{MyFloat64, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}}}}, MyFloat64})
    @ Base ./generator.jl:36
  [5] iterate
    @ ./generator.jl:47 [inlined]
  [6] collect
    @ ./array.jl:787 [inlined]
  [7] map
    @ ./abstractarray.jl:3055 [inlined]
  [8] (::Zygote.var"#∇broadcasted#942"{Tuple{Vector{MyFloat64}}, Vector{Tuple{MyFloat64, Zygote.ZBack{ChainRules.var"#exp_pullback#1313"{MyFloat64, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}}}}}, Val{2}})(ȳ::Vector{MyFloat64})
    @ Zygote ~/.julia/packages/Zygote/AS0Go/src/lib/broadcast.jl:205
  [9] #3885#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [10] #208
    @ ~/.julia/packages/Zygote/AS0Go/src/lib/lib.jl:206 [inlined]
 [11] #2066#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [12] Pullback
    @ ./broadcast.jl:1298 [inlined]
 [13] (::Zygote.var"#3909#back#952"{typeof(∂(broadcasted))})(Δ::Vector{MyFloat64})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
 [14] top-level scope
    @ REPL[11]:1

@tansongchen
Copy link
Author

tansongchen commented Jan 13, 2023

I was over-simplifying the MWE a little bit 😅 . Thanks for introducing Cthulhu - great tool! Wasn't aware of that. The actual MWE that makes pb runs is

using Zygote, InteractiveUtils
import Base: exp, conj, +, -, *, /

struct MyFloat64 <: Number
    n::Float64
end

*(f1::MyFloat64, f2::MyFloat64) = MyFloat64(f1.n * f2.n)
exp(f::MyFloat64) = MyFloat64(exp(f.n))
conj(f::MyFloat64) = f
my_vector = MyFloat64[1.0, 2.0, 3.0]
result, pb = Zygote._pullback(broadcast, exp, my_vector)
d_result = MyFloat64[1.0, 1.0, 1.0]
pb(d_result)

And now, throwing Cthulhu on it, I get:

(::Zygote.var"#∇broadcasted#942")(ȳ) in Zygote at /Users/tansongchen/.julia/packages/Zygote/AS0Go/src/lib/broadcast.jl:204
    ∘ ─ %0 = invoke ∇broadcasted(::Vector{MyFloat64})::Tuple{Nothing, Any, Any}                                                                              
205 1 ─ %1  = Core.getfield(#self#, :y∂b)::Vector{Tuple{MyFloat64, Zygote.ZBack{ChainRules.var"#exp_pullback#1307"{MyFloat64, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}}}}}
    │   %2  = Core.tuple(%1, ȳ)::Tuple{Vector{Tuple{MyFloat64, Zygote.ZBack{ChainRules.var"#exp_pullback#1307"{MyFloat64, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}}}}}, Vector{MyFloat64}}
    │   %3  = %new(Base.Iterators.Zip{Tuple{Vector{Tuple{MyFloat64, Zygote.ZBack{ChainRules.var"#exp_pullback#1307"{MyFloat64, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}}}}}, Vector{MyFloat64}}}, %2)::Base.Iterators.Zip{Tuple{Vector{Tuple{MyFloat64, Zygote.ZBack{ChainRules.var"#exp_pullback#1307"{MyFloat64, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}}}}}, Vector{MyFloat64}}}
    │   %4  = %new(Base.Generator{Base.Iterators.Zip{Tuple{Vector{Tuple{MyFloat64, Zygote.ZBack{ChainRules.var"#exp_pullback#1307"{MyFloat64, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}}}}}, Vector{MyFloat64}}}, Base.var"#4#5"{Zygote.var"#938#943"}}, Base.var"#4#5"{Zygote.var"#938#943"}(Zygote.var"#938#943"()), %3)::Base.Generator{Base.Iterators.Zip{Tuple{Vector{Tuple{MyFloat64, Zygote.ZBack{ChainRules.var"#exp_pullback#1307"{MyFloat64, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}}}}}, Vector{MyFloat64}}}, Base.var"#4#5"{Zygote.var"#938#943"}}
    │   %5  = invoke Base.collect(%4::Base.Generator{Base.Iterators.Zip{Tuple{Vector{Tuple{MyFloat64, Zygote.ZBack{ChainRules.var"#exp_pullback#1307"{MyFloat64, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}}}}}, Vector{MyFloat64}}}, Base.var"#4#5"{Zygote.var"#938#943"}})::Vector{Tuple{Nothing, MyFloat64}}
206 │   %6  = %new(Zygote.var"#939#944"{Vector{Tuple{Nothing, MyFloat64}}}, %5)::Zygote.var"#939#944"{Vector{Tuple{Nothing, MyFloat64}}}               │     
    │   %7  = invoke %6(1::Int64)::Union{Nothing, Vector}                                                                                              │╻     ntuple
    │   %8  = invoke %6(2::Int64)::Union{Nothing, Vector}                                                                                              ││    
    │   %9  = Core.tuple(%7, %8)::Tuple{Union{Nothing, Vector}, Union{Nothing, Vector}}                                                                ││    
209 │   %10 = Zygote.accum_sum(%7)::Any  
...

The %6 involves a StaticGetter. Maybe the type inference system isn't clever enough to get Vector{MyFloat64} by static-getting Vector{Tuple{Nothing, MyFloat64}}?

I mean, in theory I believe map(StaticGetter{1}(), dxs_zip) should gives Vector{Nothing} and map(StaticGetter{2}(), dxs_zip) should give Vector{MyFloat64}. That's sound.

@tansongchen
Copy link
Author

tansongchen commented Jan 13, 2023

So, let's do this instead:

struct StaticGetter{i} end
(::StaticGetter{i})(v) where {i} = v[i]
(::StaticGetter{i})(::Nothing) where {i} = nothing

function test_type_inference()
    dxs_zip = Tuple{Nothing, Float64}[(nothing, 1.)]

    dxs = ntuple(Val(2)) do i
        map(StaticGetter{i}(), dxs_zip)
    end
    dxs
end

type is corrupted:

julia> @code_warntype test_type_inference()
MethodInstance for test_type_inference()
  from test_type_inference() in Main at /Users/tansongchen/Applications/project/TaylorDiff/.vscode/broadcast.jl:21
Arguments
  #self#::Core.Const(test_type_inference)
Locals
  #9::var"#9#10"{Vector{Tuple{Nothing, Float64}}}
  dxs::Tuple{Vector, Vector}
  dxs_zip::Vector{Tuple{Nothing, Float64}}
Body::Tuple{Vector, Vector}
1 ─ %1 = Core.apply_type(Main.Tuple, Main.Nothing, Main.Float64)::Core.Const(Tuple{Nothing, Float64})
│   %2 = Core.tuple(Main.nothing, 1.0)::Core.Const((nothing, 1.0))
│        (dxs_zip = Base.getindex(%1, %2))
│   %4 = Main.:(var"#9#10")::Core.Const(var"#9#10")
│   %5 = Core.typeof(dxs_zip)::Core.Const(Vector{Tuple{Nothing, Float64}})
│   %6 = Core.apply_type(%4, %5)::Core.Const(var"#9#10"{Vector{Tuple{Nothing, Float64}}})
│        (#9 = %new(%6, dxs_zip))
│   %8 = #9::var"#9#10"{Vector{Tuple{Nothing, Float64}}}
│   %9 = Main.Val(2)::Core.Const(Val{2}())
│        (dxs = Main.ntuple(%8, %9))
└──      return dxs

@ToucheSir
Copy link
Member

I also found the correct methods to implement eventually. The issue appears to be some compiler heuristic being violated, which leads to no const prop from ntuple and StaticGetter not being so static. PR in #1360.

@tansongchen
Copy link
Author

Thank you for the quick fix, I verified your corrections locally by dev'ing Zygote. For my case (non-MWE), being able to infer the type boost it 30× faster 🤣 .

In case you are interested in why I'm always throwing questions onto the corner cases of Zygote, I'm building a fast higher-order forward-mode AD (essentially a rewrite of TaylorSeries.jl with statically inferred polynomial types) and I want it to be downstream-compatible with Zygote so that it can be used in cases like NeuralPDE.jl.

@CarloLucibello
Copy link
Member

Closed by #1360

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants