From d3f9371cf45a8ba6189d3d8b98e9b35ed2f60e4c Mon Sep 17 00:00:00 2001 From: cossio Date: Mon, 8 Jun 2020 14:03:57 +0200 Subject: [PATCH] getindex --- Project.toml | 1 + src/ZygoteStructArrays.jl | 4 ++-- src/adjoints.jl | 26 ++++++++++++++++++++++++-- test/adjoints.jl | 15 +++++++++++++++ 4 files changed, 42 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 668bc22..5dd9491 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Jorge Fernandez-de-Cossio-Diaz "] version = "0.1.0" [deps] +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/src/ZygoteStructArrays.jl b/src/ZygoteStructArrays.jl index eb40db6..7dd8c08 100644 --- a/src/ZygoteStructArrays.jl +++ b/src/ZygoteStructArrays.jl @@ -1,7 +1,7 @@ module ZygoteStructArrays -using Zygote, StructArrays -using Zygote: @adjoint, literal_getproperty +using Zygote, StructArrays, LinearAlgebra +using Zygote: @adjoint, Numeric, literal_getproperty, accum include("adjoints.jl") diff --git a/src/adjoints.jl b/src/adjoints.jl index 48258c5..7fa278e 100644 --- a/src/adjoints.jl +++ b/src/adjoints.jl @@ -10,7 +10,7 @@ end @adjoint function (::Type{T})(t::NamedTuple) where {T<:StructArray} result = T(t) - back(Δ::NamedTuple) = (fieldarrays(T(Δ)),) + back(Δ::NamedTuple) = (NamedTuple{propertynames(result)}(Δ),) function back(Δ::AbstractArray{<:NamedTuple}) nt = (; (p => [dx[p] for dx in Δ] for p in propertynames(result))...) return back(nt) @@ -22,7 +22,7 @@ end key::Symbol result = getproperty(sa, key) function back(Δ::AbstractArray) - nt = (; (p => zero(getproperty(sa, p)) for p in propertynames(sa))...) + nt = (; (k => zero(v) for (k,v) in pairs(fieldarrays(sa)))...) return (Base.setindex(nt, Δ, key), nothing) end return result, back @@ -33,3 +33,25 @@ end back(Δ::NamedTuple) = (values(T(Δ)),) return result, back end + +@adjoint Base.getindex(x::StructArray, i...) = x[i...], Δ -> ∇getindex(x, i, Δ) +@adjoint Base.view(x::StructArray, i...) = view(x, i...), Δ -> ∇getindex(x, i, Δ) +function ∇getindex(x::StructArray, i, Δ::NamedTuple) + dx = (; (k => ∇getindex(v, i, Δ[k]) for (k,v) in pairs(fieldarrays(x)))...) + di = map(_ -> nothing, i) + return (dx, map(_ -> nothing, i)...) +end +# based on +# https://github.com/FluxML/Zygote.jl/blob/64c02dccc698292c548c334a15ce2100a11403e2/src/lib/array.jl#L41 +∇getindex(x::AbstractArray, i, Δ::Nothing) = nothing +function ∇getindex(x::AbstractArray, i, Δ) + if i isa NTuple{<:Any, Integer} + dx = Zygote._zero(x, typeof(Δ)) + dx[i...] = Δ + else + dx = Zygote._zero(x, eltype(Δ)) + dxv = view(dx, i...) + dxv .= Zygote.accum.(dxv, Zygote._droplike(Δ, dxv)) + end + return dx +end diff --git a/test/adjoints.jl b/test/adjoints.jl index 32baa77..cb1e8b6 100644 --- a/test/adjoints.jl +++ b/test/adjoints.jl @@ -29,3 +29,18 @@ end == ([1.0, 1.0], [2.0, 2.0]) S = StructArray{Complex}((im = Y, re = X)) sum(S).re + 2sum(S).im end == ([1.0, 1.0], [2.0, 2.0]) + +@test gradient(randn(2), randn(2)) do X, Y + S = StructArray{Complex}((re = X, im = Y)) + S[1].re +end == ([1.0, 0.0], nothing) + +@test gradient(randn(2), randn(2)) do X, Y + S = StructArray{Complex}((re = X, im = Y)) + S[1].re + S[1].im +end == ([1.0, 0.0], [1.0, 0.0]) + +@test gradient(randn(2), randn(2)) do X, Y + S = StructArray{Complex}((re = X, im = Y)) + S[1].re + S[2].re +end == ([1.0, 1.0], nothing) \ No newline at end of file