Skip to content

Sized AbstractArray #783

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/MArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,12 @@ end
function promote_rule(::Type{<:MArray{S,T,N,L}}, ::Type{<:MArray{S,U,N,L}}) where {S,T,U,N,L}
MArray{S,promote_type(T,U),N,L}
end

function Base.view(
a::MArray{S},
indices::Union{Integer, Colon, StaticVector, Base.Slice, SOneTo}...,
) where {S}
new_size = new_out_size(S, indices...)
view_from_invoke = invoke(view, Tuple{AbstractArray, typeof(indices).parameters...}, a, indices...)
return SizedArray{new_size}(view_from_invoke)
end
216 changes: 171 additions & 45 deletions src/SizedArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""
SizedArray{Tuple{dims...}}(array)

Wraps an `Array` with a static size, so to take advantage of the (faster)
Wraps an `AbstractArray` with a static size, so to take advantage of the (faster)
methods defined by the static array package. The size is checked once upon
construction to determine if the number of elements (`length`) match, but the
array may be reshaped.
Expand All @@ -11,37 +11,48 @@ The aliases `SizedVector{N}` and `SizedMatrix{N,M}` are provided as more
convenient names for one and two dimensional `SizedArray`s. For example, to
wrap a 2x3 array `a` in a `SizedArray`, use `SizedMatrix{2,3}(a)`.
"""
struct SizedArray{S <: Tuple, T, N, M} <: StaticArray{S, T, N}
data::Array{T, M}
struct SizedArray{S<:Tuple,T,N,M,TData<:AbstractArray{T,M}} <: StaticArray{S,T,N}
data::TData

function SizedArray{S, T, N, M}(a::Array) where {S, T, N, M}
if length(a) != tuple_prod(S)
function SizedArray{S,T,N,M,TData}(a::TData) where {S,T,N,M,TData<:AbstractArray{T,M}}
if size(a) != size_to_tuple(S) && size(a) != (tuple_prod(S),)
throw(DimensionMismatch("Dimensions $(size(a)) don't match static size $S"))
end
if size(a) != size_to_tuple(S)
Base.depwarn("Construction of `SizedArray` with an `Array` of a different
size is deprecated. If you need this functionality report it at
https://github.com/JuliaArrays/StaticArrays.jl/pull/666 .
Calling `sa = reshape(a::Array, s::Size)` will actually reshape
array `a` in the future and converting `sa` back to `Array` will
return an `Array` of shape `s`.", :SizedArray)
end
new{S,T,N,M}(a)
return new{S,T,N,M,TData}(a)
end

function SizedArray{S, T, N, M}(::UndefInitializer) where {S, T, N, M}
new{S, T, N, M}(Array{T, M}(undef, size_to_tuple(S)...))
function SizedArray{S,T,N,1,TData}(::UndefInitializer) where {S,T,N,TData<:AbstractArray{T,1}}
return new{S,T,N,1,TData}(TData(undef, tuple_prod(S)))
end
function SizedArray{S,T,N,N,TData}(::UndefInitializer) where {S,T,N,TData<:AbstractArray{T,N}}
return new{S,T,N,N,TData}(TData(undef, size_to_tuple(S)...))
end
end

@inline SizedArray{S,T,N}(a::Array{T,M}) where {S,T,N,M} = SizedArray{S,T,N,M}(a)
@inline SizedArray{S,T}(a::Array{T,M}) where {S,T,M} = SizedArray{S,T,tuple_length(S),M}(a)
@inline SizedArray{S}(a::Array{T,M}) where {S,T,M} = SizedArray{S,T,tuple_length(S),M}(a)

@inline SizedArray{S,T,N}(::UndefInitializer) where {S,T,N} = SizedArray{S,T,N,N}(undef)
@inline SizedArray{S,T}(::UndefInitializer) where {S,T} = SizedArray{S,T,tuple_length(S),tuple_length(S)}(undef)

@generated function SizedArray{S,T,N,M}(x::NTuple{L,Any}) where {S,T,N,M,L}
@inline function SizedArray{S,T,N}(
a::TData,
) where {S,T,N,M,TData<:AbstractArray{T,M}}
return SizedArray{S,T,N,M,TData}(a)
end
@inline function SizedArray{S,T}(a::TData) where {S,T,M,TData<:AbstractArray{T,M}}
return SizedArray{S,T,tuple_length(S),M,TData}(a)
end
@inline function SizedArray{S}(a::TData) where {S,T,M,TData<:AbstractArray{T,M}}
return SizedArray{S,T,tuple_length(S),M,TData}(a)
end
function SizedArray{S,T,N,N}(::UndefInitializer) where {S,T,N}
return SizedArray{S,T,N,N,Array{T,N}}(undef)
end
function SizedArray{S,T,N,1}(::UndefInitializer) where {S,T,N}
return SizedArray{S,T,N,1,Vector{T}}(undef)
end
@inline function SizedArray{S,T,N}(::UndefInitializer) where {S,T,N}
return SizedArray{S,T,N,N}(undef)
end
@inline function SizedArray{S,T}(::UndefInitializer) where {S,T}
return SizedArray{S,T,tuple_length(S)}(undef)
end
@generated function (::Type{SizedArray{S,T,N,M,TData}})(x::NTuple{L,Any}) where {S,T,N,M,TData<:AbstractArray{T,M},L}
if L != tuple_prod(S)
error("Dimension mismatch")
end
Expand All @@ -53,43 +64,158 @@ end
return a
end
end

@inline SizedArray{S,T,N}(x::Tuple) where {S,T,N} = SizedArray{S,T,N,N}(x)
@inline SizedArray{S,T}(x::Tuple) where {S,T} = SizedArray{S,T,tuple_length(S),tuple_length(S)}(x)
@inline SizedArray{S}(x::NTuple{L,T}) where {S,T,L} = SizedArray{S,T,tuple_length(S),tuple_length(S)}(x)
@inline function SizedArray{S,T,N,M}(x::Tuple) where {S,T,N,M}
return SizedArray{S,T,N,M,Array{T,M}}(x)
end
@inline function SizedArray{S,T,N}(x::Tuple) where {S,T,N}
return SizedArray{S,T,N,N,Array{T,N}}(x)
end
@inline function SizedArray{S,T}(x::Tuple) where {S,T}
return SizedArray{S,T,tuple_length(S)}(x)
end
@inline function SizedArray{S}(x::NTuple{L,T}) where {S,T,L}
return SizedArray{S,T}(x)
end

# Overide some problematic default behaviour
@inline convert(::Type{SA}, sa::SizedArray) where {SA<:SizedArray} = SA(sa.data)
@inline convert(::Type{SA}, sa::SA) where {SA<:SizedArray} = sa

# Back to Array (unfortunately need both convert and construct to overide other methods)
@inline Array(sa::SizedArray) = Array(sa.data)
@inline Array{T}(sa::SizedArray{S,T}) where {T,S} = Array{T}(sa.data)
@inline Array{T,N}(sa::SizedArray{S,T,N}) where {T,S,N} = Array{T,N}(sa.data)
@inline function Base.Array(sa::SizedArray{S}) where {S}
return Array(reshape(sa.data, size_to_tuple(S)))
end
@inline function Base.Array{T}(sa::SizedArray{S,T}) where {T,S}
return Array(reshape(sa.data, size_to_tuple(S)))
end
@inline function Base.Array{T,N}(sa::SizedArray{S,T,N}) where {T,S,N}
return Array(reshape(sa.data, size_to_tuple(S)))
end

@inline convert(::Type{Array}, sa::SizedArray) = sa.data
@inline convert(::Type{Array{T}}, sa::SizedArray{S,T}) where {T,S} = sa.data
@inline convert(::Type{Array{T,N}}, sa::SizedArray{S,T,N}) where {T,S,N} = sa.data
@inline function convert(::Type{Array}, sa::SizedArray{S}) where {S}
return Array(reshape(sa.data, size_to_tuple(S)))
end
@inline function convert(::Type{Array}, sa::SizedArray{S,T,N,M,Array{T,M}}) where {S,T,N,M}
return sa.data
end
@inline function convert(::Type{Array{T}}, sa::SizedArray{S,T}) where {T,S}
return Array(reshape(sa.data, size_to_tuple(S)))
end
@inline function convert(::Type{Array{T}}, sa::SizedArray{S,T,N,M,Array{T,M}}) where {S,T,N,M}
return sa.data
end
@inline function convert(
::Type{Array{T,N}},
sa::SizedArray{S,T,N},
) where {T,S,N}
return Array(reshape(sa.data, size_to_tuple(S)))
end
@inline function convert(::Type{Array{T,N}}, sa::SizedArray{S,T,N,N,Array{T,N}}) where {S,T,N}
return sa.data
end

@propagate_inbounds getindex(a::SizedArray, i::Int) = getindex(a.data, i)
@propagate_inbounds setindex!(a::SizedArray, v, i::Int) = setindex!(a.data, v, i)

SizedVector{S,T,M} = SizedArray{Tuple{S},T,1,M}
@inline SizedVector{S}(a::Array{T,M}) where {S,T,M} = SizedArray{Tuple{S},T,1,M}(a)
@inline SizedVector{S}(x::NTuple{L,T}) where {S,T,L} = SizedArray{Tuple{S},T,1,1}(x)
Base.parent(sa::SizedArray) = sa.data

SizedMatrix{S1,S2,T,M} = SizedArray{Tuple{S1,S2},T,2,M}
@inline SizedMatrix{S1,S2}(a::Array{T,M}) where {S1,S2,T,M} = SizedArray{Tuple{S1,S2},T,2,M}(a)
@inline SizedMatrix{S1,S2}(x::NTuple{L,T}) where {S1,S2,T,L} = SizedArray{Tuple{S1,S2},T,2,2}(x)
const SizedVector{S,T} = SizedArray{Tuple{S},T,1,1}

@inline function SizedVector{S}(a::TData) where {S,T,TData<:AbstractVector{T}}
return SizedArray{Tuple{S},T,1,1,TData}(a)
end
@inline function SizedVector(x::NTuple{S,T}) where {S,T}
return SizedArray{Tuple{S},T,1,1,Vector{T}}(x)
end
@inline function SizedVector{S}(x::NTuple{S,T}) where {S,T}
return SizedArray{Tuple{S},T,1,1,Vector{T}}(x)
end
@inline function SizedVector{S,T}(x::NTuple{S}) where {S,T}
return SizedArray{Tuple{S},T,1,1,Vector{T}}(x)
end
# disambiguation
@inline function SizedVector{S}(a::StaticVector{S,T}) where {S,T}
return SizedVector{S,T}(a.data)
end

const SizedMatrix{S1,S2,T} = SizedArray{Tuple{S1,S2},T,2}

@inline function SizedMatrix{S1,S2}(
a::TData,
) where {S1,S2,T,M,TData<:AbstractArray{T,M}}
return SizedArray{Tuple{S1,S2},T,2,M,TData}(a)
end
@inline function SizedMatrix{S1,S2}(x::NTuple{L,T}) where {S1,S2,T,L}
return SizedArray{Tuple{S1,S2},T,2,2,Matrix{T}}(x)
end
@inline function SizedMatrix{S1,S2,T}(x::NTuple{L}) where {S1,S2,T,L}
return SizedArray{Tuple{S1,S2},T,2,2,Matrix{T}}(x)
end
# disambiguation
@inline function SizedMatrix{S1,S2}(a::StaticMatrix{S1,S2,T}) where {S1,S2,T}
return SizedMatrix{S1,S2,T}(a.data)
end

Base.dataids(sa::SizedArray) = Base.dataids(sa.data)

function (::Size{S})(a::Array) where {S}
Base.depwarn("`Size{S}(a::Array)` is deprecated, use `SizedVector{N}(a)`, `SizedMatrix{N,M}(a)` or `SizedArray{Tuple{S}}(a)` instead", :Size)
SizedArray{Tuple{S...}}(a)
function promote_rule(
::Type{SizedArray{S,T,N,M,TDataA}},
::Type{SizedArray{S,U,N,M,TDataB}},
) where {S,T,U,N,M,TDataA,TDataB}
TU = promote_type(T, U)
return SizedArray{S, TU, N, M, promote_type(TDataA, TDataB)}
end

function promote_rule(
::Type{SizedArray{S,T,N,M}},
::Type{SizedArray{S,U,N,M}},
) where {S,T,U,N,M,}
TU = promote_type(T, U)
return SizedArray{S, TU, N, M}
end

function promote_rule(
::Type{SizedArray{S,T,N}},
::Type{SizedArray{S,U,N}},
) where {S,T,U,N}
TU = promote_type(T, U)
return SizedArray{S, TU, N}
end


### Code that makes views of statically sized arrays also statically sized (where possible)

@generated function new_out_size(::Type{Size}, inds...) where Size
os = []
map(Size.parameters, inds) do s, i
if i <: Integer
# dimension is fixed
elseif i <: StaticVector
push!(os, i.parameters[1].parameters[1])
elseif i == Colon || i <: Base.Slice
push!(os, s)
elseif i <: SOneTo
push!(os, i.parameters[1])
else
error("Unknown index type: $i")
end
end
return Tuple{os...}
end

@generated function new_out_size(::Type{Size}, ::Colon) where Size
prod_size = tuple_prod(Size)
return Tuple{prod_size}
end

function Base.view(
a::SizedArray{S},
indices::Union{Integer, Colon, StaticVector, Base.Slice, SOneTo}...,
) where {S}
new_size = new_out_size(S, indices...)
return SizedArray{new_size}(view(a.data, indices...))
end

function promote_rule(::Type{<:SizedArray{S,T,N,M}}, ::Type{<:SizedArray{S,U,N,M}}) where {S,T,U,N,M}
SizedArray{S,promote_type(T,U),N,M}
function Base.vec(a::SizedArray{S}) where {S}
return SizedVector{tuple_prod(S)}(vec(a.data))
end
8 changes: 4 additions & 4 deletions src/matrix_multiply_add.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,18 @@ Base.transpose(::TSize{S,T}) where {S,T} = TSize{reverse(S),!T}()

# Get the parent of transposed arrays, or the array itself if it has no parent
# QUESTION: maybe call this something else?
Base.parent(A::Union{<:Transpose{<:Any,<:StaticArray}, <:Adjoint{<:Any,<:StaticArray}}) = A.parent
Base.parent(A::StaticArray) = A
mul_parent(A) = parent(A)
mul_parent(A::StaticArray) = A

# 5-argument matrix multiplication
# To avoid allocations, strip away Transpose type and store tranpose info in Size
@inline LinearAlgebra.mul!(dest::StaticVecOrMatLike, A::StaticVecOrMatLike, B::StaticVecOrMatLike,
α::Real, β::Real) = _mul!(TSize(dest), parent(dest), TSize(A), TSize(B), parent(A), parent(B),
α::Real, β::Real) = _mul!(TSize(dest), mul_parent(dest), TSize(A), TSize(B), mul_parent(A), mul_parent(B),
AlphaBeta(α,β))

@inline LinearAlgebra.mul!(dest::StaticVecOrMatLike, A::StaticVecOrMatLike{T},
B::StaticVecOrMatLike{T}) where T =
_mul!(TSize(dest), parent(dest), TSize(A), TSize(B), parent(A), parent(B), NoMulAdd{T}())
_mul!(TSize(dest), mul_parent(dest), TSize(A), TSize(B), mul_parent(A), mul_parent(B), NoMulAdd{T}())


"Calculate the product of the dimensions being multiplied. Useful as a heuristic for unrolling."
Expand Down
Loading