Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error with gradient of function based on Dictionary #1421

Open
kishore-nori opened this issue May 12, 2023 · 4 comments
Open

Error with gradient of function based on Dictionary #1421

kishore-nori opened this issue May 12, 2023 · 4 comments
Labels
dictionary help wanted Extra attention is needed up for grabs anyone is welcome to contribute with a PR to fix the issue

Comments

@kishore-nori
Copy link

Hi,

I encountered the following errors, when working with functions based on Dictionaries, the following are the Minimum Failing Examples (MFEs) and my naive attempts: (They seem to require some methods and adjoints for the Base.ValueIterator type)

using Zygote 

function mfe1(x::Vector)
  y = x.^2
  collection = Dict(:a => x, :b => y)
  sum(map(sum,values(collection)))
end

x = rand(3)

Zygote.gradient(mfe1, x)

The above results in the following error:

ERROR: MethodError: no method matching size(::Base.ValueIterator{Dict{Symbol, Vector{Float64}}})
Closest candidates are:
  size(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted}) at ~/julia-src/julia-1.8.5/share/julia/stdlib/v1.8/LinearAlgebra/src/qr.jl:581
  size(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted}, ::Integer) at ~/julia-src/julia-1.8.5/share/julia/stdlib/v1.8/LinearAlgebra/src/qr.jl:580
  size(::Union{LinearAlgebra.Cholesky, LinearAlgebra.CholeskyPivoted}) at ~/julia-src/julia-1.8.5/share/julia/stdlib/v1.8/LinearAlgebra/src/cholesky.jl:514
  ...
