Skip to content

Make matrix multiplication work for more types #18218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 24, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1058,4 +1058,22 @@ function reduced_dims0(dims::Dims, region)
map(last, reduced_dims0(map(n->OneTo(n), dims), region))
end

# #18218
eval(Base.LinAlg, quote
function arithtype(T)
depwarn(string("arithtype is now deprecated. If you were using it inside a ",
"promote_op call, use promote_op(LinAlg.matprod, Ts...) instead. Otherwise, ",
"if you need its functionality, consider defining it locally."),
:arithtype)
T
end
function arithtype(::Type{Bool})
depwarn(string("arithtype is now deprecated. If you were using it inside a ",
"promote_op call, use promote_op(LinAlg.matprod, Ts...) instead. Otherwise, ",
"if you need its functionality, consider defining it locally."),
:arithtype)
Int
end
end)

# End deprecations scheduled for 0.6
35 changes: 17 additions & 18 deletions base/linalg/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

# matmul.jl: Everything to do with dense matrix multiplication

arithtype(T) = T
arithtype(::Type{Bool}) = Int
matprod(x, y) = x*y + x*y

# multiply by diagonal matrix as vector
function scale!(C::AbstractMatrix, A::AbstractMatrix, b::AbstractVector)
Expand Down Expand Up @@ -76,11 +75,11 @@ At_mul_B{T<:BlasComplex}(x::StridedVector{T}, y::StridedVector{T}) = [BLAS.dotu(

# Matrix-vector multiplication
function (*){T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
A_mul_B!(similar(x, TS, size(A,1)), A, convert(AbstractVector{TS}, x))
end
function (*){T,S}(A::AbstractMatrix{T}, x::AbstractVector{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
A_mul_B!(similar(x,TS,size(A,1)),A,x)
end
(*)(A::AbstractVector, B::AbstractMatrix) = reshape(A,length(A),1)*B
Expand All @@ -99,22 +98,22 @@ end
A_mul_B!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector) = generic_matvecmul!(y, 'N', A, x)

function At_mul_B{T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
At_mul_B!(similar(x,TS,size(A,2)), A, convert(AbstractVector{TS}, x))
end
function At_mul_B{T,S}(A::AbstractMatrix{T}, x::AbstractVector{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
At_mul_B!(similar(x,TS,size(A,2)), A, x)
end
At_mul_B!{T<:BlasFloat}(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T}) = gemv!(y, 'T', A, x)
At_mul_B!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector) = generic_matvecmul!(y, 'T', A, x)

function Ac_mul_B{T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
Ac_mul_B!(similar(x,TS,size(A,2)),A,convert(AbstractVector{TS},x))
end
function Ac_mul_B{T,S}(A::AbstractMatrix{T}, x::AbstractVector{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
Ac_mul_B!(similar(x,TS,size(A,2)), A, x)
end

Expand All @@ -132,7 +131,7 @@ Ac_mul_B!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector) = generic_m
Matrix multiplication.
"""
function (*){T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
A_mul_B!(similar(B, TS, (size(A,1), size(B,2))), A, B)
end
A_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'N', 'N', A, B)
Expand Down Expand Up @@ -166,14 +165,14 @@ julia> Y
A_mul_B!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'N', 'N', A, B)

function At_mul_B{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
At_mul_B!(similar(B, TS, (size(A,2), size(B,2))), A, B)
end
At_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = A===B ? syrk_wrapper!(C, 'T', A) : gemm_wrapper!(C, 'T', 'N', A, B)
At_mul_B!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'T', 'N', A, B)

function A_mul_Bt{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
A_mul_Bt!(similar(B, TS, (size(A,1), size(B,1))), A, B)
end
A_mul_Bt!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = A===B ? syrk_wrapper!(C, 'N', A) : gemm_wrapper!(C, 'N', 'T', A, B)
Expand All @@ -190,7 +189,7 @@ end
A_mul_Bt!(C::AbstractVecOrMat, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'N', 'T', A, B)

function At_mul_Bt{T,S}(A::AbstractMatrix{T}, B::AbstractVecOrMat{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
At_mul_Bt!(similar(B, TS, (size(A,2), size(B,1))), A, B)
end
At_mul_Bt!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'T', 'T', A, B)
Expand All @@ -199,7 +198,7 @@ At_mul_Bt!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generi
Ac_mul_B{T<:BlasReal}(A::StridedMatrix{T}, B::StridedMatrix{T}) = At_mul_B(A, B)
Ac_mul_B!{T<:BlasReal}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = At_mul_B!(C, A, B)
function Ac_mul_B{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
Ac_mul_B!(similar(B, TS, (size(A,2), size(B,2))), A, B)
end
Ac_mul_B!{T<:BlasComplex}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = A===B ? herk_wrapper!(C,'C',A) : gemm_wrapper!(C,'C', 'N', A, B)
Expand All @@ -208,14 +207,14 @@ Ac_mul_B!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic
A_mul_Bc{T<:BlasFloat,S<:BlasReal}(A::StridedMatrix{T}, B::StridedMatrix{S}) = A_mul_Bt(A, B)
A_mul_Bc!{T<:BlasFloat,S<:BlasReal}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{S}) = A_mul_Bt!(C, A, B)
function A_mul_Bc{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
TS = promote_op(*, arithtype(T), arithtype(S))
TS = promote_op(matprod, T, S)
A_mul_Bc!(similar(B,TS,(size(A,1),size(B,1))),A,B)
end
A_mul_Bc!{T<:BlasComplex}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = A===B ? herk_wrapper!(C, 'N', A) : gemm_wrapper!(C, 'N', 'C', A, B)
A_mul_Bc!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'N', 'C', A, B)

Ac_mul_Bc{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S}) =
Ac_mul_Bc!(similar(B, promote_op(*, arithtype(T), arithtype(S)), (size(A,2), size(B,1))), A, B)
Ac_mul_Bc!(similar(B, promote_op(matprod, T, S), (size(A,2), size(B,1))), A, B)
Ac_mul_Bc!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'C', 'C', A, B)
Ac_mul_Bc!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'C', 'C', A, B)
Ac_mul_Bt!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'C', 'T', A, B)
Expand Down Expand Up @@ -448,7 +447,7 @@ end
function generic_matmatmul{T,S}(tA, tB, A::AbstractVecOrMat{T}, B::AbstractMatrix{S})
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
C = similar(B, promote_op(*, arithtype(T), arithtype(S)), mA, nB)
C = similar(B, promote_op(matprod, T, S), mA, nB)
generic_matmatmul!(C, tA, tB, A, B)
end

Expand Down Expand Up @@ -642,7 +641,7 @@ end

# multiply 2x2 matrices
function matmul2x2{T,S}(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
matmul2x2!(similar(B, promote_op(*, T, S), 2, 2), tA, tB, A, B)
matmul2x2!(similar(B, promote_op(matprod, T, S), 2, 2), tA, tB, A, B)
end

function matmul2x2!{T,S,R}(C::AbstractMatrix{R}, tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
Expand Down Expand Up @@ -671,7 +670,7 @@ end

# Multiply 3x3 matrices
function matmul3x3{T,S}(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
matmul3x3!(similar(B, promote_op(*, T, S), 3, 3), tA, tB, A, B)
matmul3x3!(similar(B, promote_op(matprod, T, S), 3, 3), tA, tB, A, B)
end

function matmul3x3!{T,S,R}(C::AbstractMatrix{R}, tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
Expand Down
27 changes: 27 additions & 0 deletions test/linalg/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,30 @@ let
@test_throws DimensionMismatch A_mul_B!(full43, full43, tri44)
end
end

# #18218
module TestPR18218
using Base.Test
import Base.*, Base.+, Base.zero
immutable TypeA
x::Int
end
Base.convert(::Type{TypeA}, x::Int) = TypeA(x)
immutable TypeB
x::Int
end
immutable TypeC
x::Int
end
Base.convert(::Type{TypeC}, x::Int) = TypeC(x)
zero(c::TypeC) = TypeC(0)
zero(::Type{TypeC}) = TypeC(0)
(*)(x::Int, a::TypeA) = TypeB(x*a.x)
(*)(a::TypeA, x::Int) = TypeB(a.x*x)
(+)(a::Union{TypeB,TypeC}, b::Union{TypeB,TypeC}) = TypeC(a.x+b.x)
A = TypeA[1 2; 3 4]
b = [1, 2]
d = A * b
@test typeof(d) == Vector{TypeC}
@test d == TypeC[5, 11]
end