diff --git a/base/abstractarray.jl b/base/abstractarray.jl index af29aee6a74b6..239e75df52510 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -1716,13 +1716,7 @@ end _cs(d, a, b) = (a == b ? a : throw(DimensionMismatch( "mismatch in dimension $d (expected $a got $b)"))) -function dims2cat(::Val{dims}) where dims - if any(≤(0), dims) - throw(ArgumentError("All cat dimensions must be positive integers, but got $dims")) - end - ntuple(in(dims), maximum(dims)) -end - +dims2cat(::Val{dims}) where dims = dims2cat(dims) function dims2cat(dims) if any(≤(0), dims) throw(ArgumentError("All cat dimensions must be positive integers, but got $dims")) @@ -1730,9 +1724,8 @@ function dims2cat(dims) ntuple(in(dims), maximum(dims)) end -_cat(dims, X...) = cat_t(promote_eltypeof(X...), X...; dims=dims) +_cat(dims, X...) = _cat_t(dims, promote_eltypeof(X...), X...) -@inline cat_t(::Type{T}, X...; dims) where {T} = _cat_t(dims, T, X...) @inline function _cat_t(dims, ::Type{T}, X...) where {T} catdims = dims2cat(dims) shape = cat_size_shape(catdims, X...) @@ -1742,6 +1735,9 @@ _cat(dims, X...) = cat_t(promote_eltypeof(X...), X...; dims=dims) end return __cat(A, shape, catdims, X...) end +# this version of `cat_t` is not very kind for inference and so its usage should be avoided, +# nevertheless it is here just for compat after https://github.com/JuliaLang/julia/pull/45028 +@inline cat_t(::Type{T}, X...; dims) where {T} = _cat_t(dims, T, X...) # Why isn't this called `__cat!`? __cat(A, shape, catdims, X...) = __cat_offset!(A, shape, catdims, ntuple(zero, length(shape)), X...) @@ -1880,8 +1876,8 @@ julia> reduce(hcat, vs) """ hcat(X...) = cat(X...; dims=Val(2)) -typed_vcat(::Type{T}, X...) where T = cat_t(T, X...; dims=Val(1)) -typed_hcat(::Type{T}, X...) where T = cat_t(T, X...; dims=Val(2)) +typed_vcat(::Type{T}, X...) where T = _cat_t(Val(1), T, X...) +typed_hcat(::Type{T}, X...) where T = _cat_t(Val(2), T, X...) """ cat(A...; dims) @@ -1917,7 +1913,8 @@ julia> cat(true, trues(2,2), trues(4)', dims=(1,2)) ``` """ @inline cat(A...; dims) = _cat(dims, A...) -_cat(catdims, A::AbstractArray{T}...) where {T} = cat_t(T, A...; dims=catdims) +# `@constprop :aggressive` allows `catdims` to be propagated as constant improving return type inference +@constprop :aggressive _cat(catdims, A::AbstractArray{T}...) where {T} = _cat_t(catdims, T, A...) # The specializations for 1 and 2 inputs are important # especially when running with --inline=no, see #11158 @@ -1928,12 +1925,12 @@ hcat(A::AbstractArray) = cat(A; dims=Val(2)) hcat(A::AbstractArray, B::AbstractArray) = cat(A, B; dims=Val(2)) hcat(A::AbstractArray...) = cat(A...; dims=Val(2)) -typed_vcat(T::Type, A::AbstractArray) = cat_t(T, A; dims=Val(1)) -typed_vcat(T::Type, A::AbstractArray, B::AbstractArray) = cat_t(T, A, B; dims=Val(1)) -typed_vcat(T::Type, A::AbstractArray...) = cat_t(T, A...; dims=Val(1)) -typed_hcat(T::Type, A::AbstractArray) = cat_t(T, A; dims=Val(2)) -typed_hcat(T::Type, A::AbstractArray, B::AbstractArray) = cat_t(T, A, B; dims=Val(2)) -typed_hcat(T::Type, A::AbstractArray...) = cat_t(T, A...; dims=Val(2)) +typed_vcat(T::Type, A::AbstractArray) = _cat_t(Val(1), T, A) +typed_vcat(T::Type, A::AbstractArray, B::AbstractArray) = _cat_t(Val(1), T, A, B) +typed_vcat(T::Type, A::AbstractArray...) = _cat_t(Val(1), T, A...) +typed_hcat(T::Type, A::AbstractArray) = _cat_t(Val(2), T, A) +typed_hcat(T::Type, A::AbstractArray, B::AbstractArray) = _cat_t(Val(2), T, A, B) +typed_hcat(T::Type, A::AbstractArray...) = _cat_t(Val(2), T, A...) # 2d horizontal and vertical concatenation diff --git a/stdlib/LinearAlgebra/src/special.jl b/stdlib/LinearAlgebra/src/special.jl index 8d4292c6045ed..098df785e557a 100644 --- a/stdlib/LinearAlgebra/src/special.jl +++ b/stdlib/LinearAlgebra/src/special.jl @@ -414,14 +414,14 @@ const _TypedDenseConcatGroup{T} = Union{Vector{T}, Adjoint{T,Vector{T}}, Transpo promote_to_array_type(::Tuple{Vararg{Union{_DenseConcatGroup,UniformScaling}}}) = Matrix -Base._cat(dims, xs::_DenseConcatGroup...) = Base.cat_t(promote_eltype(xs...), xs...; dims=dims) +Base._cat(dims, xs::_DenseConcatGroup...) = Base._cat_t(dims, promote_eltype(xs...), xs...) vcat(A::Vector...) = Base.typed_vcat(promote_eltype(A...), A...) vcat(A::_DenseConcatGroup...) = Base.typed_vcat(promote_eltype(A...), A...) hcat(A::Vector...) = Base.typed_hcat(promote_eltype(A...), A...) hcat(A::_DenseConcatGroup...) = Base.typed_hcat(promote_eltype(A...), A...) hvcat(rows::Tuple{Vararg{Int}}, xs::_DenseConcatGroup...) = Base.typed_hvcat(promote_eltype(xs...), rows, xs...) # For performance, specially handle the case where the matrices/vectors have homogeneous eltype -Base._cat(dims, xs::_TypedDenseConcatGroup{T}...) where {T} = Base.cat_t(T, xs...; dims=dims) +Base._cat(dims, xs::_TypedDenseConcatGroup{T}...) where {T} = Base._cat_t(dims, T, xs...) vcat(A::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_vcat(T, A...) hcat(A::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_hcat(T, A...) hvcat(rows::Tuple{Vararg{Int}}, xs::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_hvcat(T, rows, xs...) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index d650cf67ebf11..df2dbe1c198b9 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -733,6 +733,10 @@ function test_cat(::Type{TestAbstractArray}) cat3v(As) = cat(As...; dims=Val(3)) @test @inferred(cat3v(As)) == zeros(2, 2, 2) @test @inferred(cat(As...; dims=Val((1,2)))) == zeros(4, 4) + + r = rand(Float32, 56, 56, 64, 1); + f(r) = cat(r, r, dims=(3,)) + @inferred f(r); end function test_ind2sub(::Type{TestAbstractArray}) diff --git a/test/ambiguous.jl b/test/ambiguous.jl index e7b3b13fba0ff..8d8c3efab53b9 100644 --- a/test/ambiguous.jl +++ b/test/ambiguous.jl @@ -172,20 +172,11 @@ using LinearAlgebra, SparseArrays, SuiteSparse # not using isempty so this prints more information when it fails @testset "detect_ambiguities" begin let ambig = Set{Any}(((m1.sig, m2.sig) for (m1, m2) in detect_ambiguities(Core, Base; recursive=true, ambiguous_bottom=false, allowed_undefineds))) - @test isempty(ambig) - expect = [] good = true - while !isempty(ambig) - sigs = pop!(ambig) - i = findfirst(==(sigs), expect) - if i === nothing - println(stderr, "push!(expect, (", sigs[1], ", ", sigs[2], "))") - good = false - continue - end - deleteat!(expect, i) + for (sig1, sig2) in ambig + @test sig1 === sig2 # print this ambiguity + good = false end - @test isempty(expect) @test good end