Skip to content

StaticArrays #479

@mcabbott

Description

@mcabbott

I think we would generically like StaticArray arguments to have StaticArray tangents. The current behaviour depends on what path you hit:

julia> using StaticArrays, ChainRulesCore

julia> p = ProjectTo(SA[1,2,3])  # has SOneTo
ProjectTo{AbstractArray}(element = ProjectTo{Float64}(), axes = (SOneTo(3),))

julia> p([4,5,6])  # doesn't reshape
3-element Vector{Float64}:
 4.0
 5.0
 6.0

julia> p(ones(3,1))  # does reshape
3-element SizedVector{3, Float64, Vector{Float64}} with indices SOneTo(3):
 1.0
 1.0
 1.0

If we change this line to test ===, then the first would be like the second:

dy = if axes(dx) == project.axes

Would this have any surprising downsides? It will also improve type-stability of things like this:

julia> p2 = ProjectTo(zeros(1:3))
ProjectTo{AbstractArray}(element = ProjectTo{Float64}(), axes = (OffsetArrays.IdOffsetRange(values=1:3, indices=1:3),))

julia> @code_warntype p2(ones(3))

Going the other way, if the argument is an ordinary Vector and the tangent is an SVector, then reshape won't do anything:

julia> reshape(SA[1,2,3], axes([4,5,6]))
3-element SVector{3, Int64} with indices SOneTo(3):
 1
 2
 3

This is the case of FluxML/Zygote.jl#1093, I think. A simple example of where the rule accidentally makes a SVector is:

julia> gradient(x -> dot(SA[1,2,3], x), rand(3))[1]  # dy = reshape(x .* ΔΩ, axes(y))
3-element SVector{3, Float64} with indices SOneTo(3):
 1.0
 2.0
 3.0

Is it a good idea as a general rule to force those to be converted back to Array? And if so, is there an easy way to implement this without depending on StaticArrays, nor the reverse?

Metadata

Metadata

Assignees

No one assigned

    Labels

    ProjectTorelated to the projection functionality

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions