-
-
Notifications
You must be signed in to change notification settings - Fork 217
Open
Labels
piracyA bug caused by a third-party committing piracyA bug caused by a third-party committing piracy
Description
In a project of mine I want to take derivatives of some Neural SDE solution (computed by the custom wrapper msolve
) wrt. to the Lux NN parameters:
function logvar(prob; ps=prob.p, n=100) # calling this method works
sum( msolve(prob, ps=ps) for i in 1:n)
end
Zygote.gradient(ps->logvar(prob, ps=ps, n=n), prob.p)[1] # this doesnt
fails with a
MethodError: no method matching +(::Tuple{}, ::NamedTuple{(), Tuple{}})
Stacktrace: [...]
[3] accum(x::NamedTuple{(:data, :itr), Tuple{Tuple{}, Nothing}}, y::NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}})
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/lib.jl:27
After following the suggestion of @ToucheSir in #1290 and replacing the generator with sum(_ -> msolve(prob, ps=ps), 1:n)
the error changes to
MethodError: no method matching +(::NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}}, ::NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}})
I hotfixed this with
import Base.+
+(::NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}}, ::NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}}) = (data=(;), itr=nothing)
and the code runs through.
Searching for occurences of (:data, :itr)
I could make out only
Line 155 in de078c8
dps = (data = Base.setindex(data, Δ, k), itr = nothing) |
and the resp. function below.
I have no clue how this all works together but thank @mcabbott and @ToucheSir a lot for helping me find the fix.
Feel free to correct the issue title and let me know if I can be of any further help fixing this (regarding the Zygote internals I am quite out of my water though).
Metadata
Metadata
Assignees
Labels
piracyA bug caused by a third-party committing piracyA bug caused by a third-party committing piracy