Skip to content

Commit 935c749

Browse files
committed
[WIP] Do not (mis)use objective as state
1 parent d4bf817 commit 935c749

File tree

18 files changed

+293
-257
lines changed

18 files changed

+293
-257
lines changed

src/Manifolds.jl

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,28 +28,19 @@ end
2828
# TODO: is it safe here to call retract! and change x?
2929
function NLSolversBase.value!(obj::ManifoldObjective, x)
3030
xin = retract(obj.manifold, x)
31-
value!(obj.inner_obj, xin)
32-
end
33-
function NLSolversBase.value(obj::ManifoldObjective)
34-
value(obj.inner_obj)
35-
end
36-
function NLSolversBase.gradient(obj::ManifoldObjective)
37-
gradient(obj.inner_obj)
38-
end
39-
function NLSolversBase.gradient(obj::ManifoldObjective, i::Int)
40-
gradient(obj.inner_obj, i)
31+
return value!(obj.inner_obj, xin)
4132
end
4233
function NLSolversBase.gradient!(obj::ManifoldObjective, x)
4334
xin = retract(obj.manifold, x)
44-
gradient!(obj.inner_obj, xin)
45-
project_tangent!(obj.manifold, gradient(obj.inner_obj), xin)
46-
return gradient(obj.inner_obj)
35+
g_xin = gradient!(obj.inner_obj, xin)
36+
project_tangent!(obj.manifold, g_xin, xin)
37+
return g_xin
4738
end
4839
function NLSolversBase.value_gradient!(obj::ManifoldObjective, x)
4940
xin = retract(obj.manifold, x)
50-
value_gradient!(obj.inner_obj, xin)
51-
project_tangent!(obj.manifold, gradient(obj.inner_obj), xin)
52-
return value(obj.inner_obj)
41+
f_xin, g_xin = value_gradient!(obj.inner_obj, xin)
42+
project_tangent!(obj.manifold, g_xin, xin)
43+
return f_xin, g_xin
5344
end
5445

5546
"""Flat Euclidean space {R,C}^N, with projections equal to the identity."""

src/multivariate/optimize/optimize.jl

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,38 @@
11
update_g!(d, state, method) = nothing
22
function update_g!(d, state, method::FirstOrderOptimizer)
33
# Update the function value and gradient
4-
value_gradient!(d, state.x)
5-
project_tangent!(method.manifold, gradient(d), state.x)
4+
f_x, g_x = value_gradient!(d, state.x)
5+
project_tangent!(method.manifold, g_x, state.x)
6+
state.f_x = f_x
7+
copyto!(state.g_x, g_x)
8+
return nothing
69
end
710
function update_g!(d, state, method::Newton)
811
# Update the function value and gradient
9-
value_gradient!(d, state.x)
12+
f_x, g_x = value_gradient!(d, state.x)
13+
state.f_x = f_x
14+
copyto!(state.g_x, g_x)
15+
return nothing
1016
end
17+
1118
update_fg!(d, state, method) = nothing
12-
update_fg!(d, state, method::ZerothOrderOptimizer) = value!(d, state.x)
19+
function update_fg!(d, state, method::ZerothOrderOptimizer)
20+
f_x = value!(d, state.x)
21+
state.f_x = f_x
22+
return nothing
23+
end
1324
function update_fg!(d, state, method::FirstOrderOptimizer)
14-
value_gradient!(d, state.x)
15-
project_tangent!(method.manifold, gradient(d), state.x)
25+
f_x, g_x = value_gradient!(d, state.x)
26+
project_tangent!(method.manifold, g_x, state.x)
27+
state.f_x = f_x
28+
copyto!(state.g_x, g_x)
29+
return nothing
1630
end
1731
function update_fg!(d, state, method::Newton)
18-
value_gradient!(d, state.x)
32+
f_x, g_x = value_gradient!(d, state.x)
33+
state.f_x = f_x
34+
copyto!(state.g_x, g_x)
35+
return nothing
1936
end
2037

2138
# Update the Hessian
@@ -24,14 +41,14 @@ update_h!(d, state, method::SecondOrderOptimizer) = hessian!(d, state.x)
2441

2542
after_while!(d, state, method, options) = nothing
2643

