Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type instability with structs #1094

Closed
kaandocal opened this issue Oct 5, 2021 · 13 comments
Closed

Type instability with structs #1094

kaandocal opened this issue Oct 5, 2021 · 13 comments

Comments

@kaandocal
Copy link

Zygote seems to lose type stability (unnecessarily?) when it is used with structs like Flux.Dense.

using Flux, Zygote

layer = Dense(1,3)

@code_warntype layer([1])

yields

Variables                                                                                                
  a::Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}
  x::Vector{Int64}                                  
  σ::typeof(identity)                                                                                    
  b::Vector{Float32}                                
  W::Matrix{Float32}                                                                                     
                                                                                                         
Body::Vector{Float32}                                                                                    
1%1  = Base.getproperty(a, :weight)::Matrix{Float32}%2  = Base.getproperty(a, :bias)::Vector{Float32}%3  = Base.getproperty(a, )::Core.Const(identity)                                                  
│         (W = %1)                                                                                       
│         (b = %2)                                                                                       
│         (σ = %3)                                                                                       
│   %7  = σ::Core.Const(identity)                                                                        
│   %8  = (W * x)::Vector{Float32}%9  = Base.broadcasted(Flux.:+, %8, b)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(+), Tuple{Vector{Float32}, Vector{Float32}}}%10 = Base.broadcasted(%7, %9)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(identity), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeo
f(+), Tuple{Vector{Float32}, Vector{Float32}}}}}                                                         
│   %11 = Base.materialize(%10)::Vector{Float32}                                                         
└──       return %11 

On the other hand, with Zygote,

@code_warntype _pullback(Context(), layer, [5])

yields

Variables                                                                                                
  #self#::Core.Const(ZygoteRules._pullback)                                                              
  ctx::Context                                                                                           
  f::Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}                    
  args::Tuple{Vector{Int64}}                                                                             
                                                                                                         
