Closed
Description
Taking a gradient of a function of a jacobian fails. Here's an MWE:
using Zygote
f(x) = [cos(x[1]) * sin(x[2]), sin(x[1]) * cos(x[2])]
jac(x) = last(Zygote.forward_jacobian(f, x))
g(x) = sum(jac(x))
x = [π/4, π/3]
Zygote.gradient(g, x)
This fails with
julia> Zygote.gradient(g, x)
ERROR: Compiling Tuple{typeof(Zygote.forward_jacobian),typeof(f),Array{Float64,1},Val{2}}: DimensionMismatch("dimensions must match")
Stacktrace:
[1] promote_shape at ./indices.jl:154 [inlined]
[2] _promote_shape at ./iterators.jl:317 [inlined]
[3] axes at ./iterators.jl:316 [inlined]
[4] map at ./tuple.jl:166 [inlined]
[5] axes at ./iterators.jl:316 [inlined]
[6] _array_for(::Type{Array{Any,1}}, ::Base.Iterators.Zip{Tuple{Array{Any,1},Base.Iterators.Zip{Tuple{Array{Any,1},Array{Any,1}}}}}, ::Base.HasShape{1}) at ./array.jl:598
[7] collect(::Base.Generator{Base.Iterators.Zip{Tuple{Array{Any,1},Base.Iterators.Zip{Tuple{Array{Any,1},Array{Any,1}}}}},getfield(IRTools, Symbol("##126#132"))}) at ./array.jl:611
[8] prune!(::IRTools.IR) at /Users/saxen/.julia/packages/IRTools/y0Ot8/src/passes/passes.jl:121
[9] |> at ./operators.jl:813 [inlined]
[10] #IR#23(::Bool, ::Bool, ::Type, ::IRTools.Meta) at /Users/saxen/.julia/packages/IRTools/y0Ot8/src/ir/wrap.jl:195
[11] Type at /Users/saxen/.julia/packages/IRTools/y0Ot8/src/ir/wrap.jl:190 [inlined]
[12] _lookup_grad(::Type) at /Users/saxen/.julia/packages/Zygote/bl1uE/src/compiler/emit.jl:117
[13] #s2387#1312 at /Users/saxen/.julia/packages/Zygote/bl1uE/src/compiler/interface2.jl:20 [inlined]
[14] #s2387#1312(::Any, ::Any, ::Any) at ./none:0
[15] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any,N} where N) at ./boot.jl:522
[16] forward_jacobian at /Users/saxen/.julia/packages/Zygote/bl1uE/src/lib/forward.jl:38 [inlined]
[17] (::typeof(∂(Zygote.forward_jacobian)))(::Tuple{Nothing,FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}}) at /Users/saxen/.julia/packages/Zygote/bl1uE/src/compiler/interface2.jl:0
[18] jac at ./REPL[3]:1 [inlined]
[19] (::typeof(∂(jac)))(::FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at /Users/saxen/.julia/packages/Zygote/bl1uE/src/compiler/interface2.jl:0
[20] g at ./REPL[4]:1 [inlined]
[21] (::getfield(Zygote, Symbol("##32#33")){typeof(∂(g))})(::Float64) at /Users/saxen/.julia/packages/Zygote/bl1uE/src/compiler/interface.jl:38
[22] gradient(::Function, ::Array{Float64,1}) at /Users/saxen/.julia/packages/Zygote/bl1uE/src/compiler/interface.jl:47
[23] top-level scope at none:0
It's not obvious to me from this error message what the problem is. I'd appreciate any insight.
Version info:
julia> versioninfo()
Julia Version 1.1.0
Commit 80516ca202 (2019-01-21 21:24 UTC)
Platform Info:
OS: macOS (x86_64-apple-darwin14.5.0)
CPU: Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-6.0.1 (ORCJIT, haswell)
(v1.1) pkg> st --manifest Zygote
Status `~/.julia/environments/v1.1/Manifest.toml`
[b552c78f] DiffRules v0.0.10
[7a1cc6ca] FFTW v0.3.0
[1a297f60] FillArrays v0.6.4
[f6369f11] ForwardDiff v0.10.3
[7869d1d1] IRTools v0.2.2 #master (https://github.com/MikeInnes/IRTools.jl.git)
[1914dd2f] MacroTools v0.5.1
[872c559c] NNlib v0.6.0
[77ba4419] NaNMath v0.3.2
[ae029012] Requires v0.5.2
[276daf66] SpecialFunctions v0.7.2
[e88e6eb3] Zygote v0.3.2 #master (https://github.com/FluxML/Zygote.jl)
[700de1a5] ZygoteRules v0.1.0
[b77e0a4c] InteractiveUtils
[37e2e46d] LinearAlgebra
[9a3f8284] Random
[10745b16] Statistics