Skip to content

Support LinearProblems in SCCs #633

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 26, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 57 additions & 24 deletions lib/SCCNonlinearSolve/src/SCCNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,71 @@ module SCCNonlinearSolve
import SciMLBase
import CommonSolve
import SymbolicIndexingInterface
import SciMLBase: NonlinearProblem, NonlinearLeastSquaresProblem, LinearProblem

"""
SCCAlg(; nlalg = nothing, linalg = nothing)

Algorithm for solving Strongly Connected Component (SCC) problems containing
both nonlinear and linear subproblems.

### Keyword Arguments

- `nlalg`: Algorithm to use for solving NonlinearProblem components
- `linalg`: Algorithm to use for solving LinearProblem components
"""
struct SCCAlg{N, L}
nlalg::N
linalg::L
end

SCCAlg(; nlalg = nothing, linalg = nothing) = SCCAlg(nlalg, linalg)

function CommonSolve.solve(prob::SciMLBase.SCCNonlinearProblem; kwargs...)
CommonSolve.solve(prob, nothing; kwargs...)
CommonSolve.solve(prob, SCCAlg(nothing, nothing); kwargs...)
end

function CommonSolve.solve(prob::SciMLBase.SCCNonlinearProblem, alg; kwargs...)
numscc = length(prob.probs)
sols = [SciMLBase.build_solution(
prob, nothing, prob.u0, convert(eltype(prob.u0), NaN) * prob.u0)
for prob in prob.probs]
u = reduce(vcat, [prob.u0 for prob in prob.probs])
resid = copy(u)

lasti = 1
for i in 1:numscc
prob.explictfuns![i](
SymbolicIndexingInterface.parameter_values(prob.probs[i]), sols)
sol = SciMLBase.solve(prob.probs[i], alg; kwargs...)
_sol = SciMLBase.build_solution(
prob.probs[i], nothing, sol.u, sol.resid, retcode = sol.retcode)
sols[i] = _sol
lasti = i
if !SciMLBase.successful_retcode(_sol)
break
end
function CommonSolve.solve(prob::SciMLBase.SCCNonlinearProblem, alg::AbstractNonlinearAlgorithm; kwargs...)
CommonSolve.solve(prob, SCCAlg(alg, nothing); kwargs...)
end

probvec(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}) = prob.u0
probvec(prob::LinearProblem) = prob.b

iteratively_build_sols(alg, sols; kwargs...) = sols

function iteratively_build_sols(alg, sols, (prob, explicitfun), args...; kwargs...)
explicitfun(
SymbolicIndexingInterface.parameter_values(prob), sols)

_sol = if prob isa SciMLBase.LinearProblem
sol = SciMLBase.solve(prob, alg.linalg; kwargs...)
SciMLBase.build_linear_solution(
alg.linalg, sol.u, nothing, nothing, retcode = sol.retcode)
else
sol = SciMLBase.solve(prob, alg.nlalg; kwargs...)
SciMLBase.build_solution(
prob, nothing, sol.u, sol.resid, retcode = sol.retcode)
end

iteratively_build_sols(alg, (sols..., _sol), args...)
end

function CommonSolve.solve(prob::SciMLBase.SCCNonlinearProblem, alg::SCCAlg; kwargs...)
numscc = length(prob.probs)
sols = iteratively_build_sols(
alg, (), zip(prob.probs, prob.explicitfuns!)...; kwargs...)

# TODO: fix allocations with a lazy concatenation
u .= reduce(vcat, sols)
resid .= reduce(vcat, getproperty.(sols, :resid))
u = reduce(vcat, sols)
resid = reduce(vcat, getproperty.(sols, :resid))

retcode = sols[lasti].retcode
retcode = if !all(SciMLBase.successful_retcode, sols)
idx = findfirst(!SciMLBase.successful_retcode, sols)
sols[idx].retcode
else
SciMLBase.ReturnCode.Success
end

