@@ -20,9 +20,8 @@ project_tangent(M::Manifold, x) = project_tangent!(M, similar(x), x)
2020retract (M:: Manifold , x) = retract! (M, copy (x))
2121
2222# Fake objective function implementing a retraction
23- mutable struct ManifoldObjective{T<: NLSolversBase.AbstractObjective } < :
24- NLSolversBase. AbstractObjective
25- manifold:: Manifold
23+ struct ManifoldObjective{M<: Manifold ,T<: AbstractObjective } <: AbstractObjective
24+ manifold:: M
2625 inner_obj:: T
2726end
2827# TODO : is it safe here to call retract! and change x?
@@ -52,6 +51,20 @@ function NLSolversBase.value_gradient!(obj::ManifoldObjective, x)
5251 return value (obj. inner_obj)
5352end
5453
54+ # In general, we have to compute the gradient/Jacobian separately as it has to be projected
55+ function NLSolversBase. jvp! (obj:: ManifoldObjective , x, v)
56+ xin = retract (obj. manifold, x)
57+ gradient! (obj. inner_obj, xin)
58+ project_tangent! (obj. manifold, gradient (obj. inner_obj), xin)
59+ return dot (gradient (obj. inner_obj), v)
60+ end
61+ function NLSolversBase. value_jvp! (obj:: ManifoldObjective , x, v)
62+ xin = retract (obj. manifold, x)
63+ value_gradient! (obj. inner_obj, xin)
64+ project_tangent! (obj. manifold, gradient (obj. inner_obj), xin)
65+ return value (obj. inner_obj), dot (gradient (obj. inner_obj), v)
66+ end
67+
5568""" Flat Euclidean space {R,C}^N, with projections equal to the identity."""
5669struct Flat <: Manifold end
5770# all the functions below are no-ops, and therefore the generated code
@@ -62,6 +75,10 @@ retract!(M::Flat, x) = x
6275project_tangent (M:: Flat , g, x) = g
6376project_tangent! (M:: Flat , g, x) = g
6477
78+ # Optimizations for `Flat` manifold
79+ NLSolversBase. jvp! (obj:: ManifoldObjective{Flat} , x, v) = jvp! (obj. inner_obj, x, v)
80+ NLSolversBase. value_jvp! (obj:: ManifoldObjective{Flat} , x, v) = value_jvp! (obj. inner_obj, x, v)
81+
6582""" Spherical manifold {|x| = 1}."""
6683struct Sphere <: Manifold end
6784retract! (S:: Sphere , x) = (x ./= norm (x))
0 commit comments