Skip to content

Commit bbb6817

Browse files
committed
Start to move TLM evaluation into NonlinearVariationalSolveBlock
1 parent 2b008b3 commit bbb6817

File tree

2 files changed

+92
-4
lines changed

2 files changed

+92
-4
lines changed

firedrake/adjoint_utils/blocks/solving.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class Solver(Enum):
2929
"""Enum for solver types."""
3030
FORWARD = 0
3131
ADJOINT = 1
32+
TLM = 2
3233

3334

3435
class GenericSolveBlock(Block):
@@ -681,8 +682,11 @@ def _adjoint_solve(self, dJdu, compute_bdy):
681682
def _ad_assign_map(self, form, solver):
682683
if solver == Solver.FORWARD:
683684
count_map = self._ad_solvers["forward_nlvs"]._problem._ad_count_map
684-
else:
685+
elif solver == Solver.ADJOINT:
685686
count_map = self._ad_solvers["adjoint_lvs"]._problem._ad_count_map
687+
elif solver == Solver.TLM:
688+
count_map = self._ad_solvers["tlm_lvs"]._problem._ad_count_map
689+
686690
assign_map = {}
687691
form_ad_count_map = dict((count_map[coeff], coeff)
688692
for coeff in form.coefficients())
@@ -717,9 +721,13 @@ def _ad_solver_replace_forms(self, solver=Solver.FORWARD):
717721
problem = self._ad_solvers["forward_nlvs"]._problem
718722
self._ad_assign_coefficients(problem.F, solver)
719723
self._ad_assign_coefficients(problem.J, solver)
720-
else:
724+
elif solver == Solver.ADJOINT:
721725
self._ad_assign_coefficients(
722726
self._ad_solvers["adjoint_lvs"]._problem.J, solver)
727+
elif solver == Solver.TLM:
728+
self._ad_assign_coefficients(
729+
self._ad_solvers["tlm_lvs"]._problem.J, solver
730+
)
723731

724732
def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies):
725733
compute_bdy = self._should_compute_boundary_adjoint(
@@ -796,6 +804,59 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
796804

797805
return dFdm
798806

807+
def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
808+
prepared=None):
809+
F_form = prepared["form"]
810+
dFdu = prepared["dFdu"]
811+
812+
bcs = []
813+
dFdm = 0.
814+
for block_variable in self.get_dependencies():
815+
tlm_value = block_variable.tlm_value
816+
c = block_variable.output
817+
c_rep = block_variable.saved_output
818+
819+
if isinstance(c, firedrake.DirichletBC):
820+
if tlm_value is None:
821+
bcs.append(c.reconstruct(g=0))
822+
else:
823+
bcs.append(tlm_value)
824+
continue
825+
elif isinstance(c, firedrake.MeshGeometry):
826+
X = firedrake.SpatialCoordinate(c)
827+
c_rep = X
828+
829+
if tlm_value is None:
830+
continue
831+
832+
if c == self.func and not self.linear:
833+
continue
834+
835+
dFdm += firedrake.derivative(-F_form, c_rep, tlm_value)
836+
837+
if isinstance(dFdm, float):
838+
v = dFdu.arguments()[0]
839+
dFdm = firedrake.inner(
840+
firedrake.Constant(numpy.zeros(v.ufl_shape)), v
841+
) * firedrake.dx
842+
843+
dFdm = ufl.algorithms.expand_derivatives(dFdm)
844+
dFdm = firedrake.assemble(dFdm)
845+
846+
# XXX I dunno how this works
847+
self._ad_solver_replace_forms(Solver.TLM)
848+
self._ad_solvers["tlm_lvs"].invalidate_jacobian()
849+
# update RHS
850+
self._ad_solvers["tlm_lvs"]._problem.F._components[1].assign(dFdm)
851+
852+
self._ad_solvers["tlm_lvs"].solve()
853+
return self._ad_solvers["tlm_lvs"]._problem.u
854+
# return self._assemble_and_solve_tlm_eq(
855+
# firedrake.assemble(dFdu, bcs=bcs, **self.assemble_kwargs),
856+
# dFdm, dudm, bcs
857+
# )
858+
859+
799860

800861
class ProjectBlock(SolveVarFormBlock):
801862
def __init__(self, v, V, output, bcs=[], *args, **kwargs):

firedrake/adjoint_utils/variational_solver.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def wrapper(self, problem, *args, **kwargs):
4949
self._ad_args = args
5050
self._ad_kwargs = kwargs
5151
self._ad_solvers = {"forward_nlvs": None, "adjoint_lvs": None,
52-
"recompute_count": 0}
52+
"recompute_count": 0, "tlm_lvs": None}
5353
self._ad_adj_cache = {}
5454

5555
return wrapper
@@ -100,6 +100,14 @@ def wrapper(self, **kwargs):
100100
if self._ad_problem._constant_jacobian:
101101
self._ad_solvers["update_adjoint"] = False
102102

103+
if not self._ad_solvers["tlm_lvs"]:
104+
with stop_annotating():
105+
self._ad_solvers["tlm_lvs"] = LinearVariationalSolver(
106+
self._ad_tlm_lvs_problem(block, problem.F, problem.u_restrict)
107+
)
108+
if self._ad_problem._constant_jacobian:
109+
self._ad_solvers["update_tlm"] = False
110+
103111
block._ad_solvers = self._ad_solvers
104112

105113
tape.add_block(block)
@@ -151,14 +159,33 @@ def _ad_adj_lvs_problem(self, block, adj_F):
151159
# linear variational problem is created with a deep copy of the
152160
# `block.adj_F` coefficients.
153161
_ad_count_map, J_replace_map, _ = self._build_count_map(
154-
adj_F, block._dependencies)
162+
adj_F, block._dependencies,
163+
)
155164
lvp = LinearVariationalProblem(
156165
replace(tmp_problem.J, J_replace_map), right_hand_side, adj_sol,
157166
bcs=tmp_problem.bcs,
158167
constant_jacobian=self._ad_problem._constant_jacobian)
159168
lvp._ad_count_map_update(_ad_count_map)
160169
return lvp
161170

171+
@no_annotations
172+
def _ad_tlm_lvs_problem(self, block, F, u):
173+
from firedrake import Function, Cofunction, LinearVariationalProblem
174+
175+
lhs = derivative(F, u)
176+
_ad_count_map, F_replace_map, _ = self._build_count_map(lhs, block._dependencies)
177+
sol = Function(block.function_space)
178+
rhs = Cofunction(block.function_space.dual())
179+
lvp = LinearVariationalProblem(
180+
replace(lhs, F_replace_map),
181+
rhs,
182+
sol,
183+
bcs=block._homogenize_bcs(),
184+
constant_jacobian=self._ad_problem._constant_jacobian,
185+
)
186+
lvp._ad_count_map_update(_ad_count_map)
187+
return lvp
188+
162189
def _build_count_map(self, J, dependencies, F=None):
163190
from firedrake import Function
164191

0 commit comments

Comments
 (0)