Skip to content

attempt to precompile linsolve #698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Aug 12, 2021
2 changes: 2 additions & 0 deletions src/DiffEqBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ include("init.jl")
include("forwarddiff.jl")
include("chainrules.jl")

include("precompile.jl")

"""
$(TYPEDEF)
"""
Expand Down
28 changes: 25 additions & 3 deletions src/linear_nonlinear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ DefaultLinSolve() = DefaultLinSolve(nothing, nothing, nothing)
end

function isopenblas()
@static if VERSION < v"1.7"
@static if VERSION < v"1.7beta"
blas = BLAS.vendor()
blas == :openblas64 || blas == :openblas
else
Expand Down Expand Up @@ -131,7 +131,7 @@ function (p::DefaultLinSolve)(x,A,b,update_matrix=false;reltol=nothing, kwargs..
end

if A isa Union{Matrix,SymTridiagonal,Tridiagonal,Symmetric,Hermitian,ForwardSensitivityJacobian} # No 2-arg form for SparseArrays!
x .= b
copyto!(x,b)
ldiv!(p.A,x)
# Missing a little bit of efficiency in a rare case
#elseif A isa DiffEqArrayOperator
Expand All @@ -144,7 +144,7 @@ function (p::DefaultLinSolve)(x,A,b,update_matrix=false;reltol=nothing, kwargs..
reltol = checkreltol(reltol)
p.iterable = IterativeSolvers.gmres_iterable!(x,A,b;initially_zero=true,restart=5,maxiter=5,abstol=1e-16,reltol=reltol,kwargs...)
end
x .= false
fill!(x,false)
iter = p.iterable
purge_history!(iter, x, b)

Expand All @@ -168,6 +168,28 @@ end
Base.resize!(p::DefaultLinSolve,i) = p.A = nothing
const DEFAULT_LINSOLVE = DefaultLinSolve()

## A much simpler LU for when we know it's Array

mutable struct LUFactorize
A::LU{Float64,Matrix{Float64}}
openblas::Bool
end
LUFactorize() = LUFactorize(lu(rand(1,1)),isopenblas())
function (p::LUFactorize)(x::Vector{Float64},A::Matrix{Float64},b::Vector{Float64},update_matrix::Bool=false;kwargs...)
if update_matrix
if ArrayInterface.can_setindex(x) && (size(A,1) <= 100 || (p.openblas && size(A,1) <= 500))
p.A = RecursiveFactorization.lu!(A)
else
p.A = lu!(A)
end
end
ldiv!(x,p.A,b)
end
function (p::LUFactorize)(::Type{Val{:init}},f,u0_prototype)
LUFactorize(lu(rand(eltype(u0_prototype),1,1)),p.openblas)
end
Base.resize!(p::LUFactorize,i) = p.A = nothing

### Default GMRES

# Easily change to GMRES
Expand Down
19 changes: 19 additions & 0 deletions src/precompile.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
let
while true
_testf(du,u,p,t) = copyto!(du,u)
b = rand(1); x = rand(1)
_linsolve = DEFAULT_LINSOLVE(Val{:init},ODEFunction(_testf),b)
A = rand(1,1)
_linsolve(x,A,b,true)
_linsolve(x,A,b,false)
_linsolve = LUFactorize()(Val{:init},ODEFunction(_testf),b)
_linsolve(x,A,b,true)
_linsolve(x,A,b,false)
Pl = ScaleVector([1.0],true)
Pr = ScaleVector([1.0],false)
reltol = 1.0
_linsolve(x,A,b,true;reltol=reltol,Pl=Pl,Pr=Pr)
_linsolve(x,A,b,false;reltol=reltol,Pl=Pl,Pr=Pr)
break
end
end
11 changes: 7 additions & 4 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,14 @@ function solve_up(prob::DEProblem,sensealg,u0,p,args...;kwargs...)

if haskey(kwargs,:alg) && (isempty(args) || args[1] === nothing)
alg = kwargs[:alg]
_prob = get_concrete_problem(prob,isadaptive(alg);u0=u0,p=p,kwargs...)
solve_call(_prob,alg,args...;kwargs...)
_alg = prepare_alg(alg,u0,p,prob)
_prob = get_concrete_problem(prob,isadaptive(_alg);u0=u0,p=p,kwargs...)
solve_call(_prob,_alg,args...;kwargs...)
elseif !isempty(args) && typeof(args[1]) <: DEAlgorithm
alg = args[1]
_prob = get_concrete_problem(prob,isadaptive(alg);u0=u0,p=p,kwargs...)
solve_call(_prob,args...;kwargs...)
_alg = prepare_alg(alg,u0,p,prob)
_prob = get_concrete_problem(prob,isadaptive(_alg);u0=u0,p=p,kwargs...)
solve_call(_prob,_alg,Base.tail(args)...;kwargs...)
elseif isempty(args) # Default algorithm handling
_prob = get_concrete_problem(prob,!(typeof(prob)<:DiscreteProblem);u0=u0,p=p,kwargs...)
solve_call(_prob,args...;kwargs...)
Expand Down Expand Up @@ -203,6 +205,7 @@ function promote_f(f,u0)
end

promote_f(f::SplitFunction,u0) = typeof(f.cache) === typeof(u0) && isinplace(f) ? f : remake(f,cache=zero(u0))
prepare_alg(alg,u0,p,f) = alg

function get_concrete_tspan(prob, isadapt, kwargs, p)
if prob.tspan isa Function
Expand Down