Body::Tuple{Any, typeof(∂(λ))}                                                                           
1$(Expr(:meta, :inline))                                                                        
│   %2  = Base.getfield(args, 1)::Vector{Int64}%3  = Zygote._pullback(ctx, Zygote.literal_getproperty, f, Val{:weight}())::Core.PartialStruct(Tuple{Union{typeof(identity), Array{Float32, N} where N}, typeof((literal_getproperty))}, Any[Union{typeof(iden
tity), Array{Float32, N} where N}, Core.PartialStruct(typeof((literal_getproperty)), Any[Core.PartialStruct(Tuple{typeof((getproperty))}, Any[Core.PartialStruct(typeof((getproperty)), Any[Core.PartialStruct(T
uple{Zygote.ZBack{ChainRules.var"#===_pullback#77"}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, UInt8}, Any[Core.Const(Zygote.ZBack{ChainRules.var"#===_pullback#77"}(ChainRules
.var"#===_pullback#77"())), Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, UInt8])])])])])
│   %4  = Base.getindex(%3, 1)::Union{typeof(identity), Array{Float32, N} where N}%5  = Base.getindex(%3, 2)::Core.PartialStruct(typeof((literal_getproperty)), Any[Core.PartialStruct(Tuple{typeof((getproperty))}, Any[Core.PartialStruct(typeof((getproperty)), Any[Core.PartialStruct(Tupl
e{Zygote.ZBack{ChainRules.var"#===_pullback#77"}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, UInt8}, Any[Core.Const(Zygote.ZBack{ChainRules.var"#===_pullback#77"}(ChainRules.va
r"#===_pullback#77"())), Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, UInt8])])])])
│   %6  = Zygote._pullback(ctx, Zygote.literal_getproperty, f, Val{:bias}())::Core.PartialStruct(Tuple{Union{typeof(identity), Array{Float32, N} where N}, typeof((literal_getproperty))}, Any[Union{typeof(identi
ty), Array{Float32, N} where N}, Core.PartialStruct(typeof((literal_getproperty)), Any[Core.PartialStruct(Tuple{typeof((getproperty))}, Any[Core.PartialStruct(typeof((getproperty)), Any[Core.PartialStruct(Tup
le{Zygote.ZBack{ChainRules.var"#===_pullback#77"}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, UInt8}, Any[Core.Const(Zygote.ZBack{ChainRules.var"#===_pullback#77"}(ChainRules.v
ar"#===_pullback#77"())), Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, UInt8])])])])])
│   %7  = Base.getindex(%6, 1)::Union{typeof(identity), Array{Float32, N} where N}%8  = Base.getindex(%6, 2)::Core.PartialStruct(typeof((literal_getproperty)), Any[Core.PartialStruct(Tuple{typeof((getproperty))}, Any[Core.PartialStruct(typeof((getproperty)), Any[Core.PartialStruct(Tupl
e{Zygote.ZBack{ChainRules.var"#===_pullback#77"}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, UInt8}, Any[Core.Const(Zygote.ZBack{ChainRules.var"#===_pullback#77"}(ChainRules.va
r"#===_pullback#77"())), Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, UInt8])])])])
│   %9  = Zygote._pullback(ctx, Zygote.literal_getproperty, f, Val{:σ}())::Core.PartialStruct(Tuple{Union{typeof(identity), Array{Float32, N} where N}, typeof((literal_getproperty))}, Any[Union{typeof(identity)
, Array{Float32, N} where N}, Core.PartialStruct(typeof((literal_getproperty)), Any[Core.PartialStruct(Tuple{typeof((getproperty))}, Any[Core.PartialStruct(typeof((getproperty)), Any[Core.PartialStruct(Tuple{
Zygote.ZBack{ChainRules.var"#===_pullback#77"}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, UInt8}, Any[Core.Const(Zygote.ZBack{ChainRules.var"#===_pullback#77"}(ChainRules.var"
#===_pullback#77"())), Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, UInt8])])])])])
│   %10 = Base.getindex(%9, 1)::Union{typeof(identity), Array{Float32, N} where N}%11 = Base.getindex(%9, 2)::Core.PartialStruct(typeof((literal_getproperty)), Any[Core.PartialStruct(Tuple{typeof((getproperty))}, Any[Core.PartialStruct(typeof((getproperty)), Any[Core.PartialStruct(Tupl
e{Zygote.ZBack{ChainRules.var"#===_pullback#77"}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, UInt8}, Any[Core.Const(Zygote.ZBack{ChainRules.var"#===_pullback#77"}(ChainRules.va
r"#===_pullback#77"())), Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, Vector{Any}, UInt8])])])])
│   %12 = Zygote._pullback(ctx, Flux.:*, %4, %2)::Any%13 = Base.getindex(%12, 1)::Any%14 = Base.getindex(%12, 2)::Any%15 = Zygote._pullback(ctx, Base.broadcasted, Flux.:+, %13, %7)::Any%16 = Base.getindex(%15, 1)::Any%17 = Base.getindex(%15, 2)::Any%18 = Zygote._pullback(ctx, Base.broadcasted, %10, %16)::Any%19 = Base.getindex(%18, 1)::Any%20 = Base.getindex(%18, 2)::Any%21 = Zygote._pullback(ctx, Base.materialize, %19)::Any%22 = Base.getindex(%21, 1)::Any%23 = Base.getindex(%21, 2)::Any
(...)

It seems like the culprit is that Zygote.literal_getproperty defaults to Base.getproperty which is type-unstable. Pay attention to the presence of Union{typeof(identity), Array{Float32, N} where N} which is exactly the generic type of an element of Dense.

Is this intended behaviour? I am trying to optimise a training routine using Flux and I've noticed that profiling shows a lot of type-instability and GC when Zygote is involved, which might cause a certain slowdown. Is there any way to avoid this, or to "precompile" the loss function (assuming no dynamic branching etc.) for better performance?

@willtebbutt
Copy link
Member

willtebbutt commented Oct 5, 2021

This is, unfortunately, a known issue.

There's a PR that's been open for a long time to try and fix the problem, but no progress has been made recently: #909

In the mean time, you can work around (provided that you're using the default getproperty implementation) using the _pullback suggested in FluxML/ZygoteRules.jl#21

Specifically, see the docstring for pullback_for_default_literal_getproperty.

You have to do this for every type for which you want getproperty to be efficient.

@DhairyaLGandhi
Copy link
Member

Will beat me to it, but yes #909 will fix this. @simeonschaub should we revive it? I think it would be the right fit for such use cases

@willtebbutt
Copy link
Member

willtebbutt commented Oct 5, 2021

That would be nice. Every time I try and do anything performant I have to write more code than I would like.

@kaandocal
Copy link
Author

kaandocal commented Oct 5, 2021

Thanks for the answers, and for working on this! I did not think of searching for this issue among the pull requests... It would be very useful if this could be mentioned in the Performance section of the docs. I'm pretty sure the workaround can be included as a simple macro similar to @functor :)

@willtebbutt
Copy link
Member

It would be very useful if this could be mentioned in the Performance section of the docs.

Counterpart PR to the ZygoteRules one: #1091 :)

@kaandocal
Copy link
Author

kaandocal commented Oct 5, 2021

@willtebbutt This does render the code I'm using fully type-stable again under _pullback, which is great! I'm happy with the workaround right now. Maybe one could incorporate this into the basic Flux types, unless a true fix is feasible soon? I'll have a look at the discussion there...

@willtebbutt
Copy link
Member

@DhairyaLGandhi what are your thoughts about properly documenting a work-around vs convincing @simeonschaub to finish up his PR?

@kaandocal
Copy link
Author

Would it be possible to figure out which method of Base.getproperty is called when a field is accessed and automatically apply the workaround if that is the default implementation?

@willtebbutt
Copy link
Member

willtebbutt commented Oct 5, 2021

IIUC that's essentially what #909 is doing.

I still don't fully understand how reliable a strategy it is, because the docs for generated functions suggest it's a bad idea to depend on mutable global state inside of generated functions (the method table being mutable global state). That being said, Zygote does this kind of thing to interop with ChainRules and it seems to work (up to the need to call Zygote.refresh() once in a while), so maybe it's going to be fine 🤷

(I think what I've said is broadly correct -- if someone knows better, please do correct my misunderstanding :) )

@DhairyaLGandhi
Copy link
Member

I largely prefer the direction of #909 and it seems to not hurt existing overrides. I should give the workaround a shake too.

@willtebbutt
Copy link
Member

Agreed -- I'd prefer 909, I just wonder if it will ever get finished.

@willtebbutt
Copy link
Member

#909 has just been merged. @kaandocal when you can get a copy of v0.6.27, could you check that your code is performant without the workaround?

@kaandocal
Copy link
Author

kaandocal commented Oct 18, 2021

It seems to work now, thanks! By the way it turns out Dense uses a custom getproperty implementation due to API deprecations, but for other structs the types are now inferred correctly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants