You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I realize that a specific rule for SVector probably shouldn't be added to chain rules core -- but a more general solution for any reinterpreted composite types is the goal.
Would be nice. Instead of depending on StaticArrays, you could probably just specify abstract types, something like this:
function ChainRules.rrule(::typeof(reinterpret), ::Type{T}, x::AbstractArray{S}) where {T<:AbstractArray{S},S}
unreinterpret(dy) = (NoTangent(), NoTangent(), reinterpret(S, dy))
reinterpret(T, x), unreinterpret
end
and another signature for the other way? This would prevent the rule from acting on things like reinterpret(Float32, [1.0, 2.0]), although also things like reinterpreting to remove units.
Minimum working example:
The text was updated successfully, but these errors were encountered: