Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Cofunction.sub with adjoint #3470

Merged
merged 13 commits into from
May 1, 2024
6 changes: 3 additions & 3 deletions firedrake/adjoint_utils/blocks/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,9 @@ def evaluate_tlm(self):
return
output = self.get_outputs()[0]
fs = output.output.function_space()
f = firedrake.Function(fs)
f = type(output.output)(fs)
output.add_tlm_output(
firedrake.Function.assign(f.sub(self.idx), tlm_input)
type(output.output).assign(f.sub(self.idx), tlm_input)
)

def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs,
Expand All @@ -271,7 +271,7 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs,
def recompute_component(self, inputs, block_variable, idx, prepared):
sub_func = inputs[0]
parent_in = inputs[1]
parent_out = firedrake.Function(parent_in)
parent_out = type(parent_in)(parent_in)
parent_out.sub(self.idx).assign(sub_func)
return maybe_disk_checkpoint(parent_out)

Expand Down
18 changes: 9 additions & 9 deletions firedrake/adjoint_utils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ def wrapper(self, *args, **kwargs):
output = subfunctions(self, *args, **kwargs)

if annotate:
output = tuple(firedrake.Function(output[i].function_space(),
output[i],
block_class=SubfunctionBlock,
_ad_floating_active=True,
_ad_args=[self, i],
_ad_output_args=[i],
output_block_class=FunctionMergeBlock,
_ad_outputs=[self],
ad_block_tag=ad_block_tag)
output = tuple(type(self)(output[i].function_space(),
output[i],
block_class=SubfunctionBlock,
_ad_floating_active=True,
_ad_args=[self, i],
_ad_output_args=[i],
output_block_class=FunctionMergeBlock,
dham marked this conversation as resolved.
Show resolved Hide resolved
_ad_outputs=[self],
ad_block_tag=ad_block_tag)
for i in range(len(output)))
return output
return wrapper
Expand Down
4 changes: 2 additions & 2 deletions firedrake/cofunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def __init__(self, function_space, val=None, name=None, dtype=ScalarType,
self._name = name or 'cofunction_%d' % self.uid
self._label = "a cofunction"

if isinstance(val, vector.Vector):
# Allow constructing using a vector.
if isinstance(val, (Cofunction, vector.Vector)):
dham marked this conversation as resolved.
Show resolved Hide resolved
val = val.dat

if isinstance(val, (op2.Dat, op2.DatView, op2.MixedDat, op2.Global)):
assert val.comm == self._comm
self.dat = val
Expand Down
29 changes: 29 additions & 0 deletions tests/regression/test_adjoint_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,3 +857,32 @@ def test_assign_zero_cofunction():
# The zero assignment should break the tape and hence cause a zero
# gradient.
assert all(compute_gradient(J, Control(k)).dat.data_ro == 0.0)


@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done
def test_cofunction_subfunctions_with_adjoint():
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved
# See https://github.com/firedrakeproject/firedrake/issues/3469
mesh = UnitSquareMesh(2, 2)
BDM = FunctionSpace(mesh, "BDM", 1)
DG = FunctionSpace(mesh, "DG", 0)
W = BDM * DG
sigma, u = TrialFunctions(W)
tau, v = TestFunctions(W)
x, y = SpatialCoordinate(mesh)
f = Function(DG).interpolate(
10*exp(-(pow(x - 0.5, 2) + pow(y - 0.5, 2)) / 0.02))
bc0 = DirichletBC(W.sub(0), as_vector([0.0, -sin(5*x)]), 3)
bc1 = DirichletBC(W.sub(0), as_vector([0.0, sin(5*x)]), 4)
k = Function(DG).assign(1.0)
a = (dot(sigma, tau) + (dot(div(tau), u))) * dx + k * div(sigma)*v*dx
b = assemble(-f*TestFunction(DG)*dx)
w = Function(W)
b1 = Cofunction(W.dual())
# The following operation generates the FunctionMergeBlock.
b1.sub(1).interpolate(b)
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved
solve(a == b1, w, bcs=[bc0, bc1])
J = assemble(0.5*dot(w, w)*dx)
J_hat = ReducedFunctional(J, Control(k))
k.block_variable.tlm_value = Constant(1)
get_working_tape().evaluate_tlm()
assert taylor_test(J_hat, k, Constant(1.0), dJdm=J.block_variable.tlm_value) > 1.9
Loading