-
-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Error with gradient of function based on Dictionary #1421
Comments
Just to update, the following variation MWE where we loop over all the keys, is a work around. (So the problem is with the unavailability of rules and methods for function mwe(x::Vector)
y = x.^2
collection = Dict(:a => x, :b => y)
s = zero(eltype(x))
for k in keys(collection)
s += sum(collection[k])
end
s
end
x = rand(3)
Zygote.gradient(mwe, x) # works! Edit: I realised this is not general enough, for example, if each of the value of |
After some trial and error, I have a generic form of the above work around, for which function mwe_generic(x::Vector)
y = x.^2
collection = Dict(:a => x, :b => y)
s = zero(first(values(collection))[1])
for k in keys(collection)
@inbounds s += sum(collection[k])
end
s
end
x = rand(3)
Zygote.gradient(mwe_generic,x) # works! :) But it is good to have |
The above workaround unfortunately doesn't work for function mfe_IdDict(x::Vector)
y = x.^2
collection = IdDict(:a => x, :b => y)
s = zero(first(values(collection))[1])
for k in keys(collection)
@inbounds s += sum(collection[k])
end
s
end
julia> Zygote.gradient(mfe_IdDict,x)
ERROR: Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_eqtable_nextind), UInt64, svec(Any, UInt64), 0, :(:ccall), %2, %5, %4)).
You might want to check the Zygote limitations documentation.
https://fluxml.ai/Zygote.jl/latest/limitations
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] Pullback
@ ./iddict.jl:143 [inlined]
[3] (::Zygote.Pullback{Tuple{typeof(Base._oidd_nextind), Vector{Any}, Int64}, Tuple{Zygote.Pullback{Tuple{typeof(Base.cconvert), Type{UInt64}, Int64}, Tuple{Zygote.ZBack{Zygote.var"#convert_pullback#325"}}}, Zygote.Pullback{Tuple{typeof(reinterpret), Type{Int64}, UInt64}, Tuple{Zygote.Pullback{Tuple{Core.IntrinsicFunction, Type{Int64}, UInt64}, Tuple{Core.IntrinsicFunction}}}}, Zygote.Pullback{Tuple{typeof(Base.unsafe_convert), Type{UInt64}, UInt64}, Tuple{}}}})(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
[4] Pullback
@ ./iddict.jl:146 [inlined]
[5] (::Zygote.Pullback{Tuple{typeof(iterate), IdDict{Symbol, Vector{Float64}}, Int64}, Any})(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
[6] #287
@ ~/.julia/packages/Zygote/SuKWp/src/lib/lib.jl:206 [inlined]
[7] (::Zygote.var"#2139#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{typeof(iterate), IdDict{Symbol, Vector{Float64}}, Int64}, Any}}})(Δ::Nothing)
@ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
[8] Pullback
@ ./abstractdict.jl:64 [inlined]
[9] (::Zygote.Pullback{Tuple{typeof(iterate), Base.KeySet{Symbol, IdDict{Symbol, Vector{Float64}}}, Int64}, Any})(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
[10] Pullback
@ ./REPL[6]:7 [inlined]
[11] (::Zygote.Pullback{Tuple{typeof(mfe_IdDict), Vector{Float64}}, Any})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
[12] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(mfe_IdDict), Vector{Float64}}, Any}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:45
[13] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:97 Hi @ToucheSir, are there plans to make Zygote work with |
There are no plans to make Zygote work better with any kind of Dict, but only because there is no dev capacity to do so. Hence why I added the above labels. Dicts are perhaps one of the trickiest types to create new functionality/fix bugs for in Zygote, but if any brave soul wants to try I'd be happy to guide them. |
Hi,
I encountered the following errors, when working with functions based on Dictionaries, the following are the Minimum Failing Examples (MFEs) and my naive attempts: (They seem to require some methods and
adjoint
s for theBase.ValueIterator
type)The above results in the following error:
Since the above asks for a
size(::Base.ValueIterator{Dict{Symbol, Vector{Float64}}})
and realising that the methodlength(::Base.ValueIterator{Dict{Symbol, Vector{Float64}}})
exists, I tried adding the following methodwhich I don't know if it is the right way to go ahead, but, makes the forward mode, I guess, error free, but now the
Zygote.gradient
requests for anadjoint
, see the following updated error:independent of the above, having the following alternative MFE,
throws the same
Need an adjoint
error as the above:I would be happy to know, if this is fixable by writing an
adjoint
that the error requests or if there is work around for this issue. Thank you!The text was updated successfully, but these errors were encountered: