From 65b9be408674bd6b08ea302571b1ac0f52b7a9f7 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> Date: Mon, 25 Apr 2022 11:50:53 +0900 Subject: [PATCH] improve `cat` inferrability (#45028) Make `cat` inferrable even if its arguments are not fully constant: ```julia julia> r = rand(Float32, 56, 56, 64, 1); julia> f(r) = cat(r, r, dims=(3,)) f (generic function with 1 method) julia> @inferred f(r); julia> last(@code_typed f(r)) Array{Float32, 4} ``` After descending into its call graph, I found that constant propagation is prohibited at `cat_t(::Type{T}, X...; dims)` due to the method instance heuristic, i.e. its body is considered to be too complex for successful inlining although it's explicitly annotated as `@inline`. But for this case, the constant propagation is greatly helpful both for abstract interpretation and optimization since it can improve the return type inference. Since it is not an easy task to improve the method instance heuristic, which is our primary logic for constant propagation, this commit does a quick fix by helping inference with the `@constprop` annotation. There is another issue that currently there is no good way to properly apply `@constprop`/`@inline` effects to a keyword function (as a note, this is a general issue of macro annotations on a method definition). So this commit also changes some internal helper functions of `cat` so that now they are not keyword ones: the changes are also necessary for the `@inline` annotation on `cat_t` to be effective to trick the method instance heuristic. --- base/abstractarray.jl | 33 +++++++++++++---------------- stdlib/LinearAlgebra/src/special.jl | 4 ++-- test/abstractarray.jl | 4 ++++ test/ambiguous.jl | 15 +++---------- 4 files changed, 24 insertions(+), 32 deletions(-) diff --git a/base/abstractarray.jl b/base/abstractarray.jl index af29aee6a74b6..239e75df52510 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -1716,13 +1716,7 @@ end _cs(d, a, b) = (a == b ? a : throw(DimensionMismatch( "mismatch in dimension $d (expected $a got $b)"))) -function dims2cat(::Val{dims}) where dims - if any(≤(0), dims) - throw(ArgumentError("All cat dimensions must be positive integers, but got $dims")) - end - ntuple(in(dims), maximum(dims)) -end - +dims2cat(::Val{dims}) where dims = dims2cat(dims) function dims2cat(dims) if any(≤(0), dims) throw(ArgumentError("All cat dimensions must be positive integers, but got $dims")) @@ -1730,9 +1724,8 @@ function dims2cat(dims) ntuple(in(dims), maximum(dims)) end -_cat(dims, X...) = cat_t(promote_eltypeof(X...), X...; dims=dims) +_cat(dims, X...) = _cat_t(dims, promote_eltypeof(X...), X...) -@inline cat_t(::Type{T}, X...; dims) where {T} = _cat_t(dims, T, X...) @inline function _cat_t(dims, ::Type{T}, X...) where {T} catdims = dims2cat(dims) shape = cat_size_shape(catdims, X...) @@ -1742,6 +1735,9 @@ _cat(dims, X...) = cat_t(promote_eltypeof(X...), X...; dims=dims) end return __cat(A, shape, catdims, X...) end +# this version of `cat_t` is not very kind for inference and so its usage should be avoided, +# nevertheless it is here just for compat after https://github.com/JuliaLang/julia/pull/45028 +@inline cat_t(::Type{T}, X...; dims) where {T} = _cat_t(dims, T, X...) # Why isn't this called `__cat!`? __cat(A, shape, catdims, X...) = __cat_offset!(A, shape, catdims, ntuple(zero, length(shape)), X...) @@ -1880,8 +1876,8 @@ julia> reduce(hcat, vs) """ hcat(X...) = cat(X...; dims=Val(2)) -typed_vcat(::Type{T}, X...) where T = cat_t(T, X...; dims=Val(1)) -typed_hcat(::Type{T}, X...) where T = cat_t(T, X...; dims=Val(2)) +typed_vcat(::Type{T}, X...) where T = _cat_t(Val(1), T, X...) +typed_hcat(::Type{T}, X...) where T = _cat_t(Val(2), T, X...) """ cat(A...; dims) @@ -1917,7 +1913,8 @@ julia> cat(true, trues(2,2), trues(4)', dims=(1,2)) ``` """ @inline cat(A...; dims) = _cat(dims, A...) -_cat(catdims, A::AbstractArray{T}...) where {T} = cat_t(T, A...; dims=catdims) +# `@constprop :aggressive` allows `catdims` to be propagated as constant improving return type inference +@constprop :aggressive _cat(catdims, A::AbstractArray{T}...) where {T} = _cat_t(catdims, T, A...) # The specializations for 1 and 2 inputs are important # especially when running with --inline=no, see #11158 @@ -1928,12 +1925,12 @@ hcat(A::AbstractArray) = cat(A; dims=Val(2)) hcat(A::AbstractArray, B::AbstractArray) = cat(A, B; dims=Val(2)) hcat(A::AbstractArray...) = cat(A...; dims=Val(2)) -typed_vcat(T::Type, A::AbstractArray) = cat_t(T, A; dims=Val(1)) -typed_vcat(T::Type, A::AbstractArray, B::AbstractArray) = cat_t(T, A, B; dims=Val(1)) -typed_vcat(T::Type, A::AbstractArray...) = cat_t(T, A...; dims=Val(1)) -typed_hcat(T::Type, A::AbstractArray) = cat_t(T, A; dims=Val(2)) -typed_hcat(T::Type, A::AbstractArray, B::AbstractArray) = cat_t(T, A, B; dims=Val(2)) -typed_hcat(T::Type, A::AbstractArray...) = cat_t(T, A...; dims=Val(2)) +typed_vcat(T::Type, A::AbstractArray) = _cat_t(Val(1), T, A) +typed_vcat(T::Type, A::AbstractArray, B::AbstractArray) = _cat_t(Val(1), T, A, B) +typed_vcat(T::Type, A::AbstractArray...) = _cat_t(Val(1), T, A...) +typed_hcat(T::Type, A::AbstractArray) = _cat_t(Val(2), T, A) +typed_hcat(T::Type, A::AbstractArray, B::AbstractArray) = _cat_t(Val(2), T, A, B) +typed_hcat(T::Type, A::AbstractArray...) = _cat_t(Val(2), T, A...) # 2d horizontal and vertical concatenation diff --git a/stdlib/LinearAlgebra/src/special.jl b/stdlib/LinearAlgebra/src/special.jl index 8d4292c6045ed..098df785e557a 100644 --- a/stdlib/LinearAlgebra/src/special.jl +++ b/stdlib/LinearAlgebra/src/special.jl @@ -414,14 +414,14 @@ const _TypedDenseConcatGroup{T} = Union{Vector{T}, Adjoint{T,Vector{T}}, Transpo promote_to_array_type(::Tuple{Vararg{Union{_DenseConcatGroup,UniformScaling}}}) = Matrix -Base._cat(dims, xs::_DenseConcatGroup...) = Base.cat_t(promote_eltype(xs...), xs...; dims=dims) +Base._cat(dims, xs::_DenseConcatGroup...) = Base._cat_t(dims, promote_eltype(xs...), xs...) vcat(A::Vector...) = Base.typed_vcat(promote_eltype(A...), A...) vcat(A::_DenseConcatGroup...) = Base.typed_vcat(promote_eltype(A...), A...) hcat(A::Vector...) = Base.typed_hcat(promote_eltype(A...), A...) hcat(A::_DenseConcatGroup...) = Base.typed_hcat(promote_eltype(A...), A...) hvcat(rows::Tuple{Vararg{Int}}, xs::_DenseConcatGroup...) = Base.typed_hvcat(promote_eltype(xs...), rows, xs...) # For performance, specially handle the case where the matrices/vectors have homogeneous eltype -Base._cat(dims, xs::_TypedDenseConcatGroup{T}...) where {T} = Base.cat_t(T, xs...; dims=dims) +Base._cat(dims, xs::_TypedDenseConcatGroup{T}...) where {T} = Base._cat_t(dims, T, xs...) vcat(A::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_vcat(T, A...) hcat(A::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_hcat(T, A...) hvcat(rows::Tuple{Vararg{Int}}, xs::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_hvcat(T, rows, xs...) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index d650cf67ebf11..df2dbe1c198b9 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -733,6 +733,10 @@ function test_cat(::Type{TestAbstractArray}) cat3v(As) = cat(As...; dims=Val(3)) @test @inferred(cat3v(As)) == zeros(2, 2, 2) @test @inferred(cat(As...; dims=Val((1,2)))) == zeros(4, 4) + + r = rand(Float32, 56, 56, 64, 1); + f(r) = cat(r, r, dims=(3,)) + @inferred f(r); end function test_ind2sub(::Type{TestAbstractArray}) diff --git a/test/ambiguous.jl b/test/ambiguous.jl index e7b3b13fba0ff..8d8c3efab53b9 100644 --- a/test/ambiguous.jl +++ b/test/ambiguous.jl @@ -172,20 +172,11 @@ using LinearAlgebra, SparseArrays, SuiteSparse # not using isempty so this prints more information when it fails @testset "detect_ambiguities" begin let ambig = Set{Any}(((m1.sig, m2.sig) for (m1, m2) in detect_ambiguities(Core, Base; recursive=true, ambiguous_bottom=false, allowed_undefineds))) - @test isempty(ambig) - expect = [] good = true - while !isempty(ambig) - sigs = pop!(ambig) - i = findfirst(==(sigs), expect) - if i === nothing - println(stderr, "push!(expect, (", sigs[1], ", ", sigs[2], "))") - good = false - continue - end - deleteat!(expect, i) + for (sig1, sig2) in ambig + @test sig1 === sig2 # print this ambiguity + good = false end - @test isempty(expect) @test good end