27-
function initial_convergence(d, state, method::AbstractOptimizer, initial_x, options)
28-
gradient!(d, initial_x)
29-
stopped = !isfinite(value(d)) || any(!isfinite, gradient(d))
30-
g_residual(d, state) <= options.g_abstol, stopped
44+
function initial_convergence(state::AbstractOptimizerState, options::Options)
45+
stopped = !isfinite(state.f_x) || any(!isfinite, state.g_x)
46+
return g_residual(state) <= options.g_abstol, stopped
3147
end
32-
function initial_convergence(d, state, method::ZerothOrderOptimizer, initial_x, options)
48+
function initial_convergence(::ZerothOrderState, ::Options)
3349
false, false
3450
end
51+
3552
function optimize(
3653
d::D,
3754
initial_x::Tx,
@@ -51,7 +68,7 @@ function optimize(
5168
f_limit_reached, g_limit_reached, h_limit_reached = false, false, false
5269
x_converged, f_converged, f_increased, counter_f_tol = false, false, false, 0
5370

54-
g_converged, stopped = initial_convergence(d, state, method, initial_x, options)
71+
g_converged, stopped = initial_convergence(state, options)
5572
converged = g_converged || stopped
5673
# prepare iteration counter (used to make "initial state" trace entry)
5774
iteration = 0
@@ -113,11 +130,11 @@ function optimize(
113130
end
114131
end
115132

116-
if g_calls(d) > 0 && !all(isfinite, gradient(d))
133+
if hasproperty(state, :g_x) && !all(isfinite, state.g_x)
117134
options.show_warnings && @warn "Terminated early due to NaN in gradient."
118135
break
119136
end
120-
if h_calls(d) > 0 && !(d isa TwiceDifferentiableHV) && !all(isfinite, hessian(d))
137+
if hasproperty(state, :H_x) && !all(isfinite, state.H_x)
121138
options.show_warnings && @warn "Terminated early due to NaN in Hessian."
122139
break
123140
end
@@ -141,7 +158,7 @@ function optimize(
141158
)
142159

143160
termination_code =
144-
_termination_code(d, g_residual(d, state), state, stopped_by, options)
161+
_termination_code(d, g_residual(state), state, stopped_by, options)
145162

146163
return MultivariateOptimizationResults{
147164
typeof(method),
@@ -162,10 +179,10 @@ function optimize(
162179
x_relchange(state),
163180
Tf(options.f_abstol),
164181
Tf(options.f_reltol),
165-
f_abschange(d, state),
166-
f_relchange(d, state),
182+
f_abschange(state),
183+
f_relchange(state),
167184
Tf(options.g_abstol),
168-
g_residual(d, state),
185+
g_residual(state),
169186
tr,
170187
f_calls(d),
171188
g_calls(d),
@@ -186,13 +203,13 @@ function _termination_code(d, gres, state, stopped_by, options)
186203
elseif (iszero(options.x_abstol) && x_abschange(state) <= options.x_abstol) ||
187204
(iszero(options.x_reltol) && x_relchange(state) <= options.x_reltol)
188205
TerminationCode.NoXChange
189-
elseif (iszero(options.f_abstol) && f_abschange(d, state) <= options.f_abstol) ||
190-
(iszero(options.f_reltol) && f_relchange(d, state) <= options.f_reltol)
206+
elseif (iszero(options.f_abstol) && f_abschange(state) <= options.f_abstol) ||
207+
(iszero(options.f_reltol) && f_relchange(state) <= options.f_reltol)
191208
TerminationCode.NoObjectiveChange
192209
elseif x_abschange(state) <= options.x_abstol || x_relchange(state) <= options.x_reltol
193210
TerminationCode.SmallXChange
194-
elseif f_abschange(d, state) <= options.f_abstol ||
195-
f_relchange(d, state) <= options.f_reltol
211+
elseif f_abschange(state) <= options.f_abstol ||
212+
f_relchange(state) <= options.f_reltol
196213
TerminationCode.SmallObjectiveChange
197214
elseif stopped_by.ls_failed
198215
TerminationCode.FailedLinesearch
@@ -210,11 +227,11 @@ function _termination_code(d, gres, state, stopped_by, options)
210227
TerminationCode.HessianCalls
211228
elseif stopped_by.f_increased
212229
TerminationCode.ObjectiveIncreased
213-
elseif f_calls(d) > 0 && !isfinite(value(d))
214-
TerminationCode.GradientNotFinite
215-
elseif g_calls(d) > 0 && !all(isfinite, gradient(d))
230+
elseif !isfinite(state.f_x)
231+
TerminationCode.ObjectiveNotFinite
232+
elseif hasproperty(state, :g_x) && !all(isfinite, state.g_x)
216233
TerminationCode.GradientNotFinite
217-
elseif h_calls(d) > 0 && !(d isa TwiceDifferentiableHV) && !all(isfinite, hessian(d))
234+
elseif hasproperty(state, :H_x) && !all(isfinite, state.H_x)
218235
TerminationCode.HessianNotFinite
219236
else
220237
TerminationCode.NotImplemented

src/multivariate/solvers/constrained/fminbox.jl

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -77,59 +77,66 @@ function value!!(bw::BarrierWrapper, x)
7777
bw.Fb = value(bw.b, x)
7878
bw.Ftotal = bw.mu * bw.Fb
7979
if in_box(bw, x)
80-
value!!(bw.obj, x)
81-
bw.Ftotal += value(bw.obj)
80+
F = value!!(bw.obj, x)
81+
bw.Ftotal += F
8282
end
8383
end
8484
function value_gradient!!(bw::BarrierWrapper, x)
8585
bw.Fb = value(bw.b, x)
86-
bw.Ftotal = bw.mu * bw.Fb
8786
bw.DFb .= _barrier_term_gradient.(x, bw.b.lower, bw.b.upper)
88-
bw.DFtotal .= bw.mu .* bw.DFb
8987
if in_box(bw, x)
90-
value_gradient!!(bw.obj, x)
91-
bw.Ftotal += value(bw.obj)
92-
bw.DFtotal .+= gradient(bw.obj)
88+
F, DF = value_gradient!!(bw.obj, x)
89+
bw.Ftotal = muladd(bw.mu, bw.Fb, F)
90+
bw.DFtotal .= muladd.(bw.mu, bw.DFb, DF)
91+
else
92+
bw.Ftotal = bw.mu * bw.Fb
93+
bw.DFtotal .= bw.mu .* bw.DFb
9394
end
94-
95+
return bw.Ftotal, bw.DFtotal
9596
end
9697
function value_gradient!(bb::BarrierWrapper, x)
9798
bb.DFb .= _barrier_term_gradient.(x, bb.b.lower, bb.b.upper)
9899
bb.Fb = value(bb.b, x)
99-
bb.DFtotal .= bb.mu .* bb.DFb
100-
bb.Ftotal = bb.mu * bb.Fb
101-
102100
if in_box(bb, x)
103-
value_gradient!(bb.obj, x)
104-
bb.DFtotal .+= gradient(bb.obj)
105-
bb.Ftotal += value(bb.obj)
101+
F, DF = value_gradient!(bb.obj, x)
102+
bb.DFtotal .= muladd.(bb.mu, bb.DFb, DF)
103+
bb.Ftotal = muladd(bb.mu, bb.Fb, F)
104+
else
105+
bb.DFtotal .= bb.mu .* bb.DFb
106+
bb.Ftotal = bb.mu * bb.Fb
106107
end
108+
return bb.Ftotal, bb.DFtotal
107109
end
108110
value(bb::BoxBarrier, x) =
109111
mapreduce(x -> _barrier_term_value(x...), +, zip(x, bb.lower, bb.upper))
110112
function value!(obj::BarrierWrapper, x)
111113
obj.Fb = value(obj.b, x)
112114
obj.Ftotal = obj.mu * obj.Fb
113115
if in_box(obj, x)
114-
value!(obj.obj, x)
115-
obj.Ftotal += value(obj.obj)
116+
F = value!(obj.obj, x)
117+
obj.Ftotal += F
116118
end
117119
obj.Ftotal
118120
end
119-
value(obj::BarrierWrapper) = obj.Ftotal
121+
120122
function value(obj::BarrierWrapper, x)
121-
F = obj.mu * value(obj.b, x)
123+
Fb = value(obj.b, x)
122124
if in_box(obj, x)
123-
F += value(obj.obj, x)
125+
return muladd(obj.mu, Fb, value(obj.obj, x))
126+
else
127+
return obj.mu * Fb
124128
end
125-
F
126129
end
127130
function gradient!(obj::BarrierWrapper, x)
128-
gradient!(obj.obj, x)
129-
obj.DFb .= gradient(obj.b, obj.DFb, x) # this should just be inplace?
130-
obj.DFtotal .= gradient(obj.obj) .+ obj.mu * obj.Fb
131+
obj.DFb .= _barrier_term_gradient.(x, obj.b.lower, obj.b.upper)
132+
if in_box(obj.b, x)
133+
DF = gradient!(obj.obj, x)
134+
obj.DFtotal .= muladd.(obj.mu, obj.Fb, DF)
135+
else
136+
obj.DFtotal .= obj.mu .* obj.DFb
137+
end
138+
return obj.DFtotal
131139
end
132-
gradient(obj::BarrierWrapper) = obj.DFtotal
133140

134141
# this mutates mu but not the gradients
135142
# Super unsafe in that it depends on x_df being correct!

src/multivariate/solvers/constrained/ipnewton/interior.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,9 @@ ls_update!(
209209
function initial_convergence(d, state, method::ConstrainedOptimizer, initial_x, options)
210210
# TODO: Make sure state.bgrad has been evaluated at initial_x
211211
# state.bgrad normally comes from constraints.c!(..., initial_x) in initial_state
212-
gradient!(d, initial_x)
213-
stopped = !isfinite(value(d)) || any(!isfinite, gradient(d))
214-
g_residual(d, state) + norm(state.bgrad, Inf) < options.g_abstol, stopped
212+
f_x, g_x = gradient!(d, initial_x)
213+
stopped = !isfinite(f_x) || any(!isfinite, g_x)
214+
g_residual(g_x, state) + norm(state.bgrad, Inf) < options.g_abstol, stopped
215215
end
216216

217217
function optimize(
@@ -342,10 +342,10 @@ function optimize(
342342
x_relchange(state),
343343
T(options.f_abstol),
344344
T(options.f_reltol),
345-
f_abschange(d, state),
346-
f_relchange(d, state),
345+
f_abschange(state),
346+
f_relchange(state),
347347
T(options.g_abstol),
348-
g_residual(d, state),
348+
g_residual(state),
349349
tr,
350350
f_calls(d),
351351
g_calls(d),

src/multivariate/solvers/constrained/ipnewton/utilities/trace.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ function trace!(tr, d, state, iteration, method::IPOptimizer, options, curr_time
4545
update!(
4646
tr,
4747
iteration,
48-
value(d),
48+
state.f_x,
4949
g_norm,
5050
dt,
5151
options.store_trace,

src/multivariate/solvers/constrained/samin.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,10 +225,10 @@ function optimize(
225225
NaN,# x_abschange(state),
226226
f_tol,#T(options.f_tol),
227227
0.0,#T(options.f_tol),
228-
f_absΔ,#f_abschange(d, state),
229-
NaN,#f_abschange(d, state),
228+
f_absΔ,#f_abschange(state),
229+
NaN,#f_abschange(state),
230230
0.0,#T(options.g_tol),
231-
NaN,#g_residual(d),
231+
NaN,#g_residual(state),
232232
tr,
233233
f_calls(d),
234234
g_calls(d),

src/multivariate/solvers/first_order/accelerated_gradient_descent.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,21 @@ end
3535

3636
function initial_state(
3737
method::AcceleratedGradientDescent,
38-
options,
38+
options::Options,
3939
d,
40-
initial_x::AbstractArray{T},
41-
) where {T}
40+
initial_x::AbstractArray,
41+
)
4242
initial_x = copy(initial_x)
4343
retract!(method.manifold, initial_x)
44-
45-
value_gradient!!(d, initial_x)
46-
47-
project_tangent!(method.manifold, gradient(d), initial_x)
44+
f_x, g_x = value_gradient!(d, initial_x)
45+
project_tangent!(method.manifold, g_x, initial_x)
4846

4947
AcceleratedGradientDescentState(
5048
copy(initial_x), # Maintain current state in state.x
51-
copy(initial_x), # Maintain previous state in state.x_previous
52-
real(T)(NaN), # Store previous f in state.f_x_previous
49+
copy(g_x), # Maintain current gradient in state.g_x
50+
f_x, # Maintain current f in state.f_x
51+
fill!(similar(initial_x), NaN), # Maintain previous state in state.x_previous
52+
oftype(f_x, NaN), # Store previous f in state.f_x_previous
5353
0, # Iteration
5454
copy(initial_x), # Maintain intermediary current state in state.y
5555
similar(initial_x), # Maintain intermediary state in state.y_previous
@@ -63,11 +63,14 @@ function update_state!(
6363
state::AcceleratedGradientDescentState,
6464
method::AcceleratedGradientDescent,
6565
)
66-
value_gradient!(d, state.x)
66+
f_x, g_x = value_gradient!(d, state.x)
6767
state.iteration += 1
68-
project_tangent!(method.manifold, gradient(d), state.x)
68+
project_tangent!(method.manifold, g_x, state.x)
69+
copyto!(state.g_x, g_x)
70+
state.f_x = f_x
71+
6972
# Search direction is always the negative gradient
70-
state.s .= .-gradient(d)
73+
state.s .= .-g_x
7174

7275
# Determine the distance of movement along the search line
7376
lssuccess = perform_linesearch!(state, method, ManifoldObjective(method.manifold, d))

0 commit comments

Comments
 (0)