Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error from gradient of vcat(x...) - appeared in v0.6.45 #1417

Open
danielalcalde opened this issue Apr 28, 2023 · 6 comments
Open

Error from gradient of vcat(x...) - appeared in v0.6.45 #1417

danielalcalde opened this issue Apr 28, 2023 · 6 comments
Labels
bug Something isn't working ChainRules adjoint -> rrule, and further integration

Comments

@danielalcalde
Copy link

In GTorlai/PastaQ.jl#300 (comment) a bug was detected that I have found to stem from a problem in the differentiation of vcat. I created a minimal example to reproduce the error:

using Zygote
function loss(theta)
    x1 = vcat([theta], 5)
    x2 = vcat(x1...)
    return x2[1]
end
println(gradient(loss, 1))
ERROR: LoadError: MethodError: no method matching (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}})(::Tuple{Float64})
Closest candidates are:
  (::ChainRulesCore.ProjectTo{AbstractArray})(::Union{LinearAlgebra.Adjoint{T, var"#s886"}, LinearAlgebra.Transpose{T, var"#s886"}} where {T, var"#s886"<:(AbstractVector)}) at ~/.julia/packages/ChainRulesCore/a4mIA/src/projection.jl:247
  (::ChainRulesCore.ProjectTo{AbstractArray})(::AbstractArray{<:ChainRulesCore.AbstractZero}) at ~/.julia/packages/ChainRulesCore/a4mIA/src/projection.jl:244
  (::ChainRulesCore.ProjectTo{AbstractArray})(::AbstractArray{S, M}) where {S, M} at ~/.julia/packages/ChainRulesCore/a4mIA/src/projection.jl:219
  ...
Stacktrace:
  [1] (::ChainRules.var"#1413#1419"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Tuple{UnitRange{Int64}}, ChainRulesCore.Tangent{Any, Tuple{Float64, Float64}}})()
    @ ChainRules ~/.julia/packages/ChainRules/aKxNz/src/rulesets/Base/array.jl:310
  [2] unthunk
    @ ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/thunks.jl:204 [inlined]
  [3] unthunk(x::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1413#1419"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Tuple{UnitRange{Int64}}, ChainRulesCore.Tangent{Any, Tuple{Float64, Float64}}}}, ChainRules.var"#1412#1418"{Tuple{UnitRange{Int64}}, ChainRulesCore.Tangent{Any, Tuple{Float64, Float64}}}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/thunks.jl:237
  [4] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/SuKWp/src/compiler/chainrules.jl:110 [inlined]
  [5] map
    @ ./tuple.jl:223 [inlined]
  [6] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/SuKWp/src/compiler/chainrules.jl:111 [inlined]
  [7] (::Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{Int64}, Tuple{}}, Val{1}}})(dy::Tuple{Float64, Float64})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/chainrules.jl:211
  [8] Pullback
    @ ~/workprojects/education/julia/pastaq/break.jl:3 [inlined]
  [9] (::Zygote.Pullback{Tuple{typeof(loss), Int64}, Tuple{Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Int64}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{}, Tuple{}}, Val{1}}}}}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{Int64}, Tuple{}}, Val{1}}}, Zygote.ZBack{ChainRules.var"#vect_pullback#1369"{1, Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), Vector{Int64}, Val{1}}, Tuple{Zygote.var"#2571#back#528"{Zygote.var"#538#540"{1, Int64, Vector{Int64}, Tuple{Int64}}}}}}})(Δ::Int64)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
 [10] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(loss), Int64}, Tuple{Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Int64}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{}, Tuple{}}, Val{1}}}}}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{Int64}, Tuple{}}, Val{1}}}, Zygote.ZBack{ChainRules.var"#vect_pullback#1369"{1, Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, Zygote.Pullback{Tuple{typeof(Zygote.literal_getindex), Vector{Int64}, Val{1}}, Tuple{Zygote.var"#2571#back#528"{Zygote.var"#538#540"{1, Int64, Vector{Int64}, Tuple{Int64}}}}}}}})(Δ::Int64)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:45
 [11] gradient(::Function, ::Int64, ::Vararg{Int64})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:97

this code used to work in version Zygote@0.6.44 but does not work as early as Zygote@0.6.45 until Zygote@0.6.60.

@ToucheSir ToucheSir added bug Something isn't working ChainRules adjoint -> rrule, and further integration labels Apr 29, 2023
@ToucheSir
Copy link
Member

0.6.45 is when we switched over to ChainRules for the cat functions: #1277. TBD whether ChainRules projection isn't being flexible enough or if Zygote is passing invalid inputs to it.

@mcabbott
Copy link
Member

I think this is #599, x1... makes a Tuple but the gradient of x1 ought to be an array. It's been worked around in some cases (e..g with _project) but not all.

@mcabbott mcabbott changed the title Gradient issue with vcat - appeared in v0.6.45 Error from gradient of vcat(x...) - appeared in v0.6.45 May 12, 2023
@theabhirath
Copy link
Member

I am running into this issue while trying to implement DenseNet. Since vcat is one of the only non-mutating ways to append elements to arrays, this is a blocker for that. Is there a workaround or a fix for this? I confirmed that it was working on 0.6.44, but the error appears on versions higher than that.

@mcabbott
Copy link
Member

mcabbott commented May 24, 2023

Can you simplify the example, or make other ones? Perhaps the splat isn't the right diagnosis, as things like this seem fine:

julia> gradient([2, 3.0]) do x
         vcat(x...)[1]
       end
([1.0, 0.0],)

julia> gradient([2, 3.0]) do x
         vcat(x..., 4)[1]
       end
([1.0, 0.0],)

No, I think those are getting fixed... pullback avoids a final _project on the answer of gradient, here the splat clearly makes the tuple:

julia> pullback([2, 3.0]) do x
         vcat(x...)[1]
       end[2](1.0)
((1.0, 0.0),)

The rrule involved cannot fix this, it sees and returns individual arguments:

julia> using ChainRules, ChainRulesCore

julia> rrule(vcat, 3/4, 4/5)[2]([6.6, 7.7])
(NoTangent(), InplaceableThunk(ChainRules.var"#..., Thunk(ChainRules.var"#...)), InplaceableThunk(ChainRules.var"#..., Thunk(ChainRules.var"#...)))

julia> unthunk.(ans)
(NoTangent(), 6.6, 7.7)

@theabhirath
Copy link
Member

Differentiating through this is what causes the error for me:

function (m::DenseBlock)(x)
    input = [x]
    for layer in m.layers
        x = layer(input)
        input = vcat(input, [x])
    end
    return cat_channels(input...)
end

This is the only place vcat is used in my code. The layers are mostly simple Chains with Convs and BatchNorms, in case that is useful information. It does seem to suggest that the splat is not the only issue.

@ToucheSir
Copy link
Member

I think we should be implementing DenseNet differently anyhow (toss up a PR if you want some ideas there), so this shouldn't block Metalhead at least.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working ChainRules adjoint -> rrule, and further integration
Projects
None yet
Development

No branches or pull requests

4 participants