SciMLBase.build_solution(prob, alg, u, resid; retcode, original = sols)
end
Expand Down
87 changes: 47 additions & 40 deletions lib/SCCNonlinearSolve/test/core_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,64 +7,71 @@ end
@testitem "Manual SCC" setup=[CoreRootfindTesting] tags=[:core] begin
using NonlinearSolveFirstOrder
function f(du, u, p)
du[1] = cos(u[2]) - u[1]
du[2] = sin(u[1] + u[2]) + u[2]
du[3] = 2u[4] + u[3] + 1.0
du[4] = u[5]^2 + u[4]
du[5] = u[3]^2 + u[5]
du[6] = u[1] + u[2] + u[3] + u[4] + u[5] + 2.0u[6] + 2.5u[7] + 1.5u[8]
du[7] = u[1] + u[2] + u[3] + 2.0u[4] + u[5] + 4.0u[6] - 1.5u[7] + 1.5u[8]
du[8] = u[1] + 2.0u[2] + 3.0u[3] + 5.0u[4] + 6.0u[5] + u[6] - u[7] - u[8]
du[1]=cos(u[2])-u[1]
du[2]=sin(u[1]+u[2])+u[2]
du[3]=2u[4]+u[3]+1.0
du[4]=u[5]^2+u[4]
du[5]=u[3]^2+u[5]
du[6]=u[1]+u[2]+u[3]+u[4]+u[5]+2.0u[6]+2.5u[7]+1.5u[8]
du[7]=u[1]+u[2]+u[3]+2.0u[4]+u[5]+4.0u[6]-1.5u[7]+1.5u[8]
du[8]=u[1]+2.0u[2]+3.0u[3]+5.0u[4]+6.0u[5]+u[6]-u[7]-u[8]
Comment on lines +10 to +17
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
du[1]=cos(u[2])-u[1]
du[2]=sin(u[1]+u[2])+u[2]
du[3]=2u[4]+u[3]+1.0
du[4]=u[5]^2+u[4]
du[5]=u[3]^2+u[5]
du[6]=u[1]+u[2]+u[3]+u[4]+u[5]+2.0u[6]+2.5u[7]+1.5u[8]
du[7]=u[1]+u[2]+u[3]+2.0u[4]+u[5]+4.0u[6]-1.5u[7]+1.5u[8]
du[8]=u[1]+2.0u[2]+3.0u[3]+5.0u[4]+6.0u[5]+u[6]-u[7]-u[8]
du[1] = cos(u[2]) - u[1]
du[2] = sin(u[1] + u[2]) + u[2]
du[3] = 2u[4] + u[3] + 1.0
du[4] = u[5]^2 + u[4]
du[5] = u[3]^2 + u[5]
du[6] = u[1] + u[2] + u[3] + u[4] + u[5] + 2.0u[6] + 2.5u[7] + 1.5u[8]
du[7] = u[1] + u[2] + u[3] + 2.0u[4] + u[5] + 4.0u[6] - 1.5u[7] + 1.5u[8]
du[8] = u[1] + 2.0u[2] + 3.0u[3] + 5.0u[4] + 6.0u[5] + u[6] - u[7] - u[8]

end
prob = NonlinearProblem(f, zeros(8))
sol = solve(prob, NewtonRaphson())
prob=NonlinearProblem(f, zeros(8))
sol=solve(prob, NewtonRaphson())
Comment on lines +19 to +20
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
prob=NonlinearProblem(f, zeros(8))
sol=solve(prob, NewtonRaphson())
prob = NonlinearProblem(f, zeros(8))
sol = solve(prob, NewtonRaphson())


u0 = zeros(2)
p = zeros(3)
u0=zeros(2)
p=zeros(3)
Comment on lines +22 to +23
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
u0=zeros(2)
p=zeros(3)
u0 = zeros(2)
p = zeros(3)


