Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is the sensitivity of a struct always a NamedTuple? #680

Open
cossio opened this issue Jun 8, 2020 · 0 comments
Open

Is the sensitivity of a struct always a NamedTuple? #680

cossio opened this issue Jun 8, 2020 · 0 comments

Comments

@cossio
Copy link
Contributor

cossio commented Jun 8, 2020

Shouldn't the sensitivity of a struct always be a NamedTuple? Why Complex is treated differently:

@adjoint (T::Type{<:Complex})(re, im) = T(re, im), c̄ -> (nothing, real(c̄), imag(c̄))

That line calls real(c̄), imag(c̄) , which assume that these methods work for . They don't if is a NamedTuple.

So how about replacing that with

@adjoint function (T::Type{<:Complex})(re, im)
	back(c::Complex) = (nothing, real(c), imag(c))
	back(c::NamedTuple) = (nothing, c.re, c.im)
	T(re, im), back
end

This leads to downstream inconsistencies. Here is the example where this gave me trouble.

I want to define the adjoint of the sa = StructArray(a::Array{T}) constructor. So I need a Δa = back(Δsa) function, that takes sensitivities with respect to sa and outputs the sensitivities with respect to a. The Δsa are a NamedTuple, where the fields are arrays with names matching the fields of the struct T. Under the assumption that the sensitivity with respect to a struct is represented as a NamedTuple, I can write the desired adjoint as follows:

@adjoint function (::Type{SA})(a::A) where {T,SA<:StructArray,A<:AbstractArray{T}}
    sa = SA(a)
    function back(Δsa)
        Δa = [(; (p => Δsa[p][i] for p in propertynames(Δsa))...) for i in eachindex(a)]
        return (Δa,)
    end
    return sa, back
end

Here back(Δsa) returns an array Δa of NamedTuples containing the sensitivities of each entry in a.

But this does not work for Complex (though it works for any other struct I have tried). The reason is that some functions expect that the sensitivity with respect to a Complex is another Complex instead of a NamedTuple.

The only solution I found is to write a special @adjoint for the StructArray constructor when the eltype is Complex (see https://github.com/cossio/ZygoteStructArrays.jl).

I think that that we need a general rule, or we get inconsistencies.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant