Closed
Description
When using tuples in the function g
(instead of a vector in function f
) in the example below
using Zygote
x = rand(3)
y = rand(3)
f(x,y) = sum(sum, [[x[i],y[i]] for i=1:length(x)])
Zygote.gradient(x->f(x,y), x) |> display
g(x,y) = sum(sum, [(x[i],y[i]) for i=1:length(x)])
Zygote.gradient(x->g(x,y), x) |> display
I am getting the following error
([1.0, 1.0, 1.0],)
ERROR: LoadError: MethodError: no method matching ChainRulesCore.ProjectTo(::Tuple{Float64, Float64})
Closest candidates are:
ChainRulesCore.ProjectTo(::LinearAlgebra.UnitLowerTriangular) at /home/niklas/.julia/packages/ChainRulesCore/Voykb/src/projection.jl:336
ChainRulesCore.ProjectTo(::LinearAlgebra.UpperTriangular) at /home/niklas/.julia/packages/ChainRulesCore/Voykb/src/projection.jl:336
ChainRulesCore.ProjectTo(::LinearAlgebra.SymTridiagonal{T, V} where V<:AbstractVector{T}) where T<:Number at /home/niklas/.julia/packages/ChainRulesCore/Voykb/src/projection.jl:373
...
Stacktrace:
[1] iterate
@ ./generator.jl:47 [inlined]
[2] _collect
@ ./array.jl:691 [inlined]
[3] collect_similar
@ ./array.jl:606 [inlined]
[4] map
@ ./abstractarray.jl:2294 [inlined]
[5] ChainRulesCore.ProjectTo(xs::Vector{Tuple{Float64, Float64}})
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/Voykb/src/projection.jl:192
[6] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context}, ::typeof(sum), f::Function, xs::Vector{Tuple{Float64, Float64}}; dims::Function)
@ ChainRules ~/.julia/packages/ChainRules/5iZFH/src/rulesets/Base/mapreduce.jl:74
[7] rrule
@ ~/.julia/packages/ChainRules/5iZFH/src/rulesets/Base/mapreduce.jl:69 [inlined]
[8] chain_rrule
@ ~/.julia/packages/Zygote/l3aNG/src/compiler/chainrules.jl:152 [inlined]
[9] macro expansion
@ ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0 [inlined]
[10] _pullback(::Zygote.Context, ::typeof(sum), ::typeof(sum), ::Vector{Tuple{Float64, Float64}})
@ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:9
[11] _pullback
@ ~/_asdf.jl:10 [inlined]
[12] _pullback(::Zygote.Context, ::typeof(g), ::Vector{Float64}, ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0
[13] _pullback
@ ~/_asdf.jl:11 [inlined]
[14] _pullback(ctx::Zygote.Context, f::var"#7#8", args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0
[15] _pullback(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface.jl:34
[16] pullback(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface.jl:40
[17] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface.jl:75
[18] top-level scope
@ ~/_asdf.jl:11
[19] include(fname::String)
@ Base.MainInclude ./client.jl:444
[20] top-level scope
@ REPL[1]:1
in expression starting at /home/niklas/_asdf.jl:11