Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inferrability of cat #149

Closed
wants to merge 2 commits into from
Closed

Inferrability of cat #149

wants to merge 2 commits into from

Conversation

DhairyaLGandhi
Copy link
Member

@DhairyaLGandhi DhairyaLGandhi commented Apr 21, 2022

cat fails to infer for non constant arguments (possible improvement in JuliaLang/julia#45028). This causes a lot of churn in the models that make use of cat. I defined an rrule but that didn't infer properly (I think I know why, but we can protect ourselves from accidental regressions).

rrule:

julia> r = rand(Float32, 56, 56, 64, 1);

julia> @code_warntype gradient((m,x) -> sum(inferredcat(m, x, dims = 3)), r, r);
MethodInstance for Zygote.gradient(::var"#18#19", ::Array{Float32, 4}, ::Array{Float32, 4})
  from gradient(f, args...) in Zygote at /home/dhairyalgandhi/FluxBench.jl/Zygote.jl/src/compiler/interface.jl:74
Arguments
  #self#::Core.Const(Zygote.gradient)
  f::Core.Const(var"#18#19"())
  args::Tuple{Array{Float32, 4}, Array{Float32, 4}}
Locals
  @_4::Int64
  grad::Tuple{Any, Any}
  back::Zygote.var"#72#73"{typeof((λ))}
  y::Float32
Body::Tuple{Any, Any}
1%1  = Core.tuple(f)::Core.Const((var"#18#19"(),))                                                                   [189/1801]
│   %2  = Core._apply_iterate(Base.iterate, Zygote.pullback, %1, args)::Core.PartialStruct(Tuple{Float32, Zygote.var"#72#73"{typeo
f((λ))}}, Any[Float32, Core.PartialStruct(Zygote.var"#72#73"{typeof((λ))}, Any[Core.PartialStruct(typeof((λ)), Any[Core.Partial
Struct(Tuple{Zygote.var"#2730#back#649"{Zygote.var"#645#647"{Array{Float32, 4}}}, typeof((λ)), Zygote.var"#kw_zpullback#57"{var"#
inferredcat_pullback#3"{Int64, Tuple{NTuple{4, Int64}, NTuple{4, Int64}}}}, Zygote.var"#1639#back#177"{typeof(identity)}}, Any[Zyg
ote.var"#2730#back#649"{Zygote.var"#645#647"{Array{Float32, 4}}}, typeof((λ)), Core.PartialStruct(Zygote.var"#kw_zpullback#57"{va
r"#inferredcat_pullback#3"{Int64, Tuple{NTuple{4, Int64}, NTuple{4, Int64}}}}, Any[Core.PartialStruct(var"#inferredcat_pullback#3"
{Int64, Tuple{NTuple{4, Int64}, NTuple{4, Int64}}}, Any[Core.Const(3), Tuple{NTuple{4, Int64}, NTuple{4, Int64}}])]), Zygote.var"#
1639#back#177"{typeof(identity)}])])])])
│   %3  = Base.indexed_iterate(%2, 1)::Core.PartialStruct(Tuple{Float32, Int64}, Any[Float32, Core.Const(2)])
│         (y = Core.getfield(%3, 1))
│         (@_4 = Core.getfield(%3, 2))
│   %6  = Base.indexed_iterate(%2, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Zygote.var"#72#73"{typeof((λ))}, Int64}, Any[
Core.PartialStruct(Zygote.var"#72#73"{typeof((λ))}, Any[Core.PartialStruct(typeof((λ)), Any[Core.PartialStruct(Tuple{Zygote.var"
#2730#back#649"{Zygote.var"#645#647"{Array{Float32, 4}}}, typeof((λ)), Zygote.var"#kw_zpullback#57"{var"#inferredcat_pullback#3"{
Int64, Tuple{NTuple{4, Int64}, NTuple{4, Int64}}}}, Zygote.var"#1639#back#177"{typeof(identity)}}, Any[Zygote.var"#2730#back#649"{
Zygote.var"#645#647"{Array{Float32, 4}}}, typeof((λ)), Core.PartialStruct(Zygote.var"#kw_zpullback#57"{var"#inferredcat_pullback#
3"{Int64, Tuple{NTuple{4, Int64}, NTuple{4, Int64}}}}, Any[Core.PartialStruct(var"#inferredcat_pullback#3"{Int64, Tuple{NTuple{4,
Int64}, NTuple{4, Int64}}}, Any[Core.Const(3), Tuple{NTuple{4, Int64}, NTuple{4, Int64}}])]), Zygote.var"#1639#back#177"{typeof(id
entity)}])])]), Core.Const(3)])
│         (back = Core.getfield(%6, 1))
│   %8  = Zygote.sensitivity(y)::Core.Const(1.0f0)
│         (grad = (back::Core.PartialStruct(Zygote.var"#72#73"{typeof((λ))}, Any[Core.PartialStruct(typeof((λ)), Any[Core.Partia
lStruct(Tuple{Zygote.var"#2730#back#649"{Zygote.var"#645#647"{Array{Float32, 4}}}, typeof((λ)), Zygote.var"#kw_zpullback#57"{var"
#inferredcat_pullback#3"{Int64, Tuple{NTuple{4, Int64}, NTuple{4, Int64}}}}, Zygote.var"#1639#back#177"{typeof(identity)}}, Any[Zy
gote.var"#2730#back#649"{Zygote.var"#645#647"{Array{Float32, 4}}}, typeof((λ)), Core.PartialStruct(Zygote.var"#kw_zpullback#57"{v
ar"#inferredcat_pullback#3"{Int64, Tuple{NTuple{4, Int64}, NTuple{4, Int64}}}}, Any[Core.PartialStruct(var"#inferredcat_pullback#3
"{Int64, Tuple{NTuple{4, Int64}, NTuple{4, Int64}}}, Any[Core.Const(3), Tuple{NTuple{4, Int64}, NTuple{4, Int64}}])]), Zygote.var"
#1639#back#177"{typeof(identity)}])])]))(%8))
│   %10 = Zygote.isnothing(grad)::Core.Const(false)
└──       goto #3 if not %10
2 ─       Core.Const(:(return Zygote.nothing))
3%13 = Zygote.map(Zygote._project, args, grad)::Tuple{Any, Any}
└──       return %13

adjoint:

julia> @code_warntype gradient((x,y) -> sum(Metalhead.inferredcat(x, y, dims = 3)), r, r)
MethodInstance for Zygote.gradient(::var"#39#40", ::Array{Float32, 4}, ::Array{Float32, 4})
  from gradient(f, args...) in Zygote at /home/dhairyalgandhi/FluxBench.jl/Zygote.jl/src/compiler/interface.jl:74
Arguments
  #self#::Core.Const(Zygote.gradient)
  f::Core.Const(var"#39#40"())
  args::Tuple{Array{Float32, 4}, Array{Float32, 4}}
Locals
  @_4::Int64
  grad::Tuple{FillArrays.Fill{Float32, 4, NTuple{4, Base.OneTo{Int64}}}, FillArrays.Fill{Float32, 4, NTuple{4, Base.OneTo{Int64}}}
}
  back::Zygote.var"#72#73"{typeof((λ))}
  y::Float32
Body::Tuple{FillArrays.Fill{Float32, 4, NTuple{4, Base.OneTo{Int64}}}, FillArrays.Fill{Float32, 4, NTuple{4, Base.OneTo{Int64}}}}
1%1  = Core.tuple(f)::Core.Const((var"#39#40"(),))
│   %2  = Core._apply_iterate(Base.iterate, Zygote.pullback, %1, args)::Core.PartialStruct(Tuple{Float32, Zygote.var"#72#73"{typeo
f((λ))}}, Any[Float32, Core.PartialStruct(Zygote.var"#72#73"{typeof((λ))}, Any[Core.PartialStruct(typeof((λ)), Any[Core.Partial
Struct(Tuple{typeof((λ)), Zygote.var"#2730#back#649"{Zygote.var"#645#647"{Array{Float32, 4}}}, Metalhead.var"#10#back#18"{Metalhe
ad.var"#3#9"{Int64, Tuple{NTuple{4, Int64}, NTuple{4, Int64}}}}, Zygote.var"#1639#back#177"{typeof(identity)}}, Any[typeof((λ)),
Zygote.var"#2730#back#649"{Zygote.var"#645#647"{Array{Float32, 4}}}, Core.PartialStruct(Metalhead.var"#10#back#18"{Metalhead.var"#
3#9"{Int64, Tuple{NTuple{4, Int64}, NTuple{4, Int64}}}}, Any[Core.PartialStruct(Metalhead.var"#3#9"{Int64, Tuple{NTuple{4, Int64},
 NTuple{4, Int64}}}, Any[Core.Const(3), Tuple{NTuple{4, Int64}, NTuple{4, Int64}}])]), Zygote.var"#1639#back#177"{typeof(identity)
}])])])])
│   %3  = Base.indexed_iterate(%2, 1)::Core.PartialStruct(Tuple{Float32, Int64}, Any[Float32, Core.Const(2)])
│         (y = Core.getfield(%3, 1))
│         (@_4 = Core.getfield(%3, 2))
│   %6  = Base.indexed_iterate(%2, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Zygote.var"#72#73"{typeof((λ))}, Int64}, Any[
Core.PartialStruct(Zygote.var"#72#73"{typeof((λ))}, Any[Core.PartialStruct(typeof((λ)), Any[Core.PartialStruct(Tuple{typeof((λ)
), Zygote.var"#2730#back#649"{Zygote.var"#645#647"{Array{Float32, 4}}}, Metalhead.var"#10#back#18"{Metalhead.var"#3#9"{Int64, Tupl
e{NTuple{4, Int64}, NTuple{4, Int64}}}}, Zygote.var"#1639#back#177"{typeof(identity)}}, Any[typeof((λ)), Zygote.var"#2730#back#64
9"{Zygote.var"#645#647"{Array{Float32, 4}}}, Core.PartialStruct(Metalhead.var"#10#back#18"{Metalhead.var"#3#9"{Int64, Tuple{NTuple
{4, Int64}, NTuple{4, Int64}}}}, Any[Core.PartialStruct(Metalhead.var"#3#9"{Int64, Tuple{NTuple{4, Int64}, NTuple{4, Int64}}}, Any
[Core.Const(3), Tuple{NTuple{4, Int64}, NTuple{4, Int64}}])]), Zygote.var"#1639#back#177"{typeof(identity)}])])]), Core.Const(3)])
│         (back = Core.getfield(%6, 1))
│   %8  = Zygote.sensitivity(y)::Core.Const(1.0f0)
│         (grad = (back::Core.PartialStruct(Zygote.var"#72#73"{typeof((λ))}, Any[Core.PartialStruct(typeof((λ)), Any[Core.Partia
lStruct(Tuple{typeof((λ)), Zygote.var"#2730#back#649"{Zygote.var"#645#647"{Array{Float32, 4}}}, Metalhead.var"#10#back#18"{Metalh
ead.var"#3#9"{Int64, Tuple{NTuple{4, Int64}, NTuple{4, Int64}}}}, Zygote.var"#1639#back#177"{typeof(identity)}}, Any[typeof((λ)),
 Zygote.var"#2730#back#649"{Zygote.var"#645#647"{Array{Float32, 4}}}, Core.PartialStruct(Metalhead.var"#10#back#18"{Metalhead.var"
#3#9"{Int64, Tuple{NTuple{4, Int64}, NTuple{4, Int64}}}}, Any[Core.PartialStruct(Metalhead.var"#3#9"{Int64, Tuple{NTuple{4, Int64}
, NTuple{4, Int64}}}, Any[Core.Const(3), Tuple{NTuple{4, Int64}, NTuple{4, Int64}}])]), Zygote.var"#1639#back#177"{typeof(identity
)}])])]))(%8))
│   %10 = Zygote.isnothing(grad)::Core.Const(false)
└──       goto #3 if not %10
2 ─       Core.Const(:(return Zygote.nothing))
3%13 = Zygote.map(Zygote._project, args, grad)::Tuple{FillArrays.Fill{Float32, 4, NTuple{4, Base.OneTo{Int64}}}, FillArrays.Fil
l{Float32, 4, NTuple{4, Base.OneTo{Int64}}}}
└──       return %13

I also ran it against FluxBench and pulled out some tests from there

8×4 DataFrame
 Row │ benchmark                                                           min_PR   min_master   min_baseline
     │ String                                                              Float64   Float64?    Float64?
─────┼─────────────────────────────────────────────────────────────────────────────────────────────────────
   1 │ Metalhead/Backwards_Pass_Metalhead.ResNet_34_with_batchsize_10      8.35454   10.7336     14.2336
   2 │ Metalhead/Backwards_Pass_Metalhead.ResNet_50_with_batchsize_10      9.08442   11.654      13.7857
   3 │ Metalhead/Backwards_Pass_Metalhead.DenseNet_121_with_batchsize_5    7.57214    8.80217     7.79749
   4 │ Metalhead/Backwards_Pass_Metalhead.ResNet_34_with_batchsize_5       4.83328    6.04459     8.06764
   5 │ Metalhead/Backwards_Pass_Metalhead.ResNet_18_with_batchsize_10      9.14701    8.57798    14.7049
   6 │ Flux3D/Flux3D_TriMesh_Backward_Pass_CUDA                            0.015529   0.0162249    missing
   7 │ Transformers/Bert-base-uncased_Backward_Pass_seq_len_8_nbatch_8     0.265012   0.478007     missing
   8 │ Transformers/Bert-base-uncased_Backward_Pass_seq_len_32_nbatch_8    0.837862   1.03684      missing

Note that there is a massive regression in the compile time of certain models (like DenseNet)

cat_channels(xy...) = cat(xy...; dims = 3)
cat_channels(xy...) = inferredcat(xy...; dims = 3)

function inferredcat(xs::T...; dims = :)::T where T <: AbstractArray
Copy link
Member Author

@DhairyaLGandhi DhairyaLGandhi Apr 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that this limits it to arrays of the same type

@ToucheSir
Copy link
Member

ToucheSir commented Apr 23, 2022

As luck would have it, https://github.com/JuliaDiff/ChainRules.jl/blob/8c34f19d3a8a8a224c9fbe20524d2a08c8b9bf81/src/rulesets/Base/array.jl#L345 actually is type stable if you invoke cat with dims=Val(n). All that's needed is to remove the existing adjoint at https://github.com/FluxML/Zygote.jl/blob/8b9d67dbc29ef438c65e9ae92ce5c68c2188cd6b/src/lib/array.jl#L119. The only caveat is that gradient((m,x) -> sum(cat(m, x, dims = Val(3))), r, r) will fail right now because the pullback is given a Fill, but I've opened JuliaDiff/ChainRules.jl#610 to address that.

Given the relative ease of this approach, it's also worth exploring changing cat_channels to use ndims - 1 instead of hard-coding 3. This would allow us to accommodate 1 + 3D convs, as well as dense layers (since the feature dim can be considered the channel dim there). Perhaps it could be moved to NNlib or MLUtils as well so that more libraries can make use of it.

@DhairyaLGandhi
Copy link
Member Author

The julia lang pr should resolve the inferrability of the Zygote adjoint. Fwiw, I had tried threading the existing rrule and only wrote this to get the stability for the entire differentiation pipeline.

@ToucheSir
Copy link
Member

It's unclear at this point whether the base PR will be backported. In contrast, using the ChainRule works back to at least 1.6 and is something we should be doing anyways (deleting old redundant adjoints in Zygote to reduce our maintenance burden).

@DhairyaLGandhi
Copy link
Member Author

Happy to backport if that is the blocker.

@ToucheSir
Copy link
Member

ToucheSir commented Apr 26, 2022

My point was that we can get a type stable cat_channels forwards + backwards with basically net negative lines of code added. The first part would be removing the existing cat adjoint, which is now tracked in FluxML/Zygote.jl#1212. The second part would be to change the definition of cat_channels to:

cat_channels(xy...) = cat(xy...; dims = Val(3))

The Base PR is nice but not at all necessary for this approach.

@ToucheSir
Copy link
Member

Now that the aforementioned 2 PRs have landed and #170 has changed chat_channels to use Val, we can settle this. Further improvements to inference without using Val can be left to Base and ChainRules :)

@ToucheSir ToucheSir closed this Jun 19, 2022
@CarloLucibello CarloLucibello deleted the dg/infer branch July 17, 2023 05:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants