@@ -179,6 +179,34 @@ def test_tao_simple_inversion(minimize, riesz_representation):
179
179
assert_allclose (x .dat .data , source_ref .dat .data , rtol = 1e-2 )
180
180
181
181
182
+ @pytest .mark .parametrize ("minimize" , [minimize_tao_lmvm ,
183
+ pytest .param (minimize_tao_nls , marks = pytest .mark .xfail )])
184
+ @pytest .mark .parametrize ("riesz_representation" , ["L2" , "H1" ])
185
+ @pytest .mark .skipcomplex
186
+ def test_tao_cofunction_control (minimize , riesz_representation ):
187
+ """Test inversion of source term in helmholtz eqn using TAO."""
188
+ mesh = UnitIntervalMesh (10 )
189
+ V = FunctionSpace (mesh , "CG" , 1 )
190
+ source_ref = Function (V )
191
+ x = SpatialCoordinate (mesh )
192
+ source_ref .interpolate (cos (pi * x ** 2 ))
193
+
194
+ # compute reference solution
195
+ with stop_annotating ():
196
+ u_ref = _simple_helmholz_model (V , source_ref )
197
+
198
+ # now rerun annotated model with zero source
199
+ source = Cofunction (V .dual ())
200
+ c = Control (source , riesz_map = riesz_representation )
201
+ u = _simple_helmholz_model (V , source .riesz_representation (riesz_representation ))
202
+
203
+ J = assemble (1e6 * (u - u_ref )** 2 * dx )
204
+ rf = ReducedFunctional (J , c )
205
+
206
+ x = minimize (rf ).riesz_representation (riesz_representation )
207
+ assert_allclose (x .dat .data , source_ref .dat .data , rtol = 1e-2 )
208
+
209
+
182
210
class TransformType (Enum ):
183
211
PRIMAL = auto ()
184
212
DUAL = auto ()
0 commit comments