Skip to content

Link/Provide full tested tutorial for Optimization in FAQ #3952

@bgctw

Description

@bgctw

Is your feature request related to a problem? Please describe.

I find it hard to translate the help given in FAQ on Optimization/AD into code that actually performs a parameter optimization.

Other also encountered difficulties with Optimization:

Describe the solution you’d like

  • Provide a short tutorial of an executable code, that actually optimizes a subset of parameters and initial conditions of an ODEProblem derived from a MTK system.
  • Implement tests for this problem that use different Optimizers and AD backends (should include ForwardDiff and Zygote)

Describe alternatives you’ve considered

Additional context

I propose a modification of the SciMLSensitivity example .

I provide the following code as a start for the tutorial. It tries 3 different approaches of updating the Problem, but encounters several obstacles. The third alternative works, but only using ForwardDiff.jl and is rather complicated.

The example uses the standard quite simple Lotka-Volterra problem, but simulates some complexity by using an non-scalar parameter, px[1:2] = [α, β].

import Pkg; 
Pkg.activate(;temp=true)
Pkg.add(["OrdinaryDiffEq","Optimization","OptimizationPolyalgorithms","SciMLSensitivity",
  "Zygote","ForwardDiff","ModelingToolkit","SymbolicIndexingInterface",
  "SciMLStructures","SciMLBase","ComponentArrays","Plots"])
import OrdinaryDiffEq as ODE
import Optimization as OPT
import OptimizationPolyalgorithms as OPA
import SciMLSensitivity as SMS
import Zygote
import ForwardDiff
using SciMLBase
using ComponentArrays
#@usingany Plots: Plots
#using Plots: Plots

using ModelingToolkit, OrdinaryDiffEq
using ModelingToolkit: t_nounits as t, D_nounits as D, ModelingToolkit as MTK
using SymbolicIndexingInterface: SymbolicIndexingInterface as SII
using SciMLStructures: SciMLStructures as SS

@variables x(t) y(t) z(t)
@parameters px[1:2]=[1.5, 1.0] γ=3.0 δ=1.0

eqs = [D(x) ~ px[1] * x - px[2] * x * y
       D(y) ~ -γ * y + δ * x * y
       z ~ x + y]

@named sys = System(eqs, t)
simpsys = mtkcompile(sys)
tsteps = 0.0:0.1:10.0
tspan = extrema(tsteps)
prob = ODEProblem(simpsys, [x => 1.1, y => 1.2], tspan)
sol = sol_true = ODE.solve(prob, ODE.Tsit5(), saveat = tsteps)
#Plots.plot(sol)

paropt = [px, γ]
paropt_sym = Symbol.(paropt)
n_paropt = sum(length.(paropt))
stateopt = [x]
stateopt_sym = Symbol.(stateopt)
#p_true = vcat(prob.ps[paropt]..., prob[stateopt]...) 
p_true = vcat(
    ComponentVector(;zip(paropt_sym,prob.ps[paropt])...),
    ComponentVector(;zip(stateopt_sym,prob[stateopt])...))
p0 = p = p_true .+ randn(length(p_true)) .* 0.2
probo = remake(prob) # copy
nest_structure(p, syms) = [p[k] for k in syms]
parstateopt = vcat(paropt, stateopt)
parstateopt_sym = Symbol.(parstateopt)
p0n = nest_structure(p0, parstateopt_sym)

#--------------- recommended method for optimization at MTK.FAQ
# obstacle: do not find documentation for initial conditions
# obstacle: problems with both Zygote and ForwardDiff
setter! = SII.setp(simpsys, paropt)
setter!(probo, nest_structure(p0, paropt_sym)) 
function loss(p)
    local p_struc = nest_structure(p, paropt_sym) #  omit u0
    setter!(probo, p_struc)
    local sol = ODE.solve(probo, ODE.Tsit5(), saveat = tsteps)
    local loss = sum(abs2, sol .- sol_true)
    return loss
end
loss(p0)
#Zygote.gradient(loss, p0)
#ForwardDiff.gradient(loss, p0)


#------------- alternative using indices, 
# documentation for initial values
# obstacle: problems with both Zygote and ForwardDiff 
probo = remake(prob)
ps_ind = MTK.parameter_index.(Ref(simpsys), paropt)
setindex!.(Ref(probo.ps), nest_structure(p0, paropt_sym), ps_ind)
probo.ps[px] == p0[Symbol("px[1:2]")]
x_ind = MTK.variable_index.(Ref(simpsys), stateopt) # 2
#setindex!(probo.ps, [p0[3]], x_ind)

function loss(p)
    setindex!.(Ref(probo.ps), nest_structure(p, paropt_sym), ps_ind)
    # for simplicity, start omitting u0
    local xv = probo.u0
    # local xvn = [begin
    #     ii = findfirst(==(i), x_ind)
    #     isnothing(ii) ? xv[i] : p[length(ps_ind) + ii]
    # end for i in  axes(xv,1)]
    #local xvn = vcat(xv[1], p[3])
    #probox = remake(probo, u0 = xvn) 
    local sol = ODE.solve(probo, ODE.Tsit5(), saveat = tsteps)
    local loss = sum(abs2, sol .- sol_true)
    return loss
end
loss(p0)
Zygote.gradient(loss, p0) # nothing?
#ForwardDiff.gradient(loss, p0)


#--------------- alternative: SciMLStructures
# works with ForwardDiff
# obstacle: how to infere positions of optimized parameters in canicalized buffer?
# obstacle: error or zero Zygote.gradient for initial conditions ?
probo = remake(prob)
function loss(p)
    local sol = compute_sol(p, probo)
    local loss = SciMLBase.successful_retcode(sol) ? sum(abs2, sol .- sol_true) : 1e30
    return loss
end
function compute_sol(p, probo) # variant without Zygote compile error, but not general
    local pv = SII.parameter_values(probo)
    local buf, _, _ = SS.canonicalize(SS.Tunable(), pv)
    local bufx, _, _ = SS.canonicalize(SS.Initials(), pv)
    local bufn = vcat(p[1:2], buf[3], p[3])       # TODO describe general
    local pvn = SS.replace(SS.Tunable(), pv, bufn)
    local bufxn = vcat(bufx[1], p[4], bufx[3:end]) # TODO describe general
    local pvn2 = SS.replace(SS.Initials(), pvn, bufxn)
    #local probon = remake(probo; u0 = xvn, p = pvn) # u0 set to pvn.initials instead
    local probon = remake(probo; p = pvn2) 
    local probon2 = probon
    #local xvn = vcat(xv[1], p[3])
    #local probon2 = remake(probon; u0 = xvn) # mutating?
    #@show probon2.u0, probon2.p
    local sol = ODE.solve(probon2, ODE.Tsit5(), saveat = tsteps)
    return sol
end
psetter! = SII.setp(probo, paropt)
ssetter! = SII.setu(probo, stateopt)
function compute_sol(p, probo)
    local pv = SII.parameter_values(probo)
    local buf, _, _ = SS.canonicalize(SS.Tunable(), pv)
    local bufx, _, _ = SS.canonicalize(SS.Initials(), pv)
    # for ForwardDiff need to convert the eltype of the entire portions
    ET = eltype(p)
    local pvt = pv
    pvt = eltype(buf) == ET ? pvt : SS.replace(SS.Tunable(), pvt, ET.(buf)) 
    pvt = eltype(bufx) == ET ? pvt : SS.replace(SS.Initials(), pvt, ET.(bufx)) 
    #local psetter! = SII.setp(probo, paropt)
    local p_struc = nest_structure(p, paropt_sym) 
    psetter!(pvt, p_struc)
    #local ssetter! = SII.setu(probo, stateopt)
    local s_struc = nest_structure(p, stateopt_sym) 
    #ssetter!(pvt, s_struc)   # state_values(MTKParameters) not implemented
    local probon2 = remake(probo, p = pvt)
    ssetter!(probon2, s_struc)
    #local sol = ODE.solve(probon2, ODE.Tsit5(), saveat = tsteps)
    # need another remake to update probon2.p.initial to probon2.u0
    local probon3 = remake(probon2, u0=probon2.u0)
    #@show probon3.u0, probon3.p
    local sol = ODE.solve(probon3, ODE.Tsit5(), saveat = tsteps)
    return sol
end
#include("tmp/test.jl")
loss(p0)
loss(p)
loss(p_true)
#Zygote.gradient(loss, p0)
ForwardDiff.gradient(loss, p0)

callback = function (state, l)
    display(l)
    # p = state.u
    # sol = compute_sol(p)
    # plt = Plots.plot(sol, ylim = (0, 7))
    # display(plt)
    # Tell Optimization.solve to not halt the optimization. If return true, then
    # optimization stops.
    return false
end

adtype = OPT.AutoForwardDiff()
#adtype = OPT.AutoZygote()
optf = OPT.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = OPT.OptimizationProblem(optf, p0)

opt_alg = OPA.PolyOpt()
#opt_alg = OPA.LBFGS()
result_ode = OPT.solve(optprob, opt_alg,
    callback = callback,
    maxiters = 20,
    )

result_ode.stats
p = result_ode.u
hcat(p0, p, p_true)
loss(result_ode.u)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions