Skip to content

Commit

Permalink
Return zero derivatives according Riesz representation option (#3637)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ig-dolci authored Jun 19, 2024
1 parent 97e883c commit 58cb608
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
18 changes: 11 additions & 7 deletions firedrake/adjoint_utils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,17 +223,15 @@ def _ad_create_checkpoint(self):
def _ad_convert_riesz(self, value, options=None):
from firedrake import Function, Cofunction

options = {} if options is None else options
riesz_representation = options.get("riesz_representation", "L2")
solver_options = options.get("solver_options", {})
V = options.get("function_space", self.function_space())
if value == 0.:
# In adjoint-based differentiation, value == 0. arises only when
# the functional is independent on the control variable.
# In this case, we do not apply the Riesz map and return a zero
# Cofunction.
return Cofunction(V.dual())
return Function(V)

options = {} if options is None else options
riesz_representation = options.get("riesz_representation", "L2")
solver_options = options.get("solver_options", {})
if not isinstance(value, (Cofunction, Function)):
raise TypeError("Expected a Cofunction or a Function")

Expand Down Expand Up @@ -279,7 +277,13 @@ def _ad_convert_type(self, value, options=None):
options = {} if options is None else options.copy()
options.setdefault("riesz_representation", "L2")
if options["riesz_representation"] is None:
return value
if value == 0.:
# In adjoint-based differentiation, value == 0. arises only when
# the functional is independent on the control variable.
V = options.get("function_space", self.function_space())
return firedrake.Cofunction(V.dual())
else:
return value
else:
return self._ad_convert_riesz(value, options=options)

Expand Down
12 changes: 11 additions & 1 deletion tests/regression/test_adjoint_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,17 @@ def test_assign_zero_cofunction():
J = assemble(((sol + Constant(1.0)) ** 2) * dx)
# The zero assignment should break the tape and hence cause a zero
# gradient.
assert all(compute_gradient(J, Control(k)).dat.data_ro == 0.0)
grad_l2 = compute_gradient(J, Control(k), options={"riesz_representation": "l2"})
grad_none = compute_gradient(J, Control(k), options={"riesz_representation": None})
grad_h1 = compute_gradient(J, Control(k), options={"riesz_representation": "H1"})
grad_L2 = compute_gradient(J, Control(k), options={"riesz_representation": "L2"})
assert isinstance(grad_l2, Function) and isinstance(grad_L2, Function) \
and isinstance(grad_h1, Function)
assert isinstance(grad_none, Cofunction)
assert all(grad_none.dat.data_ro == 0.0)
assert all(grad_l2.dat.data_ro == 0.0)
assert all(grad_h1.dat.data_ro == 0.0)
assert all(grad_L2.dat.data_ro == 0.0)


@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done
Expand Down

0 comments on commit 58cb608

Please sign in to comment.