Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zygote and StructArrays do not play nicely #257

Open
ptiede opened this issue Dec 8, 2022 · 0 comments
Open

Zygote and StructArrays do not play nicely #257

ptiede opened this issue Dec 8, 2022 · 0 comments

Comments

@ptiede
Copy link

ptiede commented Dec 8, 2022

Hi All,

I noticed during some of my development that StructArrays and Zygote seem to be broken. It seems that if you access a property of the struct array in a function Zygote/ChainRules doesn't maintain the StructArray type and this causes an issue during gradient accumulation. A MWE is

using Zygote
using StructArrays

f(p) = p.U^2 + p.V^2
l1(x) = sum(f, x) + sum(x.U)
l2(x) = sum(f.(x)  + x.U)
l3(x) = sum(f.(x) .+ x.U)


x = StructArray{NamedTuple{(:U,:V)}}((U=rand(10), V=rand(10)))

Zygote.gradient(l1, x) 
ERROR: MethodError: no method matching +(::NamedTuple{(:components,), Tuple{NamedTuple{(:U, :V), Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Nothing}}}}, ::Vector{NamedTuple{(:U, :V), Tuple{Float64, Float64}}})
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at operators.jl:591
  +(::ChainRulesCore.AbstractThunk, ::Any) at ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_arithmetic.jl:122
  +(::Array, ::Array...) at arraymath.jl:12
  ...
Stacktrace:
  [1] accum(x::NamedTuple{(:components,), Tuple{NamedTuple{(:U, :V), Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Nothing}}}}, y::Vector{NamedTuple{(:U, :V), Tuple{Float64, Float64}}})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/lib/lib.jl:17
  [2] collect_similar
    @ ./array.jl:716 [inlined]
  [3] map
    @ ./abstractarray.jl:2933 [inlined]
  [4] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:122 [inlined]
  [5] map
    @ ./tuple.jl:223 [inlined]
  [6] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:106 [inlined]
  [7] ZBack
    @ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:206 [inlined]
  [8] Pullback
    @ ~/struct_array_issue.jl:5 [inlined]
  [9] (::typeof((l0)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [10] (::Zygote.var"#60#61"{typeof((l0))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:45
 [11] gradient(f::Function, args::StructVector{NamedTuple{(:U, :V)}, NamedTuple{(:U, :V), Tuple{Vector{Float64}, Vector{Float64}}}, Int64})
    @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:97
 [12] top-level scope
    @ REPL[50]:1

#####################################################################
Zygote.gradient(l2, x)
ERROR: MethodError: no method matching +(::NamedTuple{(:components,), Tuple{NamedTuple{(:U, :V), Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Nothing}}}}, ::Vector{NamedTuple{(:U, :V), Tuple{Float64, Float64}}})
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at operators.jl:591
  +(::ChainRulesCore.AbstractThunk, ::Any) at ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_arithmetic.jl:122
  +(::Array, ::Array...) at arraymath.jl:12
  ...
Stacktrace:
 [1] accum(x::NamedTuple{(:components,), Tuple{NamedTuple{(:U, :V), Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Nothing}}}}, y::Vector{NamedTuple{(:U, :V), Tuple{Float64, Float64}}})
   @ Zygote ~/.julia/packages/Zygote/SmJK6/src/lib/lib.jl:17
 [2] Pullback
   @ ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:43 [inlined]
 [3] Pullback
   @ ~/struct_array_issue.jl:6 [inlined]
 [4] (::typeof((l1)))(Δ::Float64)
   @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [5] (::Zygote.var"#60#61"{typeof((l1))})(Δ::Float64)
   @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:45
 [6] gradient(f::Function, args::StructVector{NamedTuple{(:U, :V)}, NamedTuple{(:U, :V), Tuple{Vector{Float64}, Vector{Float64}}}, Int64})
   @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:97
 [7] top-level scope
   @ REPL[51]:1


##########################################################################
Zygote.gradient(l3, x)
ERROR: MethodError: no method matching +(::NamedTuple{(:components,), Tuple{NamedTuple{(:U, :V), Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Nothing}}}}, ::Vector{NamedTuple{(:U, :V), Tuple{Float64, Float64}}})
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at operators.jl:591
  +(::ChainRulesCore.AbstractThunk, ::Any) at ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_arithmetic.jl:122
  +(::Array, ::Array...) at arraymath.jl:12
  ...
Stacktrace:
 [1] accum(x::NamedTuple{(:components,), Tuple{NamedTuple{(:U, :V), Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, Nothing}}}}, y::Vector{NamedTuple{(:U, :V), Tuple{Float64, Float64}}})
   @ Zygote ~/.julia/packages/Zygote/SmJK6/src/lib/lib.jl:17
 [2] Pullback
   @ ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:43 [inlined]
 [3] Pullback
   @ ~/struct_array_issue.jl:7 [inlined]
 [4] (::typeof((l2)))(Δ::Float64)
   @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [5] (::Zygote.var"#60#61"{typeof((l2))})(Δ::Float64)
   @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:45
 [6] gradient(f::Function, args::StructVector{NamedTuple{(:U, :V)}, NamedTuple{(:U, :V), Tuple{Vector{Float64}, Vector{Float64}}}, Int64})
   @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:97
 [7] top-level scope
   @ REPL[52]:1

On the other hand

l0(x) = sum(f, x)

Zygote.gradient(l0, x)

Seems to work fine and return a StructArray.

I have been playing with ChainRulesCore and ProjectTo to see if I could get this to work but I am not sure the best way to store everything internally.

Working environment

julia> Pkg.status()
Status /tmp/jl_sssQXD/Project.toml
[09ab397b] StructArrays v0.6.13 https://github.com/JuliaArrays/StructArrays.jl.git#master
[e88e6eb3] Zygote v0.6.51 [09ab397b] StructArrays v0.6.13 https://github.com/JuliaArrays/StructArrays.jl.git#master
[e88e6eb3] Zygote v0.6.51

julia> versioninfo()
Julia Version 1.8.3
Commit 0434deb161e (2022-11-14 20:14 UTC)
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 32 × AMD Ryzen 9 7950X 16-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-13.0.1 (ORCJIT, znver3)
Threads: 1 on 32 virtual cores
Environment:
JULIA_EDITOR = code
JULIA_NUM_THREADS = 1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant