Skip to content

Commit

Permalink
Dispatch on solve_model (#108)
Browse files Browse the repository at this point in the history
* Dispatch on `solve_model`
  • Loading branch information
tmigot authored Jan 26, 2024
1 parent 840398f commit ccbff0f
Show file tree
Hide file tree
Showing 11 changed files with 42 additions and 48 deletions.
10 changes: 5 additions & 5 deletions docs/src/doityourself.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ AdaptiveRegularization.ALL_solvers

To make your own variant we need to implement:
- A new data structure `<: PData{T}` for some real number type `T`.
- A `preprocess(PData::TPData, H, g, gNorm2, α)` function called before each trust-region iteration.
- A `solve_model(PData::PDataST, H, g, gNorm2, n1, n2, δ::T)` function used to solve the algorithm subproblem.
- A `preprocess!(PData::TPData, H, g, gNorm2, α)` function called before each trust-region iteration.
- A `solve_model!(PData::TPData, H, g, gNorm2, n1, n2, δ::T)` function used to solve the algorithm subproblem.

In the rest of this tutorial, we implement a Steihaug-Toint trust-region method using `cg_lanczos` from [`Krylov.jl`](https://github.com/JuliaSmoothOptimizers/Krylov.jl) to solve the linear subproblem with trust-region constraint.

Expand Down Expand Up @@ -57,13 +57,13 @@ end
```
For our Steihaug-Toint implementation, we do not run any preprocess operation, so we use the default one.
```@example 1
function AdaptiveRegularization.preprocess(PData::AdaptiveRegularization.TPData, H, g, gNorm2, n1, n2, α)
function AdaptiveRegularization.preprocess!(PData::AdaptiveRegularization.TPData, H, g, gNorm2, n1, n2, α)
return PData
end
```
We now solve the subproblem.
```@example 1
function solve_modelST_TR(PData::PDataST, H, g, gNorm2, calls, max_calls, δ::T) where {T}
function AdaptiveRegularization.solve_model!(PData::PDataST, H, g, gNorm2, calls, max_calls, δ::T) where {T}
ζ, ξ, maxtol, mintol = PData.ζ, PData.ξ, PData.maxtol, PData.mintol
n = length(g)
# precision = max(1e-12, min(0.5, (gNorm2^ζ)))
Expand Down Expand Up @@ -92,7 +92,7 @@ end

We can now proceed with the main solver call specifying the used `pdata_type` and `solve_model`. Since, `Krylov.cg_lanczos` only uses matrix-vector products, it is sufficient to evaluate the Hessian matrix as an operator, so we provide `hess_type = HessOp`.
```@example 1
ST_TROp(nlp; kwargs...) = TRARC(nlp, pdata_type = PDataST, solve_model = solve_modelST_TR, hess_type = HessOp; kwargs...)
ST_TROp(nlp; kwargs...) = TRARC(nlp, pdata_type = PDataST, hess_type = HessOp; kwargs...)
```
Finally, we can apply our new method to any [`NLPModels`](https://github.com/JuliaSmoothOptimizers/NLPModels.jl).
```@example 1
Expand Down
5 changes: 2 additions & 3 deletions src/AdaptiveRegularization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ The keyword arguments include
- `TR::TrustRegion`: structure with trust-region/ARC parameters, see [`TrustRegion`](@ref). Default: `TrustRegion(T(10.0))`.
- `hess_type::Type{Hess}`: Structure used to handle the hessian. The possible values are: `HessDense`, `HessSparse`, `HessSparseCOO`, `HessOp`. Default: `HessOp`.
- `pdata_type::Type{ParamData}` Structure used for the preprocessing step. Default: `PDataKARC`.
- `solve_model::Function` Function used to solve the subproblem. Default: `solve_modelKARC`.
- `robust::Bool`: `true` implements a robust evaluation of the model. Default: `true`.
- `verbose::Bool`: `true` prints iteration information. Default: `false`.
Additional `kwargs` are used for stopping criterion, see `Stopping.jl`.
Expand Down Expand Up @@ -127,7 +126,7 @@ function TRARC(nlp_stop::NLPStopping; kwargs...)
end

for fun in union(keys(solvers_const), keys(solvers_nls_const))
ht, pt, sm, ka = merge(solvers_const, solvers_nls_const)[fun]
ht, pt, ka = merge(solvers_const, solvers_nls_const)[fun]
@eval begin
function $fun(nlpstop::NLPStopping; kwargs...)
kw_list = Dict{Symbol, Any}()
Expand All @@ -137,7 +136,7 @@ for fun in union(keys(solvers_const), keys(solvers_nls_const))
end
end
merge!(kw_list, Dict(kwargs))
TRARC(nlpstop; hess_type = $ht, pdata_type = $pt, solve_model = $sm, kw_list...)
TRARC(nlpstop; hess_type = $ht, pdata_type = $pt, kw_list...)
end
end
@eval begin
Expand Down
2 changes: 1 addition & 1 deletion src/SolveModel/SolveModelKARC.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function solve_modelKARC(X::PDataKARC, H, g, gNorm2, calls, max_calls, α::T) where {T}
function solve_model!(X::PDataKARC{S, T, Fatol, Frtol}, H, g, gNorm2, calls, max_calls, α::T) where {S, T, Fatol, Frtol}
# target value should be close to satisfy αλ=||d||
start = findfirst(X.positives)
if isnothing(start)
Expand Down
2 changes: 1 addition & 1 deletion src/SolveModel/SolveModelNLSST_TR.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function solve_modelNLSST_TR(PData::PDataNLSST, Jx, Fx, norm_∇f, calls, max_calls, δ::T) where {T}
function solve_model!(PData::PDataNLSST{S, T, Fatol, Frtol}, Jx, Fx, norm_∇f, calls, max_calls, δ::T) where {S, T, Fatol, Frtol}
# cas particulier Steihaug-Toint
# ϵ = sqrt(eps(T)) # * 100.0 # old
# cgtol = max(ϵ, min(cgtol, 9 * cgtol / 10, 0.01 * norm(g)^(1.0 + PData.ζ))) # old
Expand Down
2 changes: 1 addition & 1 deletion src/SolveModel/SolveModelST_TR.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function solve_modelST_TR(PData::PDataST, H, g, gNorm2, calls, max_calls, δ::T) where {T}
function solve_model!(PData::PDataST{S, T, Fatol, Frtol}, H, g, gNorm2, calls, max_calls, δ::T) where {S, T, Fatol, Frtol}
# cas particulier Steihaug-Toint
# ϵ = sqrt(eps(T)) # * 100.0 # old
# cgtol = max(ϵ, min(cgtol, 9 * cgtol / 10, 0.01 * norm(g)^(1.0 + PData.ζ))) # old
Expand Down
2 changes: 1 addition & 1 deletion src/SolveModel/SolveModelTRK.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function solve_modelTRK(X::PDataTRK, H, g, gNorm2, calls, max_calls, α::T) where {T}
function solve_model!(X::PDataTRK{S, T, Fatol, Frtol}, H, g, gNorm2, calls, max_calls, α::T) where {S, T, Fatol, Frtol}
# target value should be close to satisfy α=||d||
start = findfirst(X.positives)
if isnothing(start)
Expand Down
15 changes: 7 additions & 8 deletions src/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ function compute_direction(
∇f,
norm_∇f,
α,
solve_model,
) where {T}
max_hprod = stp.meta.max_cntrs[:neval_hprod]
Hx = workspace.Hstruct.H
return solve_model(PData, Hx, ∇f, norm_∇f, neval_hprod(stp.pb), max_hprod, α)
solve_model!(PData, Hx, ∇f, norm_∇f, neval_hprod(stp.pb), max_hprod, α)
return PData.d, PData.λ
end

function compute_direction(
Expand All @@ -33,12 +33,12 @@ function compute_direction(
∇f,
norm_∇f,
α,
solve_model,
)
max_prod = stp.meta.max_cntrs[:neval_jprod_residual]
Jx = jac_op_residual(stp.pb, workspace.xt)
Fx = workspace.Fx
return solve_model(PData, Jx, Fx, norm_∇f, neval_jprod_residual(stp.pb), max_prod, α)
solve_model!(PData, Jx, Fx, norm_∇f, neval_jprod_residual(stp.pb), max_prod, α)
return PData.d, PData.λ
end

function compute_direction(
Expand All @@ -48,12 +48,12 @@ function compute_direction(
∇f,
norm_∇f,
α,
solve_model,
) where {T, S, Hess <: HessGaussNewtonOp}
max_prod = stp.meta.max_cntrs[:neval_jprod_residual]
Jx = jac_op_residual!(stp.pb, workspace.xt, workspace.Hstruct.Jv, workspace.Hstruct.Jtv)
Fx = workspace.Fx
return solve_model(PData, Jx, Fx, norm_∇f, neval_jprod_residual(stp.pb), max_prod, α)
solve_model!(PData, Jx, Fx, norm_∇f, neval_jprod_residual(stp.pb), max_prod, α)
return PData.d, PData.λ
end

function hessian!(workspace::TRARCWorkspace, nlp, x)
Expand Down Expand Up @@ -100,7 +100,6 @@ function SolverCore.solve!(
solver::TRARCSolver{T, S},
nlp_stop::NLPStopping{Pb, M, SRC, NLPAtX{Score, T, S}, MStp, LoS},
stats::GenericExecutionStats{T, S};
solve_model::Function = solve_modelKARC,
robust::Bool = true,
verbose::Integer = false,
kwargs...,
Expand Down Expand Up @@ -156,7 +155,7 @@ function SolverCore.solve!(

success = false
while !success & (unsuccinarow < max_unsuccinarow)
d, λ = compute_direction(nlp_stop, PData, workspace, ∇f, norm_∇f, α, solve_model)
d, λ = compute_direction(nlp_stop, PData, workspace, ∇f, norm_∇f, α)

slope = ∇f d
Δq = compute_Δq(workspace, Hx, d, ∇f)
Expand Down
29 changes: 14 additions & 15 deletions src/solvers.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,30 @@
const solvers_const = Dict(
:ARCqKdense =>
(HessDense, PDataKARC, solve_modelKARC, [(:shifts => 10.0 .^ (collect(-20.0:1.0:20.0)))]),
(HessDense, PDataKARC, [(:shifts => 10.0 .^ (collect(-20.0:1.0:20.0)))]),
:ARCqKOp =>
(HessOp, PDataKARC, solve_modelKARC, [:shifts => 10.0 .^ (collect(-20.0:1.0:20.0))]),
(HessOp, PDataKARC, [:shifts => 10.0 .^ (collect(-20.0:1.0:20.0))]),
:ARCqKsparse =>
(HessSparse, PDataKARC, solve_modelKARC, [:shifts => 10.0 .^ (collect(-20.0:1.0:20.0))]),
(HessSparse, PDataKARC, [:shifts => 10.0 .^ (collect(-20.0:1.0:20.0))]),
:ARCqKCOO =>
(HessSparseCOO, PDataKARC, solve_modelKARC, [:shifts => 10.0 .^ (collect(-20.0:1.0:20.0))]),
:ST_TRdense => (HessDense, PDataST, solve_modelST_TR, ()),
:ST_TROp => (HessOp, PDataST, solve_modelST_TR, ()),
:ST_TRsparse => (HessSparse, PDataST, solve_modelST_TR, ()),
:TRKdense => (HessDense, PDataTRK, solve_modelTRK, ()),
:TRKOp => (HessOp, PDataTRK, solve_modelTRK, ()),
:TRKsparse => (HessSparse, PDataTRK, solve_modelTRK, ()),
(HessSparseCOO, PDataKARC, [:shifts => 10.0 .^ (collect(-20.0:1.0:20.0))]),
:ST_TRdense => (HessDense, PDataST, ()),
:ST_TROp => (HessOp, PDataST, ()),
:ST_TRsparse => (HessSparse, PDataST, ()),
:TRKdense => (HessDense, PDataTRK, ()),
:TRKOp => (HessOp, PDataTRK, ()),
:TRKsparse => (HessSparse, PDataTRK, ()),
)

const solvers_nls_const = Dict(
:ARCqKOpGN => (
HessGaussNewtonOp,
PDataKARC,
solve_modelKARC,
[:shifts => 10.0 .^ (collect(-10.0:0.5:20.0))],
),
:ST_TROpGN => (HessGaussNewtonOp, PDataST, solve_modelST_TR, ()),
:ST_TROpGN => (HessGaussNewtonOp, PDataST, ()),
:ST_TROpGNLSCgls =>
(HessGaussNewtonOp, PDataNLSST, solve_modelNLSST_TR, [:solver_method => :cgls]),
(HessGaussNewtonOp, PDataNLSST, [:solver_method => :cgls]),
:ST_TROpGNLSLsqr =>
(HessGaussNewtonOp, PDataNLSST, solve_modelNLSST_TR, [:solver_method => :lsqr]),
:ST_TROpLS => (HessOp, PDataNLSST, solve_modelNLSST_TR, ()),
(HessGaussNewtonOp, PDataNLSST, [:solver_method => :lsqr]),
:ST_TROpLS => (HessOp, PDataNLSST, ()),
)
7 changes: 4 additions & 3 deletions src/utils/pdata_struct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ function preprocess!(PData::TPData{T}, H, g, gNorm2, n1, n2, α) where {T}
end

"""
solve_model(PData::TPData, H, g, gNorm2, n1, n2, α)
solve_model!(PData::TPData, H, g, gNorm2, n1, n2, α)
Function called in the `TRARC` algorithm to solve the subproblem.
# Arguments
- `PData::TPData`: data structure used for preprocessing.
- `H`: current Hessian matrix.
Expand All @@ -40,9 +41,9 @@ Function called in the `TRARC` algorithm to solve the subproblem.
- `n2`: Maximum number of Hessian-vector products accepted.
- `α`: current value of the TR/ARC parameter.
It returns a couple `(PData.d, PData.λ)`. Current implementations include: `solve_modelKARC`, `solve_modelTRK`, `solve_modelST_TR`.
It returns a couple `(PData.d, PData.λ)`.
"""
function solve_model(X::TPData{T}, H, g, gNorm2, n1, n2, α) where {T} end
function solve_model!(X::TPData{T}, H, g, gNorm2, n1, n2, α) where {T} end

"""
PDataKARC(::Type{S}, ::Type{T}, n)
Expand Down
14 changes: 5 additions & 9 deletions test/allocation_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,13 @@ function alloc_preprocess(XData, H, g, ng, calls, max_calls, α)
return nothing
end

function alloc_solve_model(solve, XData, H, g, ng, calls, max_calls, α)
solve(XData, H, g, ng, calls, max_calls, α)
function alloc_solve_model(XData, H, g, ng, calls, max_calls, α)
solve_model!(XData, H, g, ng, calls, max_calls, α)
return nothing
end

@testset "Allocation test in preprocess and solvemodel" begin
for (Data, solve) in (
(PDataKARC, AdaptiveRegularization.solve_modelKARC),
(PDataTRK, AdaptiveRegularization.solve_modelTRK),
(PDataST, AdaptiveRegularization.solve_modelST_TR),
)
for Data in (PDataKARC, PDataTRK, PDataST)
XData = Data(S, T, n)
@testset "Allocation test in preprocess with $(Data)" begin
alloc_preprocess(XData, H, g, ng, calls, max_calls, α)
Expand All @@ -34,8 +30,8 @@ end
end

@testset "Allocation test in $solve with $(Data)" begin
alloc_solve_model(solve, XData, H, g, ng, calls, max_calls, α)
al = @allocated alloc_solve_model(solve, XData, H, g, ng, calls, max_calls, α)
alloc_solve_model(XData, H, g, ng, calls, max_calls, α)
al = @allocated alloc_solve_model(XData, H, g, ng, calls, max_calls, α)
@test al == 0
end
end
Expand Down
2 changes: 1 addition & 1 deletion test/allocation_test_main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ S, T = typeof(x0), eltype(x0)
PData = PDataKARC(S, T, n)

function alloc_AdaptiveRegularization(stp, solver, stats)
solve!(solver, stp, stats, solve_model = AdaptiveRegularization.solve_modelKARC)
solve!(solver, stp, stats)
return nothing
end

Expand Down

0 comments on commit ccbff0f

Please sign in to comment.