Skip to content

Commit d77a389

Browse files
Merge pull request #86 from SciML/linearsolve
use LinearSolve.jl
2 parents a88e748 + fc05995 commit d77a389

File tree

4 files changed

+70
-100
lines changed

4 files changed

+70
-100
lines changed

Project.toml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@ version = "0.3.22"
77
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
88
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
99
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
10-
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
1110
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1211
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
13-
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
1412
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1513
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1614
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
@@ -21,9 +19,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
2119
ArrayInterfaceCore = "0.1.1"
2220
FiniteDiff = "2"
2321
ForwardDiff = "0.10.3"
24-
IterativeSolvers = "0.9"
2522
RecursiveArrayTools = "2"
26-
RecursiveFactorization = "0.1, 0.2"
2723
Reexport = "0.2, 1"
2824
SciMLBase = "1.32"
2925
Setfield = "0.7, 0.8, 1"

src/NonlinearSolve.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,13 @@ using StaticArrays
99
using RecursiveArrayTools
1010
using LinearAlgebra
1111
import ArrayInterfaceCore
12-
import IterativeSolvers
13-
import RecursiveFactorization
1412

1513
@reexport using SciMLBase
1614

1715
abstract type AbstractNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end
1816
abstract type AbstractBracketingAlgorithm <: AbstractNonlinearSolveAlgorithm end
19-
abstract type AbstractNewtonAlgorithm{CS, AD} <: AbstractNonlinearSolveAlgorithm end
17+
abstract type AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ} <:
18+
AbstractNonlinearSolveAlgorithm end
2019
abstract type AbstractImmutableNonlinearSolver <: AbstractNonlinearSolveAlgorithm end
2120

2221
include("utils.jl")

src/raphson.jl

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1-
struct NewtonRaphson{CS, AD, DT, L} <: AbstractNewtonAlgorithm{CS, AD}
2-
diff_type::DT
1+
struct NewtonRaphson{CS, AD, FDT, L, P, ST, CJ} <:
2+
AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ}
33
linsolve::L
4+
precs::P
45
end
56

6-
function NewtonRaphson(; autodiff = true, chunk_size = 12, diff_type = Val{:forward},
7-
linsolve = DEFAULT_LINSOLVE)
8-
NewtonRaphson{chunk_size, autodiff, typeof(diff_type), typeof(linsolve)}(diff_type,
9-
linsolve)
7+
function NewtonRaphson(; chunk_size = Val{0}(), autodiff = Val{true}(),
8+
standardtag = Val{true}(), concrete_jac = nothing,
9+
diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS)
10+
NewtonRaphson{_unwrap_val(chunk_size), _unwrap_val(autodiff), diff_type,
11+
typeof(linsolve), typeof(precs), _unwrap_val(standardtag),
12+
_unwrap_val(concrete_jac)}(linsolve, precs)
1013
end
1114

1215
mutable struct NewtonRaphsonCache{ufType, L, jType, uType, JC}
@@ -17,10 +20,64 @@ mutable struct NewtonRaphsonCache{ufType, L, jType, uType, JC}
1720
jac_config::JC
1821
end
1922

23+
function dolinsolve(precs::P, linsolve; A = nothing, linu = nothing, b = nothing,
24+
du = nothing, u = nothing, p = nothing, t = nothing,
25+
weight = nothing, solverdata = nothing,
26+
reltol = nothing) where {P}
27+
A !== nothing && (linsolve = LinearSolve.set_A(linsolve, A))
28+
b !== nothing && (linsolve = LinearSolve.set_b(linsolve, b))
29+
linu !== nothing && (linsolve = LinearSolve.set_u(linsolve, linu))
30+
31+
Plprev = linsolve.Pl isa LinearSolve.ComposePreconditioner ? linsolve.Pl.outer :
32+
linsolve.Pl
33+
Prprev = linsolve.Pr isa LinearSolve.ComposePreconditioner ? linsolve.Pr.outer :
34+
linsolve.Pr
35+
36+
_Pl, _Pr = precs(linsolve.A, du, u, p, nothing, A !== nothing, Plprev, Prprev,
37+
solverdata)
38+
if (_Pl !== nothing || _Pr !== nothing)
39+
_weight = weight === nothing ?
40+
(linsolve.Pr isa Diagonal ? linsolve.Pr.diag : linsolve.Pr.inner.diag) :
41+
weight
42+
Pl, Pr = wrapprecs(_Pl, _Pr, _weight)
43+
linsolve = LinearSolve.set_prec(linsolve, Pl, Pr)
44+
end
45+
46+
linres = if reltol === nothing
47+
solve(linsolve; reltol)
48+
else
49+
solve(linsolve; reltol)
50+
end
51+
52+
return linres
53+
end
54+
55+
function wrapprecs(_Pl, _Pr, weight)
56+
if _Pl !== nothing
57+
Pl = LinearSolve.ComposePreconditioner(LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
58+
_Pl)
59+
else
60+
Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight)))
61+
end
62+
63+
if _Pr !== nothing
64+
Pr = LinearSolve.ComposePreconditioner(Diagonal(_vec(weight)), _Pr)
65+
else
66+
Pr = Diagonal(_vec(weight))
67+
end
68+
Pl, Pr
69+
end
70+
2071
function alg_cache(alg::NewtonRaphson, f, u, p, ::Val{true})
2172
uf = JacobianWrapper(f, p)
22-
linsolve = alg.linsolve(Val{:init}, f, u)
2373
J = false .* u .* u'
74+
75+
linprob = LinearProblem(W, _vec(zero(u)); u0 = _vec(zero(u)))
76+
Pl, Pr = wrapprecs(alg.precs(W, nothing, u, p, nothing, nothing, nothing, nothing,
77+
nothing)..., weight)
78+
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
79+
Pl = Pl, Pr = Pr)
80+
2481
du1 = zero(u)
2582
tmp = zero(u)
2683
if alg_autodiff(alg)
@@ -47,9 +104,12 @@ function perform_step(solver::NewtonImmutableSolver, alg::NewtonRaphson, ::Val{t
47104
@unpack J, linsolve, du1 = cache
48105
calc_J!(J, solver, cache)
49106
# u = u - J \ fu
50-
linsolve(du1, J, fu, true)
107+
linsolve = dolinsolve(alg.precs, solver.linsolve, A = J, b = fu, u = du1,
108+
p = p, reltol = solver.tol)
109+
@set! cache.linsolve = linsolve
51110
@. u = u - du1
52111
f(fu, u, p)
112+
53113
if solver.internalnorm(solver.fu) < solver.tol
54114
@set! solver.force_stop = true
55115
end

src/utils.jl

Lines changed: 0 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -100,91 +100,6 @@ function num_types_in_tuple(sig::UnionAll)
100100
length(Base.unwrap_unionall(sig).parameters)
101101
end
102102

103-
### Default Linsolve
104-
105-
# Try to be as smart as possible
106-
# lu! if Matrix
107-
# lu if sparse
108-
# gmres if operator
109-
110-
mutable struct DefaultLinSolve
111-
A::Any
112-
iterable::Any
113-
end
114-
DefaultLinSolve() = DefaultLinSolve(nothing, nothing)
115-
116-
function (p::DefaultLinSolve)(x, A, b, update_matrix = false; tol = nothing, kwargs...)
117-
if p.iterable isa Vector && eltype(p.iterable) <: LinearAlgebra.BlasInt # `iterable` here is the pivoting vector
118-
F = LU{eltype(A)}(A, p.iterable, zero(LinearAlgebra.BlasInt))
119-
ldiv!(x, F, b)
120-
return nothing
121-
end
122-
if update_matrix
123-
if typeof(A) <: Matrix
124-
blasvendor = BLAS.vendor()
125-
# if the user doesn't use OpenBLAS, we assume that is a better BLAS
126-
# implementation like MKL
127-
#
128-
# RecursiveFactorization seems to be consistantly winning below 100
129-
# https://discourse.julialang.org/t/ann-recursivefactorization-jl/39213
130-
if ArrayInterfaceCore.can_setindex(x) && (size(A, 1) <= 100 ||
131-
((blasvendor === :openblas || blasvendor === :openblas64) &&
132-
size(A, 1) <= 500))
133-
p.A = RecursiveFactorization.lu!(A)
134-
else
135-
p.A = lu!(A)
136-
end
137-
elseif typeof(A) <: Tridiagonal
138-
p.A = lu!(A)
139-
elseif typeof(A) <: Union{SymTridiagonal}
140-
p.A = ldlt!(A)
141-
elseif typeof(A) <: Union{Symmetric, Hermitian}
142-
p.A = bunchkaufman!(A)
143-
elseif typeof(A) <: SparseMatrixCSC
144-
p.A = lu(A)
145-
elseif ArrayInterfaceCore.isstructured(A)
146-
p.A = factorize(A)
147-
elseif !(typeof(A) <: AbstractDiffEqOperator)
148-
# Most likely QR is the one that is overloaded
149-
# Works on things like CuArrays
150-
p.A = qr(A)
151-
end
152-
end
153-
154-
if typeof(A) <: Union{Matrix, SymTridiagonal, Tridiagonal, Symmetric, Hermitian} # No 2-arg form for SparseArrays!
155-
x .= b
156-
ldiv!(p.A, x)
157-
# Missing a little bit of efficiency in a rare case
158-
#elseif typeof(A) <: DiffEqArrayOperator
159-
# ldiv!(x,p.A,b)
160-
elseif ArrayInterfaceCore.isstructured(A) || A isa SparseMatrixCSC
161-
ldiv!(x, p.A, b)
162-
elseif typeof(A) <: AbstractDiffEqOperator
163-
# No good starting guess, so guess zero
164-
if p.iterable === nothing
165-
p.iterable = IterativeSolvers.gmres_iterable!(x, A, b; initially_zero = true,
166-
restart = 5, maxiter = 5,
167-
tol = 1e-16, kwargs...)
168-
p.iterable.reltol = tol
169-
end
170-
x .= false
171-
iter = p.iterable
172-
purge_history!(iter, x, b)
173-
174-
for residual in iter
175-
end
176-
else
177-
ldiv!(x, p.A, b)
178-
end
179-
return nothing
180-
end
181-
182-
function (p::DefaultLinSolve)(::Type{Val{:init}}, f, u0_prototype)
183-
DefaultLinSolve()
184-
end
185-
186-
const DEFAULT_LINSOLVE = DefaultLinSolve()
187-
188103
@inline UNITLESS_ABS2(x) = real(abs2(x))
189104
@inline DEFAULT_NORM(u::Union{AbstractFloat, Complex}) = @fastmath abs(u)
190105
@inline function DEFAULT_NORM(u::Array{T}) where {T <: Union{AbstractFloat, Complex}}

0 commit comments

Comments
 (0)