Skip to content

Commit 7b6259e

Browse files
jarlebringdkarrasch
authored andcommitted
Efficiency improvement of exp(::StridedMatrix) with UniformScaling and mul! (JuliaLang#40668)
Co-authored-by: Daniel Karrasch <daniel.karrasch@posteo.de>
1 parent f1cb26f commit 7b6259e

File tree

1 file changed

+31
-11
lines changed

1 file changed

+31
-11
lines changed

stdlib/LinearAlgebra/src/dense.jl

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,6 @@ function exp!(A::StridedMatrix{T}) where T<:BlasFloat
617617
end
618618
ilo, ihi, scale = LAPACK.gebal!('B', A) # modifies A
619619
nA = opnorm(A, 1)
620-
Inn = Matrix{T}(I, n, n)
621620
## For sufficiently small nA, use lower order Padé-Approximations
622621
if (nA <= 2.1)
623622
if nA > 0.95
@@ -634,17 +633,21 @@ function exp!(A::StridedMatrix{T}) where T<:BlasFloat
634633
C = T[120.,60.,12.,1.]
635634
end
636635
A2 = A * A
637-
P = copy(Inn)
638-
U = C[2] * P
639-
V = C[1] * P
640-
for k in 1:(div(size(C, 1), 2) - 1)
636+
# Compute U and V: Even/odd terms in Padé numerator & denom
637+
# Expansion of k=1 in for loop
638+
P = A2
639+
U = C[2]*I + C[4]*P
640+
V = C[1]*I + C[3]*P
641+
for k in 2:(div(size(C, 1), 2) - 1)
641642
k2 = 2 * k
642643
P *= A2
643-
U += C[k2 + 2] * P
644-
V += C[k2 + 1] * P
644+
mul!(U, C[k2 + 2], P, true, true) # U += C[k2+2]*P
645+
mul!(V, C[k2 + 1], P, true, true) # V += C[k2+1]*P
645646
end
647+
646648
U = A * U
647649
X = V + U
650+
# Padé approximant: (V-U)\(V+U)
648651
LAPACK.gesv!(V-U, X)
649652
else
650653
s = log2(nA/5.4) # power of 2 later reversed by squaring
@@ -660,10 +663,27 @@ function exp!(A::StridedMatrix{T}) where T<:BlasFloat
660663
A2 = A * A
661664
A4 = A2 * A2
662665
A6 = A2 * A4
663-
U = A * (A6 * (CC[14].*A6 .+ CC[12].*A4 .+ CC[10].*A2) .+
664-
CC[8].*A6 .+ CC[6].*A4 .+ CC[4].*A2 .+ CC[2].*Inn)
665-
V = A6 * (CC[13].*A6 .+ CC[11].*A4 .+ CC[9].*A2) .+
666-
CC[7].*A6 .+ CC[5].*A4 .+ CC[3].*A2 .+ CC[1].*Inn
666+
Ut = CC[4]*A2
667+
Ut[diagind(Ut)] .+= CC[2]
668+
# Allocation economical version of:
669+
#U = A * (A6 * (CC[14].*A6 .+ CC[12].*A4 .+ CC[10].*A2) .+
670+
# CC[8].*A6 .+ CC[6].*A4 .+ Ut)
671+
U = mul!(CC[8].*A6 .+ CC[6].*A4 .+ Ut,
672+
A6,
673+
CC[14].*A6 .+ CC[12].*A4 .+ CC[10].*A2,
674+
true, true)
675+
U = A*U
676+
677+
# Allocation economical version of: Vt = CC[3]*A2 (recycle Ut)
678+
Vt = mul!(Ut, CC[3], A2, true, false)
679+
Vt[diagind(Vt)] .+= CC[1]
680+
# Allocation economical version of:
681+
#V = A6 * (CC[13].*A6 .+ CC[11].*A4 .+ CC[9].*A2) .+
682+
# CC[7].*A6 .+ CC[5].*A4 .+ Vt
683+
V = mul!(CC[7].*A6 .+ CC[5].*A4 .+ Vt,
684+
A6,
685+
CC[13].*A6 .+ CC[11].*A4 .+ CC[9].*A2,
686+
true, true)
667687

668688
X = V + U
669689
LAPACK.gesv!(V-U, X)

0 commit comments

Comments
 (0)