-
Notifications
You must be signed in to change notification settings - Fork 65
Closed
Labels
ProjectTorelated to the projection functionalityrelated to the projection functionality
Description
I think we would generically like StaticArray arguments to have StaticArray tangents. The current behaviour depends on what path you hit:
julia> using StaticArrays, ChainRulesCore
julia> p = ProjectTo(SA[1,2,3]) # has SOneTo
ProjectTo{AbstractArray}(element = ProjectTo{Float64}(), axes = (SOneTo(3),))
julia> p([4,5,6]) # doesn't reshape
3-element Vector{Float64}:
4.0
5.0
6.0
julia> p(ones(3,1)) # does reshape
3-element SizedVector{3, Float64, Vector{Float64}} with indices SOneTo(3):
1.0
1.0
1.0
If we change this line to test ===
, then the first would be like the second:
ChainRulesCore.jl/src/projection.jl
Line 214 in 0e560c6
dy = if axes(dx) == project.axes |
Would this have any surprising downsides? It will also improve type-stability of things like this:
julia> p2 = ProjectTo(zeros(1:3))
ProjectTo{AbstractArray}(element = ProjectTo{Float64}(), axes = (OffsetArrays.IdOffsetRange(values=1:3, indices=1:3),))
julia> @code_warntype p2(ones(3))
Going the other way, if the argument is an ordinary Vector and the tangent is an SVector, then reshape won't do anything:
julia> reshape(SA[1,2,3], axes([4,5,6]))
3-element SVector{3, Int64} with indices SOneTo(3):
1
2
3
This is the case of FluxML/Zygote.jl#1093, I think. A simple example of where the rule accidentally makes a SVector is:
julia> gradient(x -> dot(SA[1,2,3], x), rand(3))[1] # dy = reshape(x .* ΔΩ, axes(y))
3-element SVector{3, Float64} with indices SOneTo(3):
1.0
2.0
3.0
Is it a good idea as a general rule to force those to be converted back to Array? And if so, is there an easy way to implement this without depending on StaticArrays, nor the reverse?
Metadata
Metadata
Assignees
Labels
ProjectTorelated to the projection functionalityrelated to the projection functionality