Skip to content

Commit 6dfafb0

Browse files
Ig-dolcidhamJDBetteridge
committed
None option for Riesz Representation in derivatives (#3552)
Also default to L2 projection. --------- Co-authored-by: David A. Ham <david.ham@imperial.ac.uk> Co-authored-by: Jack Betteridge <43041811+JDBetteridge@users.noreply.github.com>
1 parent e6fa04a commit 6dfafb0

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

firedrake/adjoint_utils/function.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,13 @@ def _define_riesz_map_form(self, riesz_representation, V):
275275

276276
@no_annotations
277277
def _ad_convert_type(self, value, options=None):
278-
# `_ad_convert_type` is not annoated unlike to `_ad_convert_riesz`
279-
return self._ad_convert_riesz(value, options=options)
278+
# `_ad_convert_type` is not annotated, unlike `_ad_convert_riesz`
279+
options = {} if options is None else options
280+
riesz_representation = options.get("riesz_representation", "L2")
281+
if riesz_representation is None:
282+
return value
283+
else:
284+
return self._ad_convert_riesz(value, options=options)
280285

281286
def _ad_restore_at_checkpoint(self, checkpoint):
282287
if isinstance(checkpoint, CheckpointBase):

tests/regression/test_adjoint_operators.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -886,3 +886,16 @@ def test_cofunction_subfunctions_with_adjoint():
886886
k.block_variable.tlm_value = Constant(1)
887887
get_working_tape().evaluate_tlm()
888888
assert taylor_test(J_hat, k, Constant(1.0), dJdm=J.block_variable.tlm_value) > 1.9
889+
890+
891+
@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done
892+
def test_none_riesz_representation_to_derivative():
893+
mesh = UnitIntervalMesh(1)
894+
space = FunctionSpace(mesh, "Lagrange", 1)
895+
u = Function(space).interpolate(SpatialCoordinate(mesh)[0])
896+
J = assemble((u ** 2) * dx)
897+
rf = ReducedFunctional(J, Control(u))
898+
assert isinstance(rf.derivative(), Function)
899+
assert isinstance(rf.derivative(options={"riesz_representation": "H1"}), Function)
900+
assert isinstance(rf.derivative(options={"riesz_representation": "L2"}), Function)
901+
assert isinstance(rf.derivative(options={"riesz_representation": None}), Cofunction)

0 commit comments

Comments
 (0)