function f1(du, u, p)
du[1] = cos(u[2]) - u[1]
du[2] = sin(u[1] + u[2]) + u[2]
du[1]=cos(u[2])-u[1]
du[2]=sin(u[1]+u[2])+u[2]
Comment on lines +26 to +27
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
du[1]=cos(u[2])-u[1]
du[2]=sin(u[1]+u[2])+u[2]
du[1] = cos(u[2]) - u[1]
du[2] = sin(u[1] + u[2]) + u[2]

end
explicitfun1(p, sols) = nothing
prob1 = NonlinearProblem(
explicitfun1(p, sols)=nothing
prob1=NonlinearProblem(
Comment on lines +29 to +30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
explicitfun1(p, sols)=nothing
prob1=NonlinearProblem(
explicitfun1(p, sols) = nothing
prob1 = NonlinearProblem(

NonlinearFunction{true, SciMLBase.NoSpecialize}(f1), zeros(2), p)
sol1 = solve(prob1, NewtonRaphson())
sol1=solve(prob1, NewtonRaphson())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
sol1=solve(prob1, NewtonRaphson())
sol1 = solve(prob1, NewtonRaphson())


function f2(du, u, p)
du[1] = 2u[2] + u[1] + 1.0
du[2] = u[3]^2 + u[2]
du[3] = u[1]^2 + u[3]
du[1]=2u[2]+u[1]+1.0
du[2]=u[3]^2+u[2]
du[3]=u[1]^2+u[3]
Comment on lines +35 to +37
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
du[1]=2u[2]+u[1]+1.0
du[2]=u[3]^2+u[2]
du[3]=u[1]^2+u[3]
du[1] = 2u[2] + u[1] + 1.0
du[2] = u[3]^2 + u[2]
du[3] = u[1]^2 + u[3]

end
explicitfun2(p, sols) = nothing
prob2 = NonlinearProblem(
explicitfun2(p, sols)=nothing
prob2=NonlinearProblem(
Comment on lines +39 to +40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
explicitfun2(p, sols)=nothing
prob2=NonlinearProblem(
explicitfun2(p, sols) = nothing
prob2 = NonlinearProblem(

NonlinearFunction{true, SciMLBase.NoSpecialize}(f2), zeros(3), p)
sol2 = solve(prob2, NewtonRaphson())
sol2=solve(prob2, NewtonRaphson())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
sol2=solve(prob2, NewtonRaphson())
sol2 = solve(prob2, NewtonRaphson())


function f3(du, u, p)
du[1] = p[1] + 2.0u[1] + 2.5u[2] + 1.5u[3]
du[2] = p[2] + 4.0u[1] - 1.5u[2] + 1.5u[3]
du[3] = p[3] + +u[1] - u[2] - u[3]
end
prob3 = NonlinearProblem(
NonlinearFunction{true, SciMLBase.NoSpecialize}(f3), zeros(3), p)
# Convert f3 to a LinearProblem since it's linear in u
# du = Au + b where A is the coefficient matrix and b is from parameters
A3=[2.0 2.5 1.5; 4.0 -1.5 1.5; 1.0 -1.0 -1.0]
b3=p # b will be updated by explicitfun3
prob3=LinearProblem(A3, b3, zeros(3))
Comment on lines +46 to +48
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
A3=[2.0 2.5 1.5; 4.0 -1.5 1.5; 1.0 -1.0 -1.0]
b3=p # b will be updated by explicitfun3
prob3=LinearProblem(A3, b3, zeros(3))
A3 = [2.0 2.5 1.5; 4.0 -1.5 1.5; 1.0 -1.0 -1.0]
b3 = p # b will be updated by explicitfun3
prob3 = LinearProblem(A3, b3, zeros(3))

function explicitfun3(p, sols)
p[1] = sols[1][1] + sols[1][2] + sols[2][1] + sols[2][2] + sols[2][3]
p[2] = sols[1][1] + sols[1][2] + sols[2][1] + 2.0sols[2][2] + sols[2][3]
p[3] = sols[1][1] + 2.0sols[1][2] + 3.0sols[2][1] + 5.0sols[2][2] +
6.0sols[2][3]
p[1]=-(sols[1][1]+sols[1][2]+sols[2][1]+sols[2][2]+sols[2][3])
p[2]=-(sols[1][1]+sols[1][2]+sols[2][1]+2.0sols[2][2]+sols[2][3])
p[3]=-(sols[1][1]+2.0sols[1][2]+3.0sols[2][1]+5.0sols[2][2]+
6.0sols[2][3])
Comment on lines +50 to +53
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
p[1]=-(sols[1][1]+sols[1][2]+sols[2][1]+sols[2][2]+sols[2][3])
p[2]=-(sols[1][1]+sols[1][2]+sols[2][1]+2.0sols[2][2]+sols[2][3])
p[3]=-(sols[1][1]+2.0sols[1][2]+3.0sols[2][1]+5.0sols[2][2]+
6.0sols[2][3])
p[1] = -(sols[1][1] + sols[1][2] + sols[2][1] + sols[2][2] + sols[2][3])
p[2] = -(sols[1][1] + sols[1][2] + sols[2][1] + 2.0sols[2][2] + sols[2][3])
p[3] = -(sols[1][1] + 2.0sols[1][2] + 3.0sols[2][1] + 5.0sols[2][2] +
6.0sols[2][3])

end
explicitfun3(p, [sol1, sol2])
sol3 = solve(prob3, NewtonRaphson())
manualscc = [sol1; sol2; sol3]
sol3=solve(prob3) # LinearProblem uses default linear solver
manualscc=reduce(vcat, (sol1, sol2, sol3))
Comment on lines +56 to +57
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
sol3=solve(prob3) # LinearProblem uses default linear solver
manualscc=reduce(vcat, (sol1, sol2, sol3))
sol3 = solve(prob3) # LinearProblem uses default linear solver
manualscc = reduce(vcat, (sol1, sol2, sol3))


sccprob = SciMLBase.SCCNonlinearProblem([prob1, prob2, prob3],
sccprob=SciMLBase.SCCNonlinearProblem((prob1, prob2, prob3),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
sccprob=SciMLBase.SCCNonlinearProblem((prob1, prob2, prob3),
sccprob = SciMLBase.SCCNonlinearProblem((prob1, prob2, prob3),

SciMLBase.Void{Any}.([explicitfun1, explicitfun2, explicitfun3]))
scc_sol = solve(sccprob, NewtonRaphson())

# Test with SCCAlg that handles both nonlinear and linear problems
using SCCNonlinearSolve
scc_alg=SCCNonlinearSolve.SCCAlg(nlalg = NewtonRaphson(), linalg = nothing)
scc_sol=solve(sccprob, scc_alg)
Comment on lines +64 to +65
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
scc_alg=SCCNonlinearSolve.SCCAlg(nlalg = NewtonRaphson(), linalg = nothing)
scc_sol=solve(sccprob, scc_alg)
scc_alg = SCCNonlinearSolve.SCCAlg(nlalg = NewtonRaphson(), linalg = nothing)
scc_sol = solve(sccprob, scc_alg)

@test sol ≈ manualscc ≈ scc_sol

# Backwards compat of alg choice
scc_sol=solve(sccprob, NewtonRaphson())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
scc_sol=solve(sccprob, NewtonRaphson())
scc_sol = solve(sccprob, NewtonRaphson())

@test sol ≈ manualscc ≈ scc_sol

import NonlinearSolve # Required for Default

scc_sol = solve(sccprob)
@test sol ≈ manualscc ≈ scc_sol
# Test default interface
scc_sol_default=solve(sccprob)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
scc_sol_default=solve(sccprob)
scc_sol_default = solve(sccprob)

@test sol ≈ manualscc ≈ scc_sol_default
end
Loading