Skip to content

Commit

Permalink
array constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
cossio committed Jun 8, 2020
1 parent d3f9371 commit ed08e72
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 30 deletions.
1 change: 1 addition & 0 deletions src/ZygoteStructArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module ZygoteStructArrays
using Zygote, StructArrays, LinearAlgebra
using Zygote: @adjoint, Numeric, literal_getproperty, accum

include("others.jl")
include("adjoints.jl")

end
57 changes: 28 additions & 29 deletions src/adjoints.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
@adjoint function (::Type{T})(t::Tuple) where {T<:StructArray}
result = T(t)
@adjoint function (::Type{SA})(t::Tuple) where {SA<:StructArray}
sa = SA(t)
back::NamedTuple) = (values(Δ),)
function back::AbstractArray{<:NamedTuple})
nt = (; (p => [getproperty(dx, p) for dx in Δ] for p in propertynames(result))...)
nt = (; (p => [getproperty(dx, p) for dx in Δ] for p in propertynames(sa))...)
return back(nt)
end
return result, back
return sa, back
end

@adjoint function (::Type{T})(t::NamedTuple) where {T<:StructArray}
result = T(t)
back::NamedTuple) = (NamedTuple{propertynames(result)}(Δ),)
@adjoint function (::Type{SA})(t::NamedTuple) where {SA<:StructArray}
sa = SA(t)
back::NamedTuple) = (NamedTuple{propertynames(sa)}(Δ),)
function back::AbstractArray{<:NamedTuple})
nt = (; (p => [dx[p] for dx in Δ] for p in propertynames(result))...)
return back(nt)
back((; (p => [dx[p] for dx in Δ] for p in propertynames(sa))...))
end
return result, back
return sa, back
end

@adjoint function (::Type{SA})(a::AbstractArray{T}) where {T,SA<:StructArray}
sa = SA(a)
back::NamedTuple) = ([(; (p => Δ[p][i] for p in propertynames(Δ))...) for i in eachindex(a)],)
return sa, back
end

@adjoint function literal_getproperty(sa::StructArray, ::Val{key}) where {key}
Expand All @@ -28,30 +33,24 @@ end
return result, back
end

@adjoint function (::Type{T})(t::Tuple) where {K,T<:NamedTuple{K}}
result = T(t)
back::NamedTuple) = (values(T(Δ)),)
return result, back
end

@adjoint Base.getindex(x::StructArray, i...) = x[i...], Δ -> ∇getindex(x, i, Δ)
@adjoint Base.view(x::StructArray, i...) = view(x, i...), Δ -> ∇getindex(x, i, Δ)
function ∇getindex(x::StructArray, i, Δ::NamedTuple)
dx = (; (k => ∇getindex(v, i, Δ[k]) for (k,v) in pairs(fieldarrays(x)))...)
@adjoint Base.getindex(sa::StructArray, i...) = sa[i...], Δ -> ∇getindex(sa,i,Δ)
@adjoint Base.view(sa::StructArray, i...) = view(sa, i...), Δ -> ∇getindex(sa,i,Δ)
function ∇getindex(sa::StructArray, i, Δ::NamedTuple)
dsa = (; (k => ∇getindex(v,i,Δ[k]) for (k,v) in pairs(fieldarrays(sa)))...)
di = map(_ -> nothing, i)
return (dx, map(_ -> nothing, i)...)
return (dsa, map(_ -> nothing, i)...)
end
# based on
# https://github.com/FluxML/Zygote.jl/blob/64c02dccc698292c548c334a15ce2100a11403e2/src/lib/array.jl#L41
∇getindex(x::AbstractArray, i, Δ::Nothing) = nothing
function ∇getindex(x::AbstractArray, i, Δ)
∇getindex(a::AbstractArray, i, Δ::Nothing) = nothing
function ∇getindex(a::AbstractArray, i, Δ)
if i isa NTuple{<:Any, Integer}
dx = Zygote._zero(x, typeof(Δ))
dx[i...] = Δ
da = Zygote._zero(a, typeof(Δ))
da[i...] = Δ
else
dx = Zygote._zero(x, eltype(Δ))
dxv = view(dx, i...)
dxv .= Zygote.accum.(dxv, Zygote._droplike(Δ, dxv))
da = Zygote._zero(a, eltype(Δ))
dav = view(da, i...)
dav .= Zygote.accum.(dav, Zygote._droplike(Δ, dav))
end
return dx
return da
end
14 changes: 14 additions & 0 deletions src/others.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#= These adjoints should probably be in Zygote =#

@adjoint function (::Type{NT})(t::Tuple) where {K,NT<:NamedTuple{K}}
nt = NT(t)
back::NamedTuple) = (values(NT(Δ)),)
return nt, back
end

# # https://github.com/FluxML/Zygote.jl/issues/680
# @adjoint function (T::Type{<:Complex})(re, im)
# back(Δ::Complex) = (nothing, real(Δ), imag(Δ))
# back(Δ::NamedTuple) = (nothing, Δ.re, Δ.im)
# T(re, im), back
# end
16 changes: 15 additions & 1 deletion test/adjoints.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
using Test, Random, Zygote, StructArrays, ZygoteStructArrays

struct Point
x::Float64; y::Float64
end

@test gradient(randn(2), randn(2)) do X,Y
S = StructArray{Complex}((X,Y))
sum(S.re) + 2sum(S.im)
Expand Down Expand Up @@ -43,4 +47,14 @@ end == ([1.0, 0.0], [1.0, 0.0])
@test gradient(randn(2), randn(2)) do X, Y
S = StructArray{Complex}((re = X, im = Y))
S[1].re + S[2].re
end == ([1.0, 1.0], nothing)
end == ([1.0, 1.0], nothing)

@test gradient(randn(2), randn(2)) do X, Y
S = StructArray(Point.(X, Y))
sum(S.x) + sum(S.y)
end == ([1.0, 1.0], [1.0, 1.0])

# @test gradient(randn(2), randn(2)) do X, Y
# S = StructArray{Complex}((re = X, im = Y))
# sum(abs.(S))
# end

0 comments on commit ed08e72

Please sign in to comment.