Skip to content

Commit 4d482b8

Browse files
committed
Some fixes for Cofunction controls
1 parent 795baf3 commit 4d482b8

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

firedrake/adjoint_utils/function.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,3 +341,14 @@ class CofunctionMixin(FunctionMixin):
341341

342342
def _ad_dot(self, other):
343343
return firedrake.assemble(firedrake.action(self, other))
344+
345+
def _ad_init_object(cls, obj):
346+
from firedrake import Cofunction
347+
return Cofunction(cls.function_space()).assign(obj)
348+
349+
def _ad_init_zero(self, dual=False):
350+
from firedrake import Function, Cofunction
351+
if dual:
352+
return Function(self.function_space().dual())
353+
else:
354+
return Cofunction(self.function_space())

tests/firedrake/adjoint/test_optimisation.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,34 @@ def test_tao_simple_inversion(minimize, riesz_representation):
179179
assert_allclose(x.dat.data, source_ref.dat.data, rtol=1e-2)
180180

181181

182+
@pytest.mark.parametrize("minimize", [minimize_tao_lmvm,
183+
pytest.param(minimize_tao_nls, marks=pytest.mark.xfail)])
184+
@pytest.mark.parametrize("riesz_representation", ["L2", "H1"])
185+
@pytest.mark.skipcomplex
186+
def test_tao_cofunction_control(minimize, riesz_representation):
187+
"""Test inversion of source term in helmholtz eqn using TAO."""
188+
mesh = UnitIntervalMesh(10)
189+
V = FunctionSpace(mesh, "CG", 1)
190+
source_ref = Function(V)
191+
x = SpatialCoordinate(mesh)
192+
source_ref.interpolate(cos(pi*x**2))
193+
194+
# compute reference solution
195+
with stop_annotating():
196+
u_ref = _simple_helmholz_model(V, source_ref)
197+
198+
# now rerun annotated model with zero source
199+
source = Cofunction(V.dual())
200+
c = Control(source, riesz_map=riesz_representation)
201+
u = _simple_helmholz_model(V, source.riesz_representation(riesz_representation))
202+
203+
J = assemble(1e6 * (u - u_ref)**2*dx)
204+
rf = ReducedFunctional(J, c)
205+
206+
x = minimize(rf).riesz_representation(riesz_representation)
207+
assert_allclose(x.dat.data, source_ref.dat.data, rtol=1e-2)
208+
209+
182210
class TransformType(Enum):
183211
PRIMAL = auto()
184212
DUAL = auto()

0 commit comments

Comments
 (0)