Skip to content

Commit

Permalink
Generic tiled matrix multiply with a single algorithm (Lapack fallback)
Browse files Browse the repository at this point in the history
This is the fastest version yet.
  • Loading branch information
timholy committed Aug 13, 2012
1 parent 51c9557 commit f597671
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 78 deletions.
54 changes: 50 additions & 4 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ function copy_to{T}(dest::Array{T}, dsto, src::Array{T}, so, N)
end
copy_to_unsafe(dest, dsto, src, so, N)
end
# @Jeff: is this split needed?

function copy_to_unsafe{T}(dest::Array{T}, dsto, src::Array{T}, so, N)
if isa(T, BitsKind)
ccall(:memcpy, Ptr{Void}, (Ptr{Void}, Ptr{Void}, Uint),
Expand All @@ -39,6 +39,52 @@ end

copy_to{T}(dest::Array{T}, src::Array{T}) = copy_to(dest, 1, src, 1, numel(src))

function copy_to{R,S}(B::Matrix{R}, ir_dest::Range1{Int}, jr_dest::Range1{Int}, A::StridedMatrix{S}, ir_src::Range1{Int}, jr_src::Range1{Int})
if length(ir_dest) != length(ir_src) || length(jr_dest) != length(jr_src)
error("copy_to: size mismatch")
end
check_bounds(B, ir_dest, jr_dest)
check_bounds(A, ir_src, jr_src)
jdest = first(jr_dest)
Askip = size(A, 1)
Bskip = size(B, 1)
if stride(A, 1) == 1 && R == S
for jsrc in jr_src
copy_to(B, (jdest-1)*Bskip+first(ir_dest), A, (jsrc-1)*Askip+first(ir_src), length(ir_src))
jdest += 1
end
else
for jsrc in jr_src
aoffset = (jsrc-1)*Askip
boffset = (jdest-1)*Bskip
idest = first(ir_dest)
for isrc in ir_src
B[boffset+idest] = A[aoffset+isrc]
idest += 1
end
jdest += 1
end
end
end
function copy_to_transpose{R,S}(B::Matrix{R}, ir_dest::Range1{Int}, jr_dest::Range1{Int}, A::StridedMatrix{S}, ir_src::Range1{Int}, jr_src::Range1{Int})
if length(ir_dest) != length(jr_src) || length(jr_dest) != length(ir_src)
error("copy_to: size mismatch")
end
check_bounds(B, ir_dest, jr_dest)
check_bounds(A, ir_src, jr_src)
idest = first(ir_dest)
Askip = size(A, 1)
for jsrc in jr_src
offset = (jsrc-1)*Askip
jdest = first(jr_dest)
for isrc in ir_src
B[idest,jdest] = A[offset+isrc]
jdest += 1
end
idest += 1
end
end

function reinterpret{T,S}(::Type{T}, a::Array{S,1})
nel = int(div(numel(a)*sizeof(S),sizeof(T)))
ccall(:jl_reshape_array, Array{T,1}, (Any, Any, Any), Array{T,1}, a, (nel,))
Expand Down Expand Up @@ -211,12 +257,12 @@ end

check_bounds(A::AbstractVector, I::Indices) = check_bounds(length(A), I)

function check_bounds(A::Matrix, I::Indices, J::Indices)
function check_bounds(A::AbstractMatrix, I::Indices, J::Indices)
check_bounds(size(A,1), I)
check_bounds(size(A,2), J)
end

function check_bounds(A::Array, I::Indices, J::Indices)
function check_bounds(A::AbstractArray, I::Indices, J::Indices)
check_bounds(size(A,1), I)
sz = size(A,2)
for i = 3:ndims(A)
Expand All @@ -225,7 +271,7 @@ function check_bounds(A::Array, I::Indices, J::Indices)
check_bounds(sz, J)
end

function check_bounds(A::Array, I::Indices...)
function check_bounds(A::AbstractArray, I::Indices...)
n = length(I)
if n > 0
for dim = 1:(n-1)
Expand Down
138 changes: 64 additions & 74 deletions base/linalg_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,27 @@ cross(a::Vector, b::Vector) =
# cases are handled here

lapack_size(t::Char, M::StridedVecOrMat) = (t == 'N') ? (size(M, 1), size(M, 2)) : (size(M,2), size(M, 1))
function copy_to{R,S}(B::Matrix{R}, ir_dest::Range1{Int}, jr_dest::Range1{Int}, tM::Char, M::StridedMatrix{S}, ir_src::Range1{Int}, jr_src::Range1{Int})
if tM == 'N'
copy_to(B, ir_dest, jr_dest, M, ir_src, jr_src)
else
copy_to_transpose(B, ir_dest, jr_dest, M, jr_src, ir_src)
if tM == 'C'
conj!(B)
end
end
end
function copy_to_transpose{R,S}(B::Matrix{R}, ir_dest::Range1{Int}, jr_dest::Range1{Int}, tM::Char, M::StridedMatrix{S}, ir_src::Range1{Int}, jr_src::Range1{Int})
if tM == 'N'
copy_to_transpose(B, ir_dest, jr_dest, M, ir_src, jr_src)
else
copy_to(B, ir_dest, jr_dest, M, jr_src, ir_src)
if tM == 'C'
conj!(B)
end
end
end


# TODO: It will be faster for large matrices to convert to float,
# call BLAS, and convert back to required type.
Expand Down Expand Up @@ -106,96 +127,65 @@ function _jl_generic_matmatmul{T,S}(tA, tB, A::StridedMatrix{T}, B::StridedMatri
C = Array(promote_type(T,S), mA, nB)
_jl_generic_matmatmul(C, tA, tB, A, B)
end

tilebufsize = 10800 # Approximately 32k/3
Abuf = Array(Uint8, tilebufsize)
Bbuf = Array(Uint8, tilebufsize)
Cbuf = Array(Uint8, tilebufsize)

function _jl_generic_matmatmul{T,S,R}(C::StridedMatrix{R}, tA, tB, A::StridedMatrix{T}, B::StridedMatrix{S})
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
if mA == 2 && nA == 2 && nB == 2; return matmul2x2(tA,tB,A,B); end
if mA == 3 && nA == 3 && nB == 3; return matmul3x3(tA,tB,A,B); end
if nA != mB; error("*: argument shapes do not match"); end
if size(C,1) != mA || size(C,2) != nB; error("*: output size is incorrect"); end
tile_size = ifloor(sqrt(tilebufsize/sizeof(R)))
sz = (tile_size, tile_size)
Atile = pointer_to_array(convert(Ptr{R}, pointer(Abuf)), sz)
Btile = pointer_to_array(convert(Ptr{R}, pointer(Bbuf)), sz)
Ctile = pointer_to_array(convert(Ptr{R}, pointer(Cbuf)), sz)

# Call in a separate function for reasons of type inference (makes
# a huge performance difference)
_jl_generic_matmatmul(C, tA, tB, A, B, Atile, Btile, Ctile)
end

function _jl_generic_matmatmul{T,S,R}(C::StridedMatrix{R}, tA, tB, A::StridedMatrix{T}, B::StridedMatrix{S}, Atile::Matrix{R}, Btile::Matrix{R}, Ctile::Matrix{R})
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
tile_size = size(Atile, 1)
z = zero(R)
fill!(C, z)

Astride = size(A, 1)
Bstride = size(B, 1)
Cstride = size(C, 1)
tilesz = ifloor(sqrt(10800/sizeof(R))) # assumes L1 cache is >=32k
fA = lapack_flag(tA)
fB = lapack_flag(tB)

if tB == 'N'
if tA == 'N'
# Tiled multiplication
for jb = 1:tilesz:nB
jlim = min(jb+tilesz-1,nB)
for ib = 1:tilesz:mA
ilim = min(ib+tilesz-1,mA)
for kb = 1:tilesz:nA
klim = min(kb+tilesz-1,mB)
for j=jb:jlim
boffs = (j-1)*Bstride
coffs = (j-1)*Cstride
for i=ib:ilim
s = z
for k=kb:klim
s += A[i,k] * B[boffs+k]
end
C[coffs+i] += s
end
for jb = 1:tile_size:nB
jlim = min(jb+tile_size-1,nB)
jlen = jlim-jb+1
for ib = 1:tile_size:mA
ilim = min(ib+tile_size-1,mA)
ilen = ilim-ib+1
fill!(Ctile, z)
for kb = 1:tile_size:nA
klim = min(kb+tile_size-1,mB)
klen = klim-kb+1
copy_to_transpose(Atile, 1:klen, 1:ilen, tA, A, ib:ilim, kb:klim)
copy_to(Btile, 1:klen, 1:jlen, tB, B, kb:klim, jb:jlim)
for j=1:jlen
bcoff = (j-1)*tile_size
for i = 1:ilen
aoff = (i-1)*tile_size
s = z
for k = 1:klen
s += Atile[aoff+k] * Btile[bcoff+k]
end
Ctile[bcoff+i] += s
end
end
end
else
# tA = 'T'/'C', tB = 'N'.
# This is the fastest case. In contrast to the case above,
# I haven't yet seen evidence that tiling helps here. So
# this is a simple untiled algorithm.
for j = 1:nB
boffs = (j-1)*Bstride
coffs = (j-1)*Cstride
for i = 1:mA
aoffs = (i-1)*Astride
s = z
for k = 1:nA
s += value(fA, A[aoffs+k]) * B[boffs+k]
end
C[coffs+i] = s
end
end
end
else
if tA == 'N'
# tA = 'N', tB = 'T'/'C'
for k = 1:nA
aoffs = (k-1)*Astride
for j = 1:nB
coffs = (j-1)*Cstride
b = value(fB, B[j,k])
for i = 1:mA
C[coffs+i] += A[aoffs+i]*b
end
end
end
else
# tA = 'T'/'C', tB = 'T'/'C'
for j = 1:nB
coffs = (j-1)*Cstride
for i = 1:mA
aoffs = (i-1)*Astride
s = z
for k = 1:nA
s += value(fA, A[aoffs+k]) * value(fB, B[j,k])
end
C[coffs+i] = s
end
end
copy_to(C, ib:ilim, jb:jlim, Ctile, 1:ilen, 1:jlen)
end
end

return C
end


# multiply 2x2 matrices
function matmul2x2{T,S}(tA, tB, A::StridedMatrix{T}, B::StridedMatrix{S})
R = promote_type(T,S)
Expand Down

2 comments on commit f597671

@StefanKarpinski
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. Throw in a little Polly and we might just have JLBLAS.

@dcampbell24
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commit broke multiplication on my 32bit Ubuntu 12.04 machine. I checked by this by checking out the previous commit.

julia> [1 2; 3 4]*[1 2; 3 4]
no method pointer_to_array(Ptr{Int32},(Int64,Int64),Bool)
in method_missing at base.jl:70
in _jl_generic_matmatmul at linalg_dense.jl:143
in _jl_generic_matmatmul at linalg_dense.jl:128
in * at linalg_dense.jl:123

Please sign in to comment.