Skip to content

Commit

Permalink
getindex
Browse files Browse the repository at this point in the history
  • Loading branch information
cossio committed Jun 8, 2020
1 parent 0c1c82c commit d3f9371
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 4 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Jorge Fernandez-de-Cossio-Diaz <j.cossio.diaz@gmail.com>"]
version = "0.1.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand Down
4 changes: 2 additions & 2 deletions src/ZygoteStructArrays.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module ZygoteStructArrays

using Zygote, StructArrays
using Zygote: @adjoint, literal_getproperty
using Zygote, StructArrays, LinearAlgebra
using Zygote: @adjoint, Numeric, literal_getproperty, accum

include("adjoints.jl")

Expand Down
26 changes: 24 additions & 2 deletions src/adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ end

@adjoint function (::Type{T})(t::NamedTuple) where {T<:StructArray}
result = T(t)
back::NamedTuple) = (fieldarrays(T(Δ)),)
back::NamedTuple) = (NamedTuple{propertynames(result)}),)
function back::AbstractArray{<:NamedTuple})
nt = (; (p => [dx[p] for dx in Δ] for p in propertynames(result))...)
return back(nt)
Expand All @@ -22,7 +22,7 @@ end
key::Symbol
result = getproperty(sa, key)
function back::AbstractArray)
nt = (; (p => zero(getproperty(sa, p)) for p in propertynames(sa))...)
nt = (; (k => zero(v) for (k,v) in pairs(fieldarrays(sa)))...)
return (Base.setindex(nt, Δ, key), nothing)
end
return result, back
Expand All @@ -33,3 +33,25 @@ end
back::NamedTuple) = (values(T(Δ)),)
return result, back
end

@adjoint Base.getindex(x::StructArray, i...) = x[i...], Δ -> ∇getindex(x, i, Δ)
@adjoint Base.view(x::StructArray, i...) = view(x, i...), Δ -> ∇getindex(x, i, Δ)
function ∇getindex(x::StructArray, i, Δ::NamedTuple)
dx = (; (k => ∇getindex(v, i, Δ[k]) for (k,v) in pairs(fieldarrays(x)))...)
di = map(_ -> nothing, i)
return (dx, map(_ -> nothing, i)...)
end
# based on
# https://github.com/FluxML/Zygote.jl/blob/64c02dccc698292c548c334a15ce2100a11403e2/src/lib/array.jl#L41
∇getindex(x::AbstractArray, i, Δ::Nothing) = nothing
function ∇getindex(x::AbstractArray, i, Δ)
if i isa NTuple{<:Any, Integer}
dx = Zygote._zero(x, typeof(Δ))
dx[i...] = Δ
else
dx = Zygote._zero(x, eltype(Δ))
dxv = view(dx, i...)
dxv .= Zygote.accum.(dxv, Zygote._droplike(Δ, dxv))
end
return dx
end
15 changes: 15 additions & 0 deletions test/adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,18 @@ end == ([1.0, 1.0], [2.0, 2.0])
S = StructArray{Complex}((im = Y, re = X))
sum(S).re + 2sum(S).im
end == ([1.0, 1.0], [2.0, 2.0])

@test gradient(randn(2), randn(2)) do X, Y
S = StructArray{Complex}((re = X, im = Y))
S[1].re
end == ([1.0, 0.0], nothing)

@test gradient(randn(2), randn(2)) do X, Y
S = StructArray{Complex}((re = X, im = Y))
S[1].re + S[1].im
end == ([1.0, 0.0], [1.0, 0.0])

@test gradient(randn(2), randn(2)) do X, Y
S = StructArray{Complex}((re = X, im = Y))
S[1].re + S[2].re
end == ([1.0, 1.0], nothing)

0 comments on commit d3f9371

Please sign in to comment.