diff --git a/src/adjoints.jl b/src/adjoints.jl index 6e25cec..b9a0ce8 100644 --- a/src/adjoints.jl +++ b/src/adjoints.jl @@ -17,9 +17,22 @@ end return sa, back end -@adjoint function (::Type{SA})(a::AbstractArray{T}) where {T,SA<:StructArray} +@adjoint function (::Type{SA})(a::A) where {T,SA<:StructArray,A<:AbstractArray{T}} sa = SA(a) - back(Δ::NamedTuple) = ([(; (p => Δ[p][i] for p in propertynames(Δ))...) for i in eachindex(a)],) + function back(Δsa) + Δa = [(; (p => Δsa[p][i] for p in propertynames(Δsa))...) for i in eachindex(a)] + return (Δa,) + end + return sa, back +end + +# Must special-case for Complex (#1) +@adjoint function (::Type{SA})(a::A) where {T<:Complex,SA<:StructArray,A<:AbstractArray{T}} + sa = SA(a) + function back(Δsa) # dsa -> da + Δa = [Complex(Δsa.re[i], Δsa.im[i]) for i in eachindex(a)] + (Δa,) + end return sa, back end diff --git a/test/adjoints.jl b/test/adjoints.jl index 24e5a21..34d91a6 100644 --- a/test/adjoints.jl +++ b/test/adjoints.jl @@ -54,7 +54,13 @@ end == ([1.0, 1.0], nothing) sum(S.x) + sum(S.y) end == ([1.0, 1.0], [1.0, 1.0]) +#1 +@test gradient(randn(2), randn(2)) do X, Y + S = StructArray(Complex.(X, Y)) + sum(S.re) + sum(S.im) +end == ([1.0, 1.0], [1.0, 1.0]) + # @test gradient(randn(2), randn(2)) do X, Y # S = StructArray{Complex}((re = X, im = Y)) # sum(abs.(S)) -# end \ No newline at end of file +# end