Skip to content

Commit

Permalink
special-case Complex
Browse files Browse the repository at this point in the history
fix #1
  • Loading branch information
Cossio committed Jun 9, 2020
1 parent ed08e72 commit 6601382
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
17 changes: 15 additions & 2 deletions src/adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,22 @@ end
return sa, back
end

@adjoint function (::Type{SA})(a::AbstractArray{T}) where {T,SA<:StructArray}
@adjoint function (::Type{SA})(a::A) where {T,SA<:StructArray,A<:AbstractArray{T}}
sa = SA(a)
back::NamedTuple) = ([(; (p => Δ[p][i] for p in propertynames(Δ))...) for i in eachindex(a)],)
function back(Δsa)
Δa = [(; (p => Δsa[p][i] for p in propertynames(Δsa))...) for i in eachindex(a)]
return (Δa,)
end
return sa, back
end

# Must special-case for Complex (#1)
@adjoint function (::Type{SA})(a::A) where {T<:Complex,SA<:StructArray,A<:AbstractArray{T}}
sa = SA(a)
function back(Δsa) # dsa -> da
Δa = [Complex(Δsa.re[i], Δsa.im[i]) for i in eachindex(a)]
(Δa,)
end
return sa, back
end

Expand Down
8 changes: 7 additions & 1 deletion test/adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,13 @@ end == ([1.0, 1.0], nothing)
sum(S.x) + sum(S.y)
end == ([1.0, 1.0], [1.0, 1.0])

#1
@test gradient(randn(2), randn(2)) do X, Y
S = StructArray(Complex.(X, Y))
sum(S.re) + sum(S.im)
end == ([1.0, 1.0], [1.0, 1.0])

# @test gradient(randn(2), randn(2)) do X, Y
# S = StructArray{Complex}((re = X, im = Y))
# sum(abs.(S))
# end
# end

0 comments on commit 6601382

Please sign in to comment.