Skip to content

Matmul: dispatch on specific blas paths using an enum #55002

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 116 additions & 52 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -301,16 +301,45 @@ true
"""
@inline mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) = _mul!(C, A, B, α, β)
# Add a level of indirection and specialize _mul! to avoid ambiguities in mul!
@inline _mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) =
module BlasFlag
@enum BlasFunction SYRK HERK GEMM SYMM HEMM NONE
const SyrkHerkGemm = Union{Val{SYRK}, Val{HERK}, Val{GEMM}}
const SymmHemmGeneric = Union{Val{SYMM}, Val{HEMM}, Val{NONE}}
end
@inline function _mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number)
tA = wrapper_char(A)
tB = wrapper_char(B)
tA_uc = uppercase(tA)
tB_uc = uppercase(tB)
isntc = wrapper_char_NTC(A) & wrapper_char_NTC(B)
blasfn = if isntc
if (tA_uc == 'T' && tB_uc == 'N') || (tA_uc == 'N' && tB_uc == 'T')
BlasFlag.SYRK
elseif (tA_uc == 'C' && tB_uc == 'N') || (tA_uc == 'N' && tB_uc == 'C')
BlasFlag.HERK
else isntc
BlasFlag.GEMM
end
else
if (tA_uc == 'S' && tB_uc == 'N') || (tA_uc == 'N' && tB_uc == 'S')
BlasFlag.SYMM
elseif (tA_uc == 'H' && tB_uc == 'N') || (tA_uc == 'N' && tB_uc == 'H')
BlasFlag.HEMM
else
BlasFlag.NONE
end
end

generic_matmatmul_wrapper!(
C,
wrapper_char(A),
wrapper_char(B),
tA,
tB,
_unwrap(A),
_unwrap(B),
α, β,
Val(wrapper_char_NTC(A) & wrapper_char_NTC(B))
Val(blasfn),
)
end

# this indirection allows is to specialize on the types of the wrappers of A and B to some extent,
# even though the wrappers are stripped off in mul!
Expand Down Expand Up @@ -415,7 +444,7 @@ end

# THE one big BLAS dispatch. This is split into two methods to improve latency
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
α::Number, β::Number, ::Val{true}) where {T<:BlasFloat}
α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:BlasFloat}
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α)
Expand All @@ -425,24 +454,31 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix
return _rmul_or_fill!(C, β)
end
matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C
# We convert the chars to uppercase to potentially unwrap a WrapperChar,
# and extract the char corresponding to the wrapper type
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
# the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
if tA_uc == 'T' && tB_uc == 'N' && A === B
return syrk_wrapper!(C, 'T', A, α, β)
elseif tA_uc == 'N' && tB_uc == 'T' && A === B
return syrk_wrapper!(C, 'N', A, α, β)
elseif tA_uc == 'C' && tB_uc == 'N' && A === B
return herk_wrapper!(C, 'C', A, α, β)
elseif tA_uc == 'N' && tB_uc == 'C' && A === B
return herk_wrapper!(C, 'N', A, α, β)
_syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, val)
return C
end
Base.@constprop :aggressive function _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, ::Val{BlasFlag.SYRK})
if A === B
tA_uc = uppercase(tA) # potentially strip a WrapperChar
return syrk_wrapper!(C, tA_uc, A, α, β)
else
return gemm_wrapper!(C, tA, tB, A, B, α, β)
end
end
Base.@constprop :aggressive function _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, ::Val{BlasFlag.HERK})
if A === B
tA_uc = uppercase(tA) # potentially strip a WrapperChar
return herk_wrapper!(C, tA_uc, A, α, β)
else
return gemm_wrapper!(C, tA, tB, A, B, α, β)
end
end
Base.@constprop :aggressive function _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, ::Val{BlasFlag.GEMM})
return gemm_wrapper!(C, tA, tB, A, B, α, β)
end
_valtypeparam(v::Val{T}) where {T} = T
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
α::Number, β::Number, ::Val{false}) where {T<:BlasFloat}
α::Number, β::Number, val::BlasFlag.SymmHemmGeneric) where {T<:BlasFloat}
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α)
Expand All @@ -452,23 +488,48 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix
return _rmul_or_fill!(C, β)
end
matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C
# We convert the chars to uppercase to potentially unwrap a WrapperChar,
# and extract the char corresponding to the wrapper type
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
alpha, beta = promote(α, β, zero(T))
if alpha isa Union{Bool,T} && beta isa Union{Bool,T}
if tA_uc == 'S' && tB_uc == 'N'
return BLAS.symm!('L', tA == 'S' ? 'U' : 'L', alpha, A, B, beta, C)
elseif tA_uc == 'N' && tB_uc == 'S'
return BLAS.symm!('R', tB == 'S' ? 'U' : 'L', alpha, B, A, beta, C)
elseif tA_uc == 'H' && tB_uc == 'N'
return BLAS.hemm!('L', tA == 'H' ? 'U' : 'L', alpha, A, B, beta, C)
elseif tA_uc == 'N' && tB_uc == 'H'
return BLAS.hemm!('R', tB == 'H' ? 'U' : 'L', alpha, B, A, beta, C)
end
blasfn = _valtypeparam(val)
if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && blasfn ∈ (BlasFlag.SYMM, BlasFlag.HEMM)
_blasfn = blasfn
αβ = (alpha, beta)
else
_blasfn = BlasFlag.NONE
αβ = (α, β)
end
return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β))
_symm_hemm_generic!(C, tA, tB, A, B, αβ..., Val(_blasfn))
return C
end
Base.@constprop :aggressive function _lrchar_ulchar(tA, tB)
if uppercase(tA) == 'N'
lrchar = 'R'
ulchar = isuppercase(tB) ? 'U' : 'L'
else
lrchar = 'L'
ulchar = isuppercase(tA) ? 'U' : 'L'
end
return lrchar, ulchar
end
function _symm_hemm_generic!(C, tA, tB, A, B, alpha, beta, ::Val{BlasFlag.SYMM})
lrchar, ulchar = _lrchar_ulchar(tA, tB)
if lrchar == 'L'
BLAS.symm!(lrchar, ulchar, alpha, A, B, beta, C)
else
BLAS.symm!(lrchar, ulchar, alpha, B, A, beta, C)
end
end
function _symm_hemm_generic!(C, tA, tB, A, B, alpha, beta, ::Val{BlasFlag.HEMM})
lrchar, ulchar = _lrchar_ulchar(tA, tB)
if lrchar == 'L'
BLAS.hemm!(lrchar, ulchar, alpha, A, B, beta, C)
else
BLAS.hemm!(lrchar, ulchar, alpha, B, A, beta, C)
end
end
Base.@constprop :aggressive function _symm_hemm_generic!(C, tA, tB, A, B, alpha, beta, ::Val{BlasFlag.NONE})
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
end

# legacy method
Base.@constprop :aggressive generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
_add::MulAddMul = MulAddMul()) where {T<:BlasFloat} =
Expand All @@ -479,8 +540,8 @@ function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::S
gemm_wrapper!(C, tA, tB, A, B, α, β)
end
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
α::Number, β::Number, ::Val{false}) where {T<:BlasReal}
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β))
alpha::Number, beta::Number, ::Val{false}) where {T<:BlasReal}
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
end
# legacy method
Base.@constprop :aggressive generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
Expand Down Expand Up @@ -682,7 +743,7 @@ Base.@constprop :aggressive function gemm_wrapper(tA::AbstractChar, tB::Abstract
if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc)))
gemm_wrapper!(C, tA, tB, A, B, true, false)
else
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul())
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), true, false)
end
end

Expand All @@ -709,7 +770,7 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::Ab
_fullstride2(A) && _fullstride2(B) && _fullstride2(C))
return BLAS.gemm!(tA, tB, alpha, A, B, beta, C)
end
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β))
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), α, β)
end
# legacy method
gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar,
Expand Down Expand Up @@ -744,7 +805,7 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{Complex{T}
BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C))
return C
end
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β))
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), α, β)
end
# legacy method
gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar,
Expand Down Expand Up @@ -914,12 +975,16 @@ end
# aggressive const prop makes mixed eltype mul!(C, A, B) invoke _generic_matmatmul! directly
# legacy method
Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul = MulAddMul()) =
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) =
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β))
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add.alpha, _add.beta)
Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, alpha::Number, beta::Number) =
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)

# legacy method
_generic_matmatmul!(C::AbstractVecOrMat, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul) =
_generic_matmatmul!(C, A, B, _add.alpha, _add.beta)

@noinline function _generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S},
_add::MulAddMul{ais1}) where {T,S,R,ais1}
@noinline function _generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat, B::AbstractVecOrMat,
alpha::Number, beta::Number) where {R}
AxM = axes(A, 1)
AxK = axes(A, 2) # we use two `axes` calls in case of `AbstractVector`
BxK = axes(B, 1)
Expand All @@ -935,34 +1000,33 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A
if BxN != CxN
throw(DimensionMismatch(lazy"matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)"))
end
_rmul_alpha = MulAddMul{ais1,true,typeof(_add.alpha),Bool}(_add.alpha,false)
if isbitstype(R) && sizeof(R) ≤ 16 && !(A isa Adjoint || A isa Transpose)
_rmul_or_fill!(C, _add.beta)
(iszero(_add.alpha) || isempty(A) || isempty(B)) && return C
_rmul_or_fill!(C, beta)
(iszero(alpha) || isempty(A) || isempty(B)) && return C
@inbounds for n in BxN, k in BxK
# Balpha = B[k,n] * alpha, but we skip the multiplication in case isone(alpha)
Balpha = _rmul_alpha(B[k,n])
Balpha = @stable_muladdmul MulAddMul(alpha, false)(B[k,n])
@simd for m in AxM
C[m,n] = muladd(A[m,k], Balpha, C[m,n])
end
end
elseif isbitstype(R) && sizeof(R) ≤ 16 && ((A isa Adjoint && B isa Adjoint) || (A isa Transpose && B isa Transpose))
_rmul_or_fill!(C, _add.beta)
(iszero(_add.alpha) || isempty(A) || isempty(B)) && return C
_rmul_or_fill!(C, beta)
(iszero(alpha) || isempty(A) || isempty(B)) && return C
t = wrapperop(A)
pB = parent(B)
pA = parent(A)
tmp = similar(C, CxN)
ci = first(CxM)
ta = t(_add.alpha)
ta = t(alpha)
for i in AxM
mul!(tmp, pB, view(pA, :, i))
@views C[ci,:] .+= t.(ta .* tmp)
ci += 1
end
else
if iszero(_add.alpha) || isempty(A) || isempty(B)
return _rmul_or_fill!(C, _add.beta)
if iszero(alpha) || isempty(A) || isempty(B)
return _rmul_or_fill!(C, beta)
end
a1 = first(AxK)
b1 = first(BxK)
Expand All @@ -972,7 +1036,7 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A
@simd for k in AxK
Ctmp = muladd(A[i, k], B[k, j], Ctmp)
end
_modify!(_add, Ctmp, C, (i,j))
@stable_muladdmul _modify!(MulAddMul(alpha,beta), Ctmp, C, (i,j))
end
end
return C
Expand Down