Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Need adjoint for reinterpret SVector #606

Open
jenkspt opened this issue Apr 9, 2022 · 4 comments
Open

Need adjoint for reinterpret SVector #606

jenkspt opened this issue Apr 9, 2022 · 4 comments

Comments

@jenkspt
Copy link

jenkspt commented Apr 9, 2022

Minimum working example:

f(x) = sum(sum(reinterpret(SVector{size(x, 1), eltype(x)}, x)))
Zygote.gradient(f, rand(3, 10))
ERROR: Need an adjoint for constructor Base.ReinterpretArray{SVector{3, Float64}, 2, Float64, Matrix{Float64}, false}. Gradient is of type FillArrays.Fill{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:33
 [2] (::Zygote.Jnew{Base.ReinterpretArray{SVector{3, Float64}, 2, Float64, Matrix{Float64}, false}, Nothing, false})(Δ::FillArrays.Fill{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
   @ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/lib.jl:323
 [3] (::Zygote.var"#1811#back#235"{Zygote.Jnew{Base.ReinterpretArray{SVector{3, Float64}, 2, Float64, Matrix{Float64}, false}, Nothing, false}})(Δ::FillArrays.Fill{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
   @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
 [4] Pullback
   @ ./reinterpretarray.jl:47 [inlined]
 [5] (::typeof(∂(reinterpret)))(Δ::FillArrays.Fill{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
   @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
 [6] Pullback
   @ ./REPL[95]:1 [inlined]
 [7] (::Zygote.var"#52#53"{typeof(∂(f))})(Δ::Float64)
   @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:41
 [8] gradient(::Function, ::Matrix{Float64}, ::Vararg{Any})
   @ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:76
 [9] top-level scope
   @ REPL[98]:1
@jenkspt
Copy link
Author

jenkspt commented Apr 9, 2022

I realize that a specific rule for SVector probably shouldn't be added to chain rules core -- but a more general solution for any reinterpreted composite types is the goal.

@mcabbott
Copy link
Member

Would be nice. Instead of depending on StaticArrays, you could probably just specify abstract types, something like this:

function ChainRules.rrule(::typeof(reinterpret), ::Type{T}, x::AbstractArray{S}) where {T<:AbstractArray{S},S}
    unreinterpret(dy) = (NoTangent(), NoTangent(), reinterpret(S, dy))
    reinterpret(T, x), unreinterpret
end

and another signature for the other way? This would prevent the rule from acting on things like reinterpret(Float32, [1.0, 2.0]), although also things like reinterpreting to remove units.

@jenkspt
Copy link
Author

jenkspt commented Apr 10, 2022

Does it make sense to add your suggestion to ChainRules?

@mcabbott
Copy link
Member

Yes, if it works... probably someone just has to tidy it up & figure out how to get tests working, etc.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants