Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions base/export.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,20 @@ export
SubOrDArray,
SubString,
TransformedString,
Tridiagonal,
VecOrMat,
Vector,
VersionNumber,
WeakKeyDict,
Woodbury,
Zip,
Stat,
Factorization,
Cholesky,
LU,
LUTridiagonal,
LDLT,
LDLTTridiagonal,
QR,
QRP,

Expand Down Expand Up @@ -634,6 +639,7 @@ export
randsym,
rank,
rref,
solve,
svd,
svdvals,
trace,
Expand Down
45 changes: 45 additions & 0 deletions base/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,48 @@ end
##ToDo: Add methods for rank(A::QRP{T}) and adjust the (\) method accordingly
## Add rcond methods for Cholesky, LU, QR and QRP types
## Lower priority: Add LQ, QL and RQ factorizations


#### Factorizations for Tridiagonal ####
type LDLTTridiagonal{T} <: Factorization{T}
D::Vector{T}
E::Vector{T}
end
function LDLTTridiagonal{T<:LapackScalar}(A::Tridiagonal{T})
D = copy(A.d)
E = copy(A.dl)
_jl_lapack_pttrf(D, E)
LDLTTridiagonal(D, E)
end
LDLT(A::Tridiagonal) = LDLTTridiagonal(A)

(\){T<:LapackScalar}(C::LDLTTridiagonal{T}, B::StridedVecOrMat{T}) =
_jl_lapack_pttrs(C.D, C.E, copy(B))

type LUTridiagonal{T} <: Factorization{T}
lu::Tridiagonal{T}
ipiv::Vector{Int32}
function LUTridiagonal(lu::Tridiagonal{T}, ipiv::Vector{Int32})
m, n = size(lu)
m == numel(ipiv) ? new(lu, ipiv) : error("LU: dimension mismatch")
end
end
show(io, lu::LUTridiagonal) = print(io, "LU decomposition of ", summary(lu.lu))

function LU{T<:LapackScalar}(A::Tridiagonal{T})
lu, ipiv = _jl_lapack_gttrf(copy(A))
LUTridiagonal{T}(lu, ipiv)
end

function lu(A::Tridiagonal)
error("lu(A) is not defined when A is Tridiagonal. Use LU(A) instead.")
end

function det(lu::LUTridiagonal)
prod(lu.lu.d) * (bool(sum(lu.ipiv .!= 1:n) % 2) ? -1 : 1)
end

det(A::Tridiagonal) = det(LU(A))

(\){T<:LapackScalar}(lu::LUTridiagonal{T}, B::StridedVecOrMat{T}) =
_jl_lapack_gttrs('N', lu.lu, lu.ipiv, copy(B))
4 changes: 3 additions & 1 deletion base/linalg.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
## linalg.jl: Basic Linear Algebra interface specifications ##
## linalg.jl: Basic Linear Algebra interface specifications and
## specialized matrix types

#
# This file mostly contains commented functions which are supposed
# to be defined in type-specific linalg_<type>.jl files.
Expand Down
184 changes: 184 additions & 0 deletions base/linalg_lapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -885,3 +885,187 @@ end
expm{T<:Union(Float32,Float64,Complex64,Complex128)}(A::StridedMatrix{T}) = expm!(copy(A))
expm{T<:Integer}(A::StridedMatrix{T}) = expm!(float(A))


#### Tridiagonal matrix routines ####
function \{T<:LapackScalar}(M::Tridiagonal{T}, rhs::StridedVecOrMat{T})
if stride(rhs, 1) == 1
x = copy(rhs)
Mc = copy(M)
Mlu, x = _jl_lapack_gtsv(Mc, x)
return x
end
solve(M, rhs) # use the Julia "fallback"
end

eig(M::Tridiagonal) = _jl_lapack_stev('V', copy(M))

# Decompositions
for (gttrf, pttrf, elty) in
((:dgttrf_,:dpttrf_,:Float64),
(:sgttrf_,:spttrf_,:Float32),
(:zgttrf_,:zpttrf_,:Complex128),
(:cgttrf_,:cpttrf_,:Complex64))
@eval begin
function _jl_lapack_gttrf(M::Tridiagonal{$elty})
info = zero(Int32)
n = int32(length(M.d))
ipiv = Array(Int32, n)
ccall(dlsym(_jl_liblapack, $string(gttrf)),
Void,
(Ptr{Int32}, Ptr{$elty}, Ptr{$elty}, Ptr{$elty}, Ptr{$elty},
Ptr{Int32}, Ptr{Int32}),
&n, M.dl, M.d, M.du, M.dutmp, ipiv, &info)
if info != 0 throw(LapackException(info)) end
M, ipiv
end
function _jl_lapack_pttrf(D::Vector{$elty}, E::Vector{$elty})
info = zero(Int32)
n = int32(length(D))
if length(E) != n-1
error("subdiagonal must be one element shorter than diagonal")
end
ccall(dlsym(_jl_liblapack, $string(pttrf)),
Void,
(Ptr{Int32}, Ptr{$elty}, Ptr{$elty}, Ptr{Int32}),
&n, D, E, &info)
if info != 0 throw(LapackException(info)) end
D, E
end
end
end
# Direct solvers
for (gtsv, ptsv, elty) in
((:dgtsv_,:dptsv_,:Float64),
(:sgtsv_,:sptsv,:Float32),
(:zgtsv_,:zptsv,:Complex128),
(:cgtsv_,:cptsv,:Complex64))
@eval begin
function _jl_lapack_gtsv(M::Tridiagonal{$elty}, B::StridedVecOrMat{$elty})
if stride(B,1) != 1
error("_jl_lapack_gtsv: matrix columns must have contiguous elements");
end
info = zero(Int32)
n = int32(length(M.d))
nrhs = int32(size(B, 2))
ldb = int32(stride(B, 2))
ccall(dlsym(_jl_liblapack, $string(gtsv)),
Void,
(Ptr{Int32}, Ptr{Int32}, Ptr{$elty}, Ptr{$elty}, Ptr{$elty}, Ptr{$elty},
Ptr{Int32}, Ptr{Int32}),
&n, &nrhs, M.dl, M.d, M.du, B, &ldb, &info)
if info != 0 throw(LapackException(info)) end
M, B
end
function _jl_lapack_ptsv(M::Tridiagonal{$elty}, B::StridedVecOrMat{$elty})
if stride(B,1) != 1
error("_jl_lapack_ptsv: matrix columns must have contiguous elements");
end
info = zero(Int32)
n = int32(length(M.d))
nrhs = int32(size(B, 2))
ldb = int32(stride(B, 2))
ccall(dlsym(_jl_liblapack, $string(ptsv)),
Void,
(Ptr{Int32}, Ptr{Int32}, Ptr{$elty}, Ptr{$elty}, Ptr{$elty},
Ptr{Int32}, Ptr{Int32}),
&n, &nrhs, M.d, M.dl, B, &ldb, &info)
if info != 0 throw(LapackException(info)) end
M, B
end
end
end
# Solvers using decompositions
for (gttrs, pttrs, elty) in
((:dgttrs_,:dpttrs_,:Float64),
(:sgttrs_,:spttrs,:Float32),
(:zgttrs_,:zpttrs,:Complex128),
(:cgttrs_,:cpttrs,:Complex64))
@eval begin
function _jl_lapack_gttrs(trans::LapackChar, M::Tridiagonal{$elty}, ipiv::Vector{Int32}, B::StridedVecOrMat{$elty})
if stride(B,1) != 1
error("_jl_lapack_gttrs: matrix columns must have contiguous elements");
end
info = zero(Int32)
n = int32(length(M.d))
nrhs = int32(size(B, 2))
ldb = int32(stride(B, 2))
ccall(dlsym(_jl_liblapack, $string(gttrs)),
Void,
(Ptr{Uint8}, Ptr{Int32}, Ptr{Int32},
Ptr{$elty}, Ptr{$elty}, Ptr{$elty}, Ptr{$elty},
Ptr{Int32}, Ptr{$elty}, Ptr{Int32}, Ptr{Int32}),
&trans, &n, &nrhs, M.dl, M.d, M.du, M.dutmp, ipiv, B, &ldb, &info)
if info != 0 throw(LapackException(info)) end
B
end
function _jl_lapack_pttrs(D::Vector{$elty}, E::Vector{$elty}, B::StridedVecOrMat{$elty})
if stride(B,1) != 1
error("_jl_lapack_pttrs: matrix columns must have contiguous elements");
end
info = zero(Int32)
n = int32(length(D))
if length(E) != n-1
error("subdiagonal must be one element shorter than diagonal")
end
nrhs = int32(size(B, 2))
ldb = int32(stride(B, 2))
ccall(dlsym(_jl_liblapack, $string(pttrs)),
Void,
(Ptr{Int32}, Ptr{Int32}, Ptr{$elty}, Ptr{$elty}, Ptr{$elty},
Ptr{Int32}, Ptr{Int32}),
&n, &nrhs, D, E, B, &ldb, &info)
if info != 0 throw(LapackException(info)) end
B
end
end
end
# Eigenvalue-eigenvector (symmetric only)
for (stev, elty) in
((:dstev_,:Float64),
(:sstev_,:Float32),
(:zstev_,:Complex128),
(:cstev_,:Complex64))
@eval begin
function _jl_lapack_stev(Z::Array, M::Tridiagonal{$elty})
n = int32(length(M.d))
if isempty(Z)
job = 'N'
ldz = 1
work = Array($elty, 0)
Ztmp = work
else
if stride(Z,1) != 1
error("_jl_lapack_stev: eigenvector matrix columns must have contiguous elements");
end
if size(Z, 1) != n
error("_jl_lapack_stev: eigenvector matrix columns are not of the correct size")
end
Ztmp = Z
job = 'V'
ldz = int32(stride(Z, 2))
work = Array($elty, max(1, 2*n-2))
end
info = zero(Int32)
ccall(dlsym(_jl_liblapack, $string(stev)),
Void,
(Ptr{Uint8}, Ptr{Int32},
Ptr{$elty}, Ptr{$elty}, Ptr{$elty},
Ptr{Int32}, Ptr{$elty}, Ptr{Int32}),
&job, &n, M.d, M.dl, Ztmp, &ldz, work, &info)
if info != 0 throw(LapackException(info)) end
M.d
end
end
end
function _jl_lapack_stev(job::LapackChar, M::Tridiagonal)
if job == 'N' || job == 'n'
Z = []
elseif job == 'V' || job == 'v'
n = length(M.d)
Z = Array(eltype(M), n, n)
else
error("Job type not recognized")
end
D = _jl_lapack_stev(Z, M)
return D, Z
end
Loading