Skip to content

Commit

Permalink
None option for Riesz Representation in derivatives (#3552)
Browse files Browse the repository at this point in the history
Also default to L2 projection.

---------

Co-authored-by: David A. Ham <david.ham@imperial.ac.uk>
Co-authored-by: Jack Betteridge <43041811+JDBetteridge@users.noreply.github.com>
  • Loading branch information
3 people authored May 8, 2024
1 parent d058083 commit af2c9cc
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
9 changes: 7 additions & 2 deletions firedrake/adjoint_utils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,13 @@ def _define_riesz_map_form(self, riesz_representation, V):

@no_annotations
def _ad_convert_type(self, value, options=None):
# `_ad_convert_type` is not annoated unlike to `_ad_convert_riesz`
return self._ad_convert_riesz(value, options=options)
# `_ad_convert_type` is not annotated, unlike `_ad_convert_riesz`
options = {} if options is None else options
riesz_representation = options.get("riesz_representation", "L2")
if riesz_representation is None:
return value
else:
return self._ad_convert_riesz(value, options=options)

def _ad_restore_at_checkpoint(self, checkpoint):
if isinstance(checkpoint, CheckpointBase):
Expand Down
13 changes: 13 additions & 0 deletions tests/regression/test_adjoint_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,3 +886,16 @@ def test_cofunction_subfunctions_with_adjoint():
k.block_variable.tlm_value = Constant(1)
get_working_tape().evaluate_tlm()
assert taylor_test(J_hat, k, Constant(1.0), dJdm=J.block_variable.tlm_value) > 1.9


@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done
def test_none_riesz_representation_to_derivative():
mesh = UnitIntervalMesh(1)
space = FunctionSpace(mesh, "Lagrange", 1)
u = Function(space).interpolate(SpatialCoordinate(mesh)[0])
J = assemble((u ** 2) * dx)
rf = ReducedFunctional(J, Control(u))
assert isinstance(rf.derivative(), Function)
assert isinstance(rf.derivative(options={"riesz_representation": "H1"}), Function)
assert isinstance(rf.derivative(options={"riesz_representation": "L2"}), Function)
assert isinstance(rf.derivative(options={"riesz_representation": None}), Cofunction)

0 comments on commit af2c9cc

Please sign in to comment.