You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
That line calls real(c̄), imag(c̄) , which assume that these methods work for c̄ . They don't if c̄ is a NamedTuple.
So how about replacing that with
@adjointfunction (T::Type{<:Complex})(re, im)
back(c::Complex) = (nothing, real(c), imag(c))
back(c::NamedTuple) = (nothing, c.re, c.im)
T(re, im), back
end
This leads to downstream inconsistencies. Here is the example where this gave me trouble.
I want to define the adjoint of the sa = StructArray(a::Array{T}) constructor. So I need a Δa = back(Δsa) function, that takes sensitivities with respect to sa and outputs the sensitivities with respect to a. The Δsa are a NamedTuple, where the fields are arrays with names matching the fields of the struct T. Under the assumption that the sensitivity with respect to a struct is represented as a NamedTuple, I can write the desired adjoint as follows:
@adjointfunction (::Type{SA})(a::A) where {T,SA<:StructArray,A<:AbstractArray{T}}
sa =SA(a)
functionback(Δsa)
Δa = [(; (p => Δsa[p][i] for p inpropertynames(Δsa))...) for i ineachindex(a)]
return (Δa,)
endreturn sa, back
end
Here back(Δsa) returns an array Δa of NamedTuples containing the sensitivities of each entry in a.
But this does not work for Complex (though it works for any other struct I have tried). The reason is that some functions expect that the sensitivity with respect to a Complex is another Complex instead of a NamedTuple.
Shouldn't the sensitivity of a struct always be a
NamedTuple
? WhyComplex
is treated differently:Zygote.jl/src/lib/number.jl
Line 20 in 64c02dc
That line calls
real(c̄)
,imag(c̄)
, which assume that these methods work forc̄
. They don't ifc̄
is aNamedTuple
.So how about replacing that with
This leads to downstream inconsistencies. Here is the example where this gave me trouble.
I want to define the adjoint of the
sa = StructArray(a::Array{T})
constructor. So I need aΔa = back(Δsa)
function, that takes sensitivities with respect tosa
and outputs the sensitivities with respect toa
. TheΔsa
are aNamedTuple
, where the fields are arrays with names matching the fields of the structT
. Under the assumption that the sensitivity with respect to a struct is represented as aNamedTuple
, I can write the desired adjoint as follows:Here
back(Δsa)
returns an arrayΔa
ofNamedTuples
containing the sensitivities of each entry ina
.But this does not work for
Complex
(though it works for any other struct I have tried). The reason is that some functions expect that the sensitivity with respect to aComplex
is anotherComplex
instead of aNamedTuple
.The only solution I found is to write a special
@adjoint
for theStructArray
constructor when the eltype isComplex
(see https://github.com/cossio/ZygoteStructArrays.jl).I think that that we need a general rule, or we get inconsistencies.
The text was updated successfully, but these errors were encountered: