Skip to content

Cannot take gradient of function of jacobian #305

Closed
@sethaxen

Description

Taking a gradient of a function of a jacobian fails. Here's an MWE:

using Zygote
f(x) = [cos(x[1]) * sin(x[2]), sin(x[1]) * cos(x[2])]
jac(x) = last(Zygote.forward_jacobian(f, x))
g(x) = sum(jac(x))
x =/4, π/3]
Zygote.gradient(g, x)

This fails with

julia> Zygote.gradient(g, x)
ERROR: Compiling Tuple{typeof(Zygote.forward_jacobian),typeof(f),Array{Float64,1},Val{2}}: DimensionMismatch("dimensions must match")
Stacktrace:
 [1] promote_shape at ./indices.jl:154 [inlined]
 [2] _promote_shape at ./iterators.jl:317 [inlined]
 [3] axes at ./iterators.jl:316 [inlined]
 [4] map at ./tuple.jl:166 [inlined]
 [5] axes at ./iterators.jl:316 [inlined]
 [6] _array_for(::Type{Array{Any,1}}, ::Base.Iterators.Zip{Tuple{Array{Any,1},Base.Iterators.Zip{Tuple{Array{Any,1},Array{Any,1}}}}}, ::Base.HasShape{1}) at ./array.jl:598
 [7] collect(::Base.Generator{Base.Iterators.Zip{Tuple{Array{Any,1},Base.Iterators.Zip{Tuple{Array{Any,1},Array{Any,1}}}}},getfield(IRTools, Symbol("##126#132"))}) at ./array.jl:611
 [8] prune!(::IRTools.IR) at /Users/saxen/.julia/packages/IRTools/y0Ot8/src/passes/passes.jl:121
 [9] |> at ./operators.jl:813 [inlined]
 [10] #IR#23(::Bool, ::Bool, ::Type, ::IRTools.Meta) at /Users/saxen/.julia/packages/IRTools/y0Ot8/src/ir/wrap.jl:195
 [11] Type at /Users/saxen/.julia/packages/IRTools/y0Ot8/src/ir/wrap.jl:190 [inlined]
 [12] _lookup_grad(::Type) at /Users/saxen/.julia/packages/Zygote/bl1uE/src/compiler/emit.jl:117
 [13] #s2387#1312 at /Users/saxen/.julia/packages/Zygote/bl1uE/src/compiler/interface2.jl:20 [inlined]
 [14] #s2387#1312(::Any, ::Any, ::Any) at ./none:0
 [15] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any,N} where N) at ./boot.jl:522
 [16] forward_jacobian at /Users/saxen/.julia/packages/Zygote/bl1uE/src/lib/forward.jl:38 [inlined]
 [17] (::typeof((Zygote.forward_jacobian)))(::Tuple{Nothing,FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}}) at /Users/saxen/.julia/packages/Zygote/bl1uE/src/compiler/interface2.jl:0
 [18] jac at ./REPL[3]:1 [inlined]
 [19] (::typeof((jac)))(::FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at /Users/saxen/.julia/packages/Zygote/bl1uE/src/compiler/interface2.jl:0
 [20] g at ./REPL[4]:1 [inlined]
 [21] (::getfield(Zygote, Symbol("##32#33")){typeof((g))})(::Float64) at /Users/saxen/.julia/packages/Zygote/bl1uE/src/compiler/interface.jl:38
 [22] gradient(::Function, ::Array{Float64,1}) at /Users/saxen/.julia/packages/Zygote/bl1uE/src/compiler/interface.jl:47
 [23] top-level scope at none:0

It's not obvious to me from this error message what the problem is. I'd appreciate any insight.

Version info:

julia> versioninfo()
Julia Version 1.1.0
Commit 80516ca202 (2019-01-21 21:24 UTC)
Platform Info:
  OS: macOS (x86_64-apple-darwin14.5.0)
  CPU: Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-6.0.1 (ORCJIT, haswell)

(v1.1) pkg> st --manifest Zygote
    Status `~/.julia/environments/v1.1/Manifest.toml`
  [b552c78f] DiffRules v0.0.10
  [7a1cc6ca] FFTW v0.3.0
  [1a297f60] FillArrays v0.6.4
  [f6369f11] ForwardDiff v0.10.3
  [7869d1d1] IRTools v0.2.2 #master (https://github.com/MikeInnes/IRTools.jl.git)
  [1914dd2f] MacroTools v0.5.1
  [872c559c] NNlib v0.6.0
  [77ba4419] NaNMath v0.3.2
  [ae029012] Requires v0.5.2
  [276daf66] SpecialFunctions v0.7.2
  [e88e6eb3] Zygote v0.3.2 #master (https://github.com/FluxML/Zygote.jl)
  [700de1a5] ZygoteRules v0.1.0
  [b77e0a4c] InteractiveUtils 
  [37e2e46d] LinearAlgebra 
  [9a3f8284] Random 
  [10745b16] Statistics 

Metadata

Assignees

No one assigned

    Labels

    second orderzygote over zygote, or otherwise

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions