From 58cb608388e1b5ffe9eaa5018e06d11a3d56dc94 Mon Sep 17 00:00:00 2001 From: Daiane Iglesia Dolci <63597005+Ig-dolci@users.noreply.github.com> Date: Wed, 19 Jun 2024 17:06:29 +0100 Subject: [PATCH] Return zero derivatives according Riesz representation option (#3637) --- firedrake/adjoint_utils/function.py | 18 +++++++++++------- tests/regression/test_adjoint_operators.py | 12 +++++++++++- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/firedrake/adjoint_utils/function.py b/firedrake/adjoint_utils/function.py index b490f34128..d56438fec7 100644 --- a/firedrake/adjoint_utils/function.py +++ b/firedrake/adjoint_utils/function.py @@ -223,17 +223,15 @@ def _ad_create_checkpoint(self): def _ad_convert_riesz(self, value, options=None): from firedrake import Function, Cofunction + options = {} if options is None else options + riesz_representation = options.get("riesz_representation", "L2") + solver_options = options.get("solver_options", {}) V = options.get("function_space", self.function_space()) if value == 0.: # In adjoint-based differentiation, value == 0. arises only when # the functional is independent on the control variable. - # In this case, we do not apply the Riesz map and return a zero - # Cofunction. - return Cofunction(V.dual()) + return Function(V) - options = {} if options is None else options - riesz_representation = options.get("riesz_representation", "L2") - solver_options = options.get("solver_options", {}) if not isinstance(value, (Cofunction, Function)): raise TypeError("Expected a Cofunction or a Function") @@ -279,7 +277,13 @@ def _ad_convert_type(self, value, options=None): options = {} if options is None else options.copy() options.setdefault("riesz_representation", "L2") if options["riesz_representation"] is None: - return value + if value == 0.: + # In adjoint-based differentiation, value == 0. arises only when + # the functional is independent on the control variable. + V = options.get("function_space", self.function_space()) + return firedrake.Cofunction(V.dual()) + else: + return value else: return self._ad_convert_riesz(value, options=options) diff --git a/tests/regression/test_adjoint_operators.py b/tests/regression/test_adjoint_operators.py index 14a023bd16..4aa66f84fd 100644 --- a/tests/regression/test_adjoint_operators.py +++ b/tests/regression/test_adjoint_operators.py @@ -864,7 +864,17 @@ def test_assign_zero_cofunction(): J = assemble(((sol + Constant(1.0)) ** 2) * dx) # The zero assignment should break the tape and hence cause a zero # gradient. - assert all(compute_gradient(J, Control(k)).dat.data_ro == 0.0) + grad_l2 = compute_gradient(J, Control(k), options={"riesz_representation": "l2"}) + grad_none = compute_gradient(J, Control(k), options={"riesz_representation": None}) + grad_h1 = compute_gradient(J, Control(k), options={"riesz_representation": "H1"}) + grad_L2 = compute_gradient(J, Control(k), options={"riesz_representation": "L2"}) + assert isinstance(grad_l2, Function) and isinstance(grad_L2, Function) \ + and isinstance(grad_h1, Function) + assert isinstance(grad_none, Cofunction) + assert all(grad_none.dat.data_ro == 0.0) + assert all(grad_l2.dat.data_ro == 0.0) + assert all(grad_h1.dat.data_ro == 0.0) + assert all(grad_L2.dat.data_ro == 0.0) @pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done