Skip to content

Commit 4726654

Browse files
SO always 2 args and some mtk wrapper fixes
1 parent 2bfebf9 commit 4726654

File tree

2 files changed

+29
-13
lines changed

2 files changed

+29
-13
lines changed

ext/OptimizationDIExt.jl

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using ADTypes
1010

1111
function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::ADTypes.AbstractADType, p = SciMLBase.NullParameters(), num_cons = 0)
1212
_f = (θ, args...) -> first(f.f(θ, p, args...))
13-
soadtype = DifferentiationInterface.SecondOrder(adtype)
13+
soadtype = DifferentiationInterface.SecondOrder(adtype, adtype)
1414

1515
if f.grad === nothing
1616
extras_grad = prepare_gradient(_f, adtype, x)
@@ -57,6 +57,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
5757
extras_jac = prepare_jacobian(cons_oop, adtype, x)
5858
cons_j = function (J, θ)
5959
jacobian!(cons_oop, J, adtype, θ, extras_jac)
60+
if size(J, 1) == 1
61+
J = vec(J)
62+
end
6063
end
6164
else
6265
cons_j = (J, θ) -> f.cons_j(J, θ, p)
@@ -97,7 +100,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, ca
97100
x = cache.u0
98101
p = cache.p
99102
_f = (θ, args...) -> first(f.f(θ, p, args...))
100-
soadtype = DifferentiationInterface.SecondOrder(adtype)
103+
soadtype = DifferentiationInterface.SecondOrder(adtype, adtype)
101104

102105
if f.grad === nothing
103106
extras_grad = prepare_gradient(_f, adtype, x)
@@ -144,6 +147,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, ca
144147
extras_jac = prepare_jacobian(cons_oop, adtype, x)
145148
cons_j = function (J, θ)
146149
jacobian!(cons_oop, J, adtype, θ, extras_jac)
150+
if size(J, 1) == 1
151+
J = vec(J)
152+
end
147153
end
148154
else
149155
cons_j = (J, θ) -> f.cons_j(J, θ, p)
@@ -183,7 +189,7 @@ end
183189

184190
function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, adtype::ADTypes.AbstractADType, p = SciMLBase.NullParameters(), num_cons = 0)
185191
_f = (θ, args...) -> first(f.f(θ, p, args...))
186-
soadtype = DifferentiationInterface.SecondOrder(adtype)
192+
soadtype = DifferentiationInterface.SecondOrder(adtype, adtype)
187193

188194
if f.grad === nothing
189195
extras_grad = prepare_gradient(_f, adtype, x)
@@ -229,7 +235,11 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
229235
if cons !== nothing && f.cons_j === nothing
230236
extras_jac = prepare_jacobian(cons_oop, adtype, x)
231237
cons_j = function (θ)
232-
jacobian(cons_oop, adtype, θ, extras_jac)
238+
J = jacobian(cons_oop, adtype, θ, extras_jac)
239+
if size(J, 1) == 1
240+
J = vec(J)
241+
end
242+
return J
233243
end
234244
else
235245
cons_j = (θ) -> f.cons_j(θ, p)
@@ -241,10 +251,11 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
241251
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
242252
extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x))
243253

244-
function cons_h(H, θ)
245-
for i in 1:num_cons
254+
function cons_h(θ)
255+
H = map(1:num_cons) do i
246256
hessian(fncs[i], adtype, θ, extras_cons_hess[i])
247257
end
258+
return H
248259
end
249260
else
250261
cons_h = (res, θ) -> f.cons_h(res, θ, p)
@@ -270,7 +281,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, c
270281
x = cache.u0
271282
p = cache.p
272283
_f = (θ, args...) -> first(f.f(θ, p, args...))
273-
soadtype = DifferentiationInterface.SecondOrder(adtype)
284+
soadtype = DifferentiationInterface.SecondOrder(adtype, adtype)
274285

275286
if f.grad === nothing
276287
extras_grad = prepare_gradient(_f, adtype, x)
@@ -316,7 +327,11 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, c
316327
if cons !== nothing && f.cons_j === nothing
317328
extras_jac = prepare_jacobian(cons_oop, adtype, x)
318329
cons_j = function (θ)
319-
jacobian(cons_oop, adtype, θ, extras_jac)
330+
J = jacobian(cons_oop, adtype, θ, extras_jac)
331+
if size(J, 1) == 1
332+
J = vec(J)
333+
end
334+
return J
320335
end
321336
else
322337
cons_j = (θ) -> f.cons_j(θ, p)
@@ -329,9 +344,10 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, c
329344
extras_cons_hess = prepare_hessian.(fncs, Ref(soadtype), Ref(x))
330345

331346
function cons_h(θ)
332-
for i in 1:num_cons
347+
H = map(1:num_cons) do i
333348
hessian(fncs[i], adtype, θ, extras_cons_hess[i])
334349
end
350+
return H
335351
end
336352
else
337353
cons_h = (θ) -> f.cons_h(θ, p)

ext/OptimizationMTKExt.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function OptimizationBase.ADTypes.AutoModelingToolkit(sparse = false, cons_spars
1515
end
1616

1717
function OptimizationBase.instantiate_function(
18-
f, x, adtype::AutoSparse{<:AutoSymbolics, S, C}, p,
18+
f::OptimizationFunction{true}, x, adtype::AutoSparse{<:AutoSymbolics, S, C}, p,
1919
num_cons = 0) where {S, C}
2020
p = isnothing(p) ? SciMLBase.NullParameters() : p
2121

@@ -60,7 +60,7 @@ function OptimizationBase.instantiate_function(
6060
observed = f.observed)
6161
end
6262

63-
function OptimizationBase.instantiate_function(f, cache::OptimizationBase.ReInitCache,
63+
function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
6464
adtype::AutoSparse{<:AutoSymbolics, S, C}, num_cons = 0) where {S, C}
6565
p = isnothing(cache.p) ? SciMLBase.NullParameters() : cache.p
6666

@@ -106,7 +106,7 @@ function OptimizationBase.instantiate_function(f, cache::OptimizationBase.ReInit
106106
observed = f.observed)
107107
end
108108

109-
function OptimizationBase.instantiate_function(f, x, adtype::AutoSymbolics, p,
109+
function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::AutoSymbolics, p,
110110
num_cons = 0)
111111
p = isnothing(p) ? SciMLBase.NullParameters() : p
112112

@@ -151,7 +151,7 @@ function OptimizationBase.instantiate_function(f, x, adtype::AutoSymbolics, p,
151151
observed = f.observed)
152152
end
153153

154-
function OptimizationBase.instantiate_function(f, cache::OptimizationBase.ReInitCache,
154+
function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
155155
adtype::AutoSymbolics, num_cons = 0)
156156
p = isnothing(cache.p) ? SciMLBase.NullParameters() : cache.p
157157

0 commit comments

Comments
 (0)