Stacktrace:
  [1] axes
    @ ./abstractarray.jl:95 [inlined]
  [2] _tryaxes(x::Base.ValueIterator{Dict{Symbol, Vector{Float64}}})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/lib/array.jl:188
  [3] map
    @ ./tuple.jl:221 [inlined]
  [4] ∇map(cx::Zygote.Context{false}, f::typeof(sum), args::Base.ValueIterator{Dict{Symbol, Vector{Float64}}})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/lib/array.jl:203
  [5] _pullback(cx::Zygote.Context{false}, #unused#::typeof(collect), g::Base.Generator{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, typeof(sum)})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/lib/array.jl:244
  [6] _pullback
    @ ./abstractarray.jl:2961 [inlined]
  [7] _pullback(::Zygote.Context{false}, ::typeof(map), ::typeof(sum), ::Base.ValueIterator{Dict{Symbol, Vector{Float64}}})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
  [8] _pullback
    @ ./REPL[2]:4 [inlined]
  [9] _pullback(ctx::Zygote.Context{false}, f::typeof(mfe1), args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
 [10] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:44
 [11] pullback
    @ ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:42 [inlined]
 [12] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:96

Since the above asks for a size(::Base.ValueIterator{Dict{Symbol, Vector{Float64}}}) and realising that the method length(::Base.ValueIterator{Dict{Symbol, Vector{Float64}}}) exists, I tried adding the following method

Base.size(v::Union{Base.KeySet,Base.ValueIterator}) = (length(v.dict),)

which I don't know if it is the right way to go ahead, but, makes the forward mode, I guess, error free, but now the Zygote.gradient requests for an adjoint, see the following updated error:

ERROR: Need an adjoint for constructor Base.ValueIterator{Dict{Symbol, Vector{Float64}}}. Gradient is of type Vector{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] (::Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false})(Δ::Vector{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/lib/lib.jl:330
  [3] (::Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}})(Δ::Vector{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
  [4] Pullback
    @ ./abstractdict.jl:48 [inlined]
  [5] (::Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}})(Δ::Vector{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
  [6] Pullback
    @ ./abstractdict.jl:48 [inlined]
  [7] (::Zygote.Pullback{Tuple{Type{Base.ValueIterator}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}}}})(Δ::Vector{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
  [8] Pullback
    @ ./abstractdict.jl:131 [inlined]
  [9] (::Zygote.Pullback{Tuple{typeof(values), Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}}}}}})(Δ::Vector{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
 [10] Pullback
    @ ./REPL[2]:4 [inlined]
 [11] (::Zygote.Pullback{Tuple{typeof(mfe1), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(values), Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}}}}}}, Zygote.Pullback{Tuple{Type{Pair}, Symbol, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2177#back#309"{Zygote.Jnew{Pair{Symbol, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Symbol}, Symbol}, Tuple{}}}}, Zygote.var"#2988#back#777"{Zygote.var"#771#775"{Vector{Float64}}}, Zygote.Pullback{Tuple{Type{Pair}, Symbol, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2177#back#309"{Zygote.Jnew{Pair{Symbol, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Symbol}, Symbol}, Tuple{}}}}, Zygote.Pullback{Tuple{typeof(map), typeof(sum), Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.Generator}, typeof(sum), Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.Generator{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, typeof(sum)}}, typeof(sum), Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(convert), Type{typeof(sum)}, typeof(sum)}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.Generator{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, typeof(sum)}, Nothing, false}}}}}}, Zygote.var"#collect_pullback#715"{Zygote.var"#map_back#677"{typeof(sum), 1, Tuple{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{Tuple{Base.OneTo{Int64}}}, Vector{Tuple{Float64, Zygote.var"#2988#back#777"{Zygote.var"#771#775"{Vector{Float64}}}}}}, Nothing}}}, Zygote.var"#3863#back#1242"{Zygote.var"#1238#1241"{2, Vector{Float64}}}, Zygote.var"#1892#back#157"{Zygote.var"#153#156"}, Zygote.Pullback{Tuple{Type{Dict}, Pair{Symbol, Vector{Float64}}, Pair{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Dict{Symbol, Vector{Float64}}}, Tuple{Pair{Symbol, Vector{Float64}}, Pair{Symbol, Vector{Float64}}}}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Float64}}, Tuple{}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(mfe1), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(values), Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}}}}}}, Zygote.Pullback{Tuple{Type{Pair}, Symbol, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2177#back#309"{Zygote.Jnew{Pair{Symbol, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Symbol}, Symbol}, Tuple{}}}}, Zygote.var"#2988#back#777"{Zygote.var"#771#775"{Vector{Float64}}}, Zygote.Pullback{Tuple{Type{Pair}, Symbol, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2177#back#309"{Zygote.Jnew{Pair{Symbol, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Symbol}, Symbol}, Tuple{}}}}, Zygote.Pullback{Tuple{typeof(map), typeof(sum), Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.Generator}, typeof(sum), Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.Generator{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, typeof(sum)}}, typeof(sum), Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(convert), Type{typeof(sum)}, typeof(sum)}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.Generator{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, typeof(sum)}, Nothing, false}}}}}}, Zygote.var"#collect_pullback#715"{Zygote.var"#map_back#677"{typeof(sum), 1, Tuple{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{Tuple{Base.OneTo{Int64}}}, Vector{Tuple{Float64, Zygote.var"#2988#back#777"{Zygote.var"#771#775"{Vector{Float64}}}}}}, Nothing}}}, Zygote.var"#3863#back#1242"{Zygote.var"#1238#1241"{2, Vector{Float64}}}, Zygote.var"#1892#back#157"{Zygote.var"#153#156"}, Zygote.Pullback{Tuple{Type{Dict}, Pair{Symbol, Vector{Float64}}, Pair{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Dict{Symbol, Vector{Float64}}}, Tuple{Pair{Symbol, Vector{Float64}}, Pair{Symbol, Vector{Float64}}}}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Float64}}, Tuple{}}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:45
 [13] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:97

independent of the above, having the following alternative MFE,

function mfe2(x::Vector)
  y = x.^2
  collection = Dict(:a => x, :b => y)
  v = vcat(values(collection)...)
  sum(v)
end

throws the same Need an adjoint error as the above:

ERROR: Need an adjoint for constructor Base.ValueIterator{Dict{Symbol, Vector{Float64}}}. Gradient is of type Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] (::Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false})(Δ::Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/lib/lib.jl:330
  [3] (::Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}})(Δ::Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
  [4] Pullback
    @ ./abstractdict.jl:48 [inlined]
  [5] (::Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}})(Δ::Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
  [6] Pullback
    @ ./abstractdict.jl:48 [inlined]
  [7] (::Zygote.Pullback{Tuple{Type{Base.ValueIterator}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}}}})(Δ::Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
  [8] Pullback
    @ ./abstractdict.jl:131 [inlined]
  [9] (::Zygote.Pullback{Tuple{typeof(values), Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}}}}}})(Δ::Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
 [10] Pullback
    @ ./REPL[2]:4 [inlined]
 [11] (::Zygote.Pullback{Tuple{typeof(mfe2), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(values), Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}}}}}}, Zygote.Pullback{Tuple{Type{Pair}, Symbol, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2177#back#309"{Zygote.Jnew{Pair{Symbol, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Symbol}, Symbol}, Tuple{}}}}, Zygote.var"#2988#back#777"{Zygote.var"#771#775"{Vector{Float64}}}, Zygote.Pullback{Tuple{Type{Pair}, Symbol, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2177#back#309"{Zygote.Jnew{Pair{Symbol, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Symbol}, Symbol}, Tuple{}}}}, Zygote.var"#2139#back#289"{Zygote.var"#287#288"{Tuple{Int64}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}, Tuple{Tuple{Int64}, Tuple{Int64}}, Val{1}}}}}, Zygote.var"#3863#back#1242"{Zygote.var"#1238#1241"{2, Vector{Float64}}}, Zygote.var"#1892#back#157"{Zygote.var"#153#156"}, Zygote.Pullback{Tuple{Type{Dict}, Pair{Symbol, Vector{Float64}}, Pair{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Dict{Symbol, Vector{Float64}}}, Tuple{Pair{Symbol, Vector{Float64}}, Pair{Symbol, Vector{Float64}}}}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Float64}}, Tuple{}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(mfe2), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(values), Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}}}}}}, Zygote.Pullback{Tuple{Type{Pair}, Symbol, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2177#back#309"{Zygote.Jnew{Pair{Symbol, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Symbol}, Symbol}, Tuple{}}}}, Zygote.var"#2988#back#777"{Zygote.var"#771#775"{Vector{Float64}}}, Zygote.Pullback{Tuple{Type{Pair}, Symbol, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2177#back#309"{Zygote.Jnew{Pair{Symbol, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Symbol}, Symbol}, Tuple{}}}}, Zygote.var"#2139#back#289"{Zygote.var"#287#288"{Tuple{Int64}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}, Tuple{Tuple{Int64}, Tuple{Int64}}, Val{1}}}}}, Zygote.var"#3863#back#1242"{Zygote.var"#1238#1241"{2, Vector{Float64}}}, Zygote.var"#1892#back#157"{Zygote.var"#153#156"}, Zygote.Pullback{Tuple{Type{Dict}, Pair{Symbol, Vector{Float64}}, Pair{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Dict{Symbol, Vector{Float64}}}, Tuple{Pair{Symbol, Vector{Float64}}, Pair{Symbol, Vector{Float64}}}}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Float64}}, Tuple{}}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:45
 [13] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:97

I would be happy to know, if this is fixable by writing an adjoint that the error requests or if there is work around for this issue. Thank you!

@kishore-nori
Copy link
Author

kishore-nori commented May 12, 2023

Just to update, the following variation MWE where we loop over all the keys, is a work around. (So the problem is with the unavailability of rules and methods for Base.ValueIterator, which is invoked in the above methods)

function mwe(x::Vector)
  y = x.^2
  collection = Dict(:a => x, :b => y)
  s = zero(eltype(x))
  for k in keys(collection)
    s += sum(collection[k])
  end
  s
end

x = rand(3)

Zygote.gradient(mwe, x) # works! 

Edit: I realised this is not general enough, for example, if each of the value of Dict has different eltype, then this is probably not a good idea.

@kishore-nori
Copy link
Author

After some trial and error, I have a generic form of the above work around, for which Zygote.gradient works,

function mwe_generic(x::Vector)
  y = x.^2
  collection = Dict(:a => x, :b => y)
  s = zero(first(values(collection))[1])
  for k in keys(collection)
    @inbounds s += sum(collection[k])
  end
  s
end

x = rand(3)

Zygote.gradient(mwe_generic,x) # works! :)

But it is good to have methods and adjoint for Base.ValueIterator for the original MFE to work!

@ToucheSir ToucheSir added help wanted Extra attention is needed up for grabs anyone is welcome to contribute with a PR to fix the issue dictionary labels May 14, 2023
@kishore-nori
Copy link
Author

kishore-nori commented May 15, 2023

The above workaround unfortunately doesn't work for IdDict, seems like it is hitting a ccall which Zygote doesn't propagate through, see the following:

function mfe_IdDict(x::Vector)
  y = x.^2
  collection = IdDict(:a => x, :b => y)
  s = zero(first(values(collection))[1])
  for k in keys(collection)
    @inbounds s += sum(collection[k])
  end
  s
end

julia> Zygote.gradient(mfe_IdDict,x)
ERROR: Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_eqtable_nextind), UInt64, svec(Any, UInt64), 0, :(:ccall), %2, %5, %4)).
You might want to check the Zygote limitations documentation.
https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] Pullback
    @ ./iddict.jl:143 [inlined]
  [3] (::Zygote.Pullback{Tuple{typeof(Base._oidd_nextind), Vector{Any}, Int64}, Tuple{Zygote.Pullback{Tuple{typeof(Base.cconvert), Type{UInt64}, Int64}, Tuple{Zygote.ZBack{Zygote.var"#convert_pullback#325"}}}, Zygote.Pullback{Tuple{typeof(reinterpret), Type{Int64}, UInt64}, Tuple{Zygote.Pullback{Tuple{Core.IntrinsicFunction, Type{Int64}, UInt64}, Tuple{Core.IntrinsicFunction}}}}, Zygote.Pullback{Tuple{typeof(Base.unsafe_convert), Type{UInt64}, UInt64}, Tuple{}}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
  [4] Pullback
    @ ./iddict.jl:146 [inlined]
  [5] (::Zygote.Pullback{Tuple{typeof(iterate), IdDict{Symbol, Vector{Float64}}, Int64}, Any})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
  [6] #287
    @ ~/.julia/packages/Zygote/SuKWp/src/lib/lib.jl:206 [inlined]
  [7] (::Zygote.var"#2139#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{typeof(iterate), IdDict{Symbol, Vector{Float64}}, Int64}, Any}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
  [8] Pullback
    @ ./abstractdict.jl:64 [inlined]
  [9] (::Zygote.Pullback{Tuple{typeof(iterate), Base.KeySet{Symbol, IdDict{Symbol, Vector{Float64}}}, Int64}, Any})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
 [10] Pullback
    @ ./REPL[6]:7 [inlined]
 [11] (::Zygote.Pullback{Tuple{typeof(mfe_IdDict), Vector{Float64}}, Any})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(mfe_IdDict), Vector{Float64}}, Any}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:45
 [13] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:97

Hi @ToucheSir, are there plans to make Zygote work with IdDict? (should I open a different issue? I haven't found any IdDict related issue in issues section here.)

@ToucheSir
Copy link
Member

There are no plans to make Zygote work better with any kind of Dict, but only because there is no dev capacity to do so. Hence why I added the above labels. Dicts are perhaps one of the trickiest types to create new functionality/fix bugs for in Zygote, but if any brave soul wants to try I'd be happy to guide them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dictionary help wanted Extra attention is needed up for grabs anyone is welcome to contribute with a PR to fix the issue
Projects
None yet
Development

No branches or pull requests

2 participants