Skip to content

Commit 28b8899

Browse files
committed
Compute JVP in line searches
1 parent d4bf817 commit 28b8899

File tree

32 files changed

+224
-210
lines changed

32 files changed

+224
-210
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ jobs:
1111
strategy:
1212
matrix:
1313
version:
14-
- "min"
15-
- "lts"
14+
# - "min"
15+
# - "lts"
1616
- "1"
1717
os:
1818
- ubuntu-latest

.github/workflows/Docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
with:
2020
version: '1'
2121
- name: Install dependencies
22-
run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'
22+
run: julia --project=docs/ -e 'using Pkg; Pkg.instantiate()'
2323
- name: Build and deploy
2424
env:
2525
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # If authenticating with GitHub Actions token

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ LineSearches = "7.4.0"
3434
LinearAlgebra = "<0.0.1, 1.6"
3535
MathOptInterface = "1.17"
3636
Measurements = "2.14.1"
37-
NLSolversBase = "7.9.0"
37+
NLSolversBase = "8"
3838
NaNMath = "0.3.2, 1"
3939
OptimTestProblems = "2.0.3"
4040
PositiveFactorizations = "0.2.2"
@@ -65,3 +65,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6565

6666
[targets]
6767
test = ["Test", "Aqua", "Distributions", "ExplicitImports", "ForwardDiff", "JET", "MathOptInterface", "Measurements", "OptimTestProblems", "Random", "RecursiveArrayTools", "StableRNGs", "ReverseDiff"]
68+
69+
[sources]
70+
LineSearches = { url = "https://github.com/devmotion/LineSearches.jl.git", rev = "dmw/jvp" }
71+
NLSolversBase = { url = "https://github.com/devmotion/NLSolversBase.jl.git", rev = "dmw/jvp" }

docs/Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1212
Documenter = "1"
1313
Literate = "2"
1414

15-
[sources.Optim]
16-
path = ".."
15+
[sources]
16+
Optim = { path = ".." }
17+
NLSolversBase = { url = "https://github.com/devmotion/NLSolversBase.jl.git", rev = "dmw/jvp" }

docs/src/examples/ipnewton_basics.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
# constraint is unbounded from below or above respectively.
2323

2424
using Optim, NLSolversBase #hide
25+
import ADTypes #hide
2526
import NLSolversBase: clear! #hide
2627

2728
# # Constrained optimization with `IPNewton`

ext/OptimMOIExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module OptimMOIExt
22

33
using Optim
4-
using Optim.LinearAlgebra: rmul!
4+
using Optim.LinearAlgebra: rmul!
55
import MathOptInterface as MOI
66

77
function __init__()

src/Manifolds.jl

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@ project_tangent(M::Manifold, x) = project_tangent!(M, similar(x), x)
2020
retract(M::Manifold, x) = retract!(M, copy(x))
2121

2222
# Fake objective function implementing a retraction
23-
mutable struct ManifoldObjective{T<:NLSolversBase.AbstractObjective} <:
24-
NLSolversBase.AbstractObjective
25-
manifold::Manifold
23+
struct ManifoldObjective{M<:Manifold,T<:AbstractObjective} <: AbstractObjective
24+
manifold::M
2625
inner_obj::T
2726
end
2827
# TODO: is it safe here to call retract! and change x?
@@ -52,6 +51,20 @@ function NLSolversBase.value_gradient!(obj::ManifoldObjective, x)
5251
return value(obj.inner_obj)
5352
end
5453

54+
# In general, we have to compute the gradient/Jacobian separately as it has to be projected
55+
function NLSolversBase.jvp!(obj::ManifoldObjective, x, v)
56+
xin = retract(obj.manifold, x)
57+
gradient!(obj.inner_obj, xin)
58+
project_tangent!(obj.manifold, gradient(obj.inner_obj), xin)
59+
return dot(gradient(obj.inner_obj), v)
60+
end
61+
function NLSolversBase.value_jvp!(obj::ManifoldObjective, x, v)
62+
xin = retract(obj.manifold, x)
63+
value_gradient!(obj.inner_obj, xin)
64+
project_tangent!(obj.manifold, gradient(obj.inner_obj), xin)
65+
return value(obj.inner_obj), dot(gradient(obj.inner_obj), v)
66+
end
67+
5568
"""Flat Euclidean space {R,C}^N, with projections equal to the identity."""
5669
struct Flat <: Manifold end
5770
# all the functions below are no-ops, and therefore the generated code
@@ -62,6 +75,10 @@ retract!(M::Flat, x) = x
6275
project_tangent(M::Flat, g, x) = g
6376
project_tangent!(M::Flat, g, x) = g
6477

78+
# Optimizations for `Flat` manifold
79+
NLSolversBase.jvp!(obj::ManifoldObjective{Flat}, x, v) = jvp!(obj.inner_obj, x, v)
80+
NLSolversBase.value_jvp!(obj::ManifoldObjective{Flat}, x, v) = value_jvp!(obj.inner_obj, x, v)
81+
6582
"""Spherical manifold {|x| = 1}."""
6683
struct Sphere <: Manifold end
6784
retract!(S::Sphere, x) = (x ./= norm(x))

src/Optim.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ documentation online at http://julianlsolvers.github.io/Optim.jl/stable/ .
1616
"""
1717
module Optim
1818

19+
import ADTypes
20+
1921
using PositiveFactorizations: Positive # for globalization strategy in Newton
2022

2123
using LineSearches: LineSearches # for globalization strategy in Quasi-Newton algs
@@ -35,15 +37,21 @@ using NLSolversBase:
3537
NonDifferentiable,
3638
OnceDifferentiable,
3739
TwiceDifferentiable,
38-
TwiceDifferentiableHV,
3940
AbstractConstraints,
4041
ConstraintBounds,
4142
TwiceDifferentiableConstraints,
4243
nconstraints,
4344
nconstraints_x,
45+
value,
46+
value!,
47+
gradient,
48+
gradient!,
49+
value_gradient!,
50+
jvp,
51+
jvp!,
52+
value_jvp!,
4453
hessian,
4554
hessian!,
46-
hessian!!,
4755
hv_product,
4856
hv_product!
4957

src/api.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,18 +100,15 @@ g_norm_trace(r::OptimizationResults) =
100100
g_norm_trace(r::MultivariateOptimizationResults) = [state.g_norm for state in trace(r)]
101101

102102
f_calls(r::OptimizationResults) = r.f_calls
103-
f_calls(d) = first(d.f_calls)
103+
f_calls(d::AbstractObjective) = NLSolversBase.f_calls(d)
104104

105105
g_calls(r::OptimizationResults) = error("g_calls is not implemented for $(summary(r)).")
106106
g_calls(r::MultivariateOptimizationResults) = r.g_calls
107-
g_calls(d::NonDifferentiable) = 0
108-
g_calls(d) = first(d.df_calls)
107+
g_calls(d::AbstractObjective) = NLSolversBase.g_calls(d) + NLSolversBase.jvp_calls(d)
109108

110109
h_calls(r::OptimizationResults) = error("h_calls is not implemented for $(summary(r)).")
111110
h_calls(r::MultivariateOptimizationResults) = r.h_calls
112-
h_calls(d::Union{NonDifferentiable,OnceDifferentiable}) = 0
113-
h_calls(d) = first(d.h_calls)
114-
h_calls(d::TwiceDifferentiableHV) = first(d.hv_calls)
111+
h_calls(d::AbstractObjective) = NLSolversBase.h_calls(d) + NLSolversBase.hv_calls(d)
115112

116113
converged(r::UnivariateOptimizationResults) = r.stopped_by.converged
117114
function converged(r::MultivariateOptimizationResults)

src/multivariate/optimize/interface.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -66,20 +66,6 @@ promote_objtype(
6666
inplace::Bool,
6767
f::InplaceObjective,
6868
) = TwiceDifferentiable(f, x, real(zero(eltype(x))))
69-
promote_objtype(
70-
method::SecondOrderOptimizer,
71-
x,
72-
autodiff::ADTypes.AbstractADType,
73-
inplace::Bool,
74-
f::NLSolversBase.InPlaceObjectiveFGHv,
75-
) = TwiceDifferentiableHV(f, x)
76-
promote_objtype(
77-
method::SecondOrderOptimizer,
78-
x,
79-
autodiff::ADTypes.AbstractADType,
80-
inplace::Bool,
81-
f::NLSolversBase.InPlaceObjectiveFG_Hv,
82-
) = TwiceDifferentiableHV(f, x)
8369
promote_objtype(method::SecondOrderOptimizer, x, autodiff::ADTypes.AbstractADType, inplace::Bool, f, g) =
8470
TwiceDifferentiable(
8571
f,

0 commit comments

Comments
 (0)