Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Allow indexing with non-static ranges to produce static arrays #703

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ end
@inline index_size(::Size, ::Int) = Size()
@inline index_size(::Size, a::StaticArray) = Size(a)
@inline index_size(s::Size, ::Colon) = s
@inline index_size(s::Size, a::SOneTo{n}) where n = Size(n,)
@inline index_size(::Size, a::AbstractRange{<:Integer}) = Size(length(a),)

@inline index_sizes(::S, inds...) where {S<:Size} = map(index_size, unpack_size(S), inds)

Expand All @@ -92,9 +92,9 @@ linear_index_size(ind_sizes::Type{<:Size}...) = _linear_index_size((), ind_sizes
@inline _linear_index_size(t::Tuple, ::Type{Size{S}}, ind_sizes...) where {S} = _linear_index_size((t..., prod(S)), ind_sizes...)

_ind(i::Int, ::Int, ::Type{Int}) = :(inds[$i])
_ind(i::Int, j::Int, ::Type{<:StaticArray}) = :(inds[$i][$j])
_ind(i::Int, j::Int, ::Type{Colon}) = j
_ind(i::Int, j::Int, ::Type{<:SOneTo}) = j
_ind(i::Int, j::Int, ::Type{<:AbstractArray}) = :(inds[$i][$j])

################################
## Non-scalar linear indexing ##
Expand Down Expand Up @@ -215,7 +215,7 @@ end

# getindex

@propagate_inbounds function getindex(a::StaticArray, inds::Union{Int, StaticArray{<:Tuple, Int}, SOneTo, Colon}...)
@propagate_inbounds function getindex(a::StaticArray, inds::Union{Int, StaticArray{<:Tuple, Int}, AbstractRange, Colon}...)
_getindex(a, index_sizes(Size(a), inds...), inds)
end

Expand Down
2 changes: 1 addition & 1 deletion test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ using StaticArrays, Test, LinearAlgebra
@test similar(v, SOneTo(3), SOneTo(4)) isa MMatrix{3,4,Int}
@test similar(v, 3, SOneTo(4)) isa Matrix

@test m[:, 1:2] isa Matrix
@test m[:, 1:2] isa SMatrix{2, 2, Int}
@test m[:, [true, false, false]] isa Matrix
@test m[:, SOneTo(2)] isa SMatrix{2, 2, Int}
@test m[:, :] isa SMatrix{2, 3, Int}
Expand Down
30 changes: 30 additions & 0 deletions test/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,4 +223,34 @@ using StaticArrays, Test
@test eltype(Bvv) == Int
@test Bvv[:] == [B[1,2,3,4], B[1,1,3,4]]
end

@testset "Indexing with constants" begin
function SVector_UnitRange()
x = SA[1, 2, 3]
x[2:end]
end
@test SVector_UnitRange() === SA[2, 3]
VERSION ≥ v"1.2" && @test_const_fold SVector_UnitRange()

function SVector_StepRange()
x = SA[1, 2, 3, 4]
x[1:2:end]
end
@test SVector_StepRange() === SA[1, 3]
VERSION ≥ v"1.2" && @test_const_fold SVector_StepRange()

function SMatrix_UnitRange_UnitRange()
x = SA[1 2 3; 4 5 6]
x[1:2, 2:end]
end
@test SMatrix_UnitRange_UnitRange() === SA[2 3; 5 6]
VERSION ≥ v"1.2" && @test_const_fold SMatrix_UnitRange_UnitRange()

function SMatrix_StepRange_StepRange()
x = SA[1 2 3; 4 5 6]
x[1:1:2, 1:2:end]
end
@test SMatrix_StepRange_StepRange() === SA[1 3; 4 6]
VERSION ≥ v"1.2" && @test_const_fold SMatrix_StepRange_StepRange()
end
end
39 changes: 39 additions & 0 deletions test/testutil.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,45 @@ should_not_be_inlined(x) = _should_not_be_inlined(x)
end


"""
@test_const_fold f(args...)

Test that constant folding works with a function call `f(args...)`.
"""
macro test_const_fold(ex)
quote
ci, = $(esc(:($InteractiveUtils.@code_typed optimize = true $ex)))
@test $(esc(ex)) == constant_return(ci)
end
end

struct NonConstantValue end

function constant_return(ci)
if :rettype in fieldnames(typeof(ci))
ci.rettype isa Core.Compiler.Const && return ci.rettype.val
return NonConstantValue()
else
# for julia < 1.2
ex = ci.code[end]
Meta.isexpr(ex, :return) || return NonConstantValue()
val = ex.args[1]
return val isa QuoteNode ? val.value : val
end
end

@testset "@test_const_fold" begin
should_const_fold() = (1, 2, 3)
@test_const_fold should_const_fold()

x = Ref(1)
should_not_const_fold() = x[]
ts = @testset ErrorCounterTestSet "" begin
@test_const_fold should_not_const_fold()
end
@test ts.errorcount == 0 && ts.failcount == 1 && ts.passcount == 0
end

"""
@inferred_maybe_allow allow ex

Expand Down