Skip to content

Commit 538bd82

Browse files
committed
Start to move Hessian evaluation into NonlinearVariationalSolveBlock
1 parent bbb6817 commit 538bd82

File tree

2 files changed

+62
-10
lines changed

2 files changed

+62
-10
lines changed

firedrake/adjoint_utils/blocks/solving.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class Solver(Enum):
3030
FORWARD = 0
3131
ADJOINT = 1
3232
TLM = 2
33+
HESSIAN = 3
3334

3435

3536
class GenericSolveBlock(Block):
@@ -221,6 +222,9 @@ def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy):
221222

222223
return adj_sol, adj_sol_bdy
223224

225+
def _hessian_solve(self, *args):
226+
return self._assemble_and_solve_adj_eq(*args)
227+
224228
def _compute_adj_bdy(self, adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu):
225229
adj_sol_bdy = firedrake.assemble(dJdu - firedrake.action(dFdu_adj_form, adj_sol))
226230
return adj_sol_bdy.riesz_representation("l2")
@@ -379,8 +383,7 @@ def _assemble_and_solve_soa_eq(self, dFdu_form, adj_sol, hessian_input,
379383
b = self._assemble_soa_eq_rhs(dFdu_form, adj_sol, hessian_input,
380384
d2Fdu2)
381385
dFdu_form = firedrake.adjoint(dFdu_form)
382-
adj_sol2, adj_sol2_bdy = self._assemble_and_solve_adj_eq(dFdu_form, b,
383-
compute_bdy)
386+
adj_sol2, adj_sol2_bdy = self._hessian_solve(dFdu_form, b, compute_bdy)
384387
if self.adj2_cb is not None:
385388
self.adj2_cb(adj_sol2)
386389
if self.adj2_bdy_cb is not None and compute_bdy:
@@ -679,6 +682,22 @@ def _adjoint_solve(self, dJdu, compute_bdy):
679682
u_sol, adj_sol_bdy, jac_adj, dJdu_copy)
680683
return u_sol, adj_sol_bdy
681684

685+
def _hessian_solve(self, adj_form, rhs, compute_bdy):
686+
# self._ad_solver_replace_forms(Solver.HESSIAN)
687+
# self._ad_solvers["hessian_lvs"].invalidate_jacobian()
688+
self._ad_solvers["hessian_lvs"]._problem.F._components[1].assign(rhs)
689+
self._ad_solvers["hessian_lvs"].solve()
690+
u_sol = self._ad_solvers["hessian_lvs"]._problem.u
691+
692+
adj_sol_bdy = None
693+
if compute_bdy:
694+
jac_adj = self._ad_solvers["hessian_lvs"]._problem.J
695+
adj_sol_bdy = self._compute_adj_bdy(
696+
u_sol, adj_sol_bdy, jac_adj, rhs.copy()
697+
)
698+
699+
return u_sol, adj_sol_bdy
700+
682701
def _ad_assign_map(self, form, solver):
683702
if solver == Solver.FORWARD:
684703
count_map = self._ad_solvers["forward_nlvs"]._problem._ad_count_map
@@ -697,8 +716,10 @@ def _ad_assign_map(self, form, solver):
697716
firedrake.Cofunction)):
698717
coeff_count = coeff.count()
699718
if coeff_count in form_ad_count_map:
700-
assign_map[form_ad_count_map[coeff_count]] = \
701-
block_variable.saved_output
719+
if solver == Solver.HESSIAN:
720+
assign_map[form_ad_count_map[coeff_count]] = block_variable.tlm_value
721+
else:
722+
assign_map[form_ad_count_map[coeff_count]] = block_variable.saved_output
702723

703724
if (
704725
solver == Solver.ADJOINT
@@ -709,6 +730,7 @@ def _ad_assign_map(self, form, solver):
709730
if coeff_count in form_ad_count_map:
710731
assign_map[form_ad_count_map[coeff_count]] = \
711732
block_variable.saved_output
733+
712734
return assign_map
713735

714736
def _ad_assign_coefficients(self, form, solver):
@@ -728,6 +750,10 @@ def _ad_solver_replace_forms(self, solver=Solver.FORWARD):
728750
self._ad_assign_coefficients(
729751
self._ad_solvers["tlm_lvs"]._problem.J, solver
730752
)
753+
elif solver == Solver.HESSIAN:
754+
self._ad_assign_coefficients(
755+
self._ad_solvers["hessian_lvs"]._problem.J, solver
756+
)
731757

732758
def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies):
733759
compute_bdy = self._should_compute_boundary_adjoint(
@@ -851,11 +877,6 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
851877

852878
self._ad_solvers["tlm_lvs"].solve()
853879
return self._ad_solvers["tlm_lvs"]._problem.u
854-
# return self._assemble_and_solve_tlm_eq(
855-
# firedrake.assemble(dFdu, bcs=bcs, **self.assemble_kwargs),
856-
# dFdm, dudm, bcs
857-
# )
858-
859880

860881

861882
class ProjectBlock(SolveVarFormBlock):

firedrake/adjoint_utils/variational_solver.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def wrapper(self, *args, **kwargs):
1717
self._ad_u = self.u_restrict
1818
self._ad_bcs = self.bcs
1919
self._ad_J = self.J
20+
2021
try:
2122
# Some forms (e.g. SLATE tensors) are not currently
2223
# differentiable.
@@ -27,8 +28,10 @@ def wrapper(self, *args, **kwargs):
2728
# Try again without expanding derivatives,
2829
# as dFdu might have been simplied to an empty Form
2930
self._ad_adj_F = adjoint(dFdu, derivatives_expanded=True)
31+
3032
except (TypeError, NotImplementedError):
3133
self._ad_adj_F = None
34+
3235
self._ad_kwargs = {'Jp': self.Jp, 'form_compiler_parameters': self.form_compiler_parameters, 'is_linear': self.is_linear}
3336
self._ad_count_map = {}
3437
return wrapper
@@ -49,7 +52,8 @@ def wrapper(self, problem, *args, **kwargs):
4952
self._ad_args = args
5053
self._ad_kwargs = kwargs
5154
self._ad_solvers = {"forward_nlvs": None, "adjoint_lvs": None,
52-
"recompute_count": 0, "tlm_lvs": None}
55+
"recompute_count": 0, "tlm_lvs": None,
56+
"hessian_lvs": None}
5357
self._ad_adj_cache = {}
5458

5559
return wrapper
@@ -100,6 +104,12 @@ def wrapper(self, **kwargs):
100104
if self._ad_problem._constant_jacobian:
101105
self._ad_solvers["update_adjoint"] = False
102106

107+
if not self._ad_solvers["hessian_lvs"]:
108+
with stop_annotating():
109+
self._ad_solvers["hessian_lvs"] = LinearVariationalSolver(
110+
self._ad_hessian_lvs_problem(block, problem._ad_adj_F),
111+
)
112+
103113
if not self._ad_solvers["tlm_lvs"]:
104114
with stop_annotating():
105115
self._ad_solvers["tlm_lvs"] = LinearVariationalSolver(
@@ -168,6 +178,27 @@ def _ad_adj_lvs_problem(self, block, adj_F):
168178
lvp._ad_count_map_update(_ad_count_map)
169179
return lvp
170180

181+
@no_annotations
182+
def _ad_hessian_lvs_problem(self, block, adj_dFdu):
183+
from firedrake import Function, Cofunction, LinearVariationalProblem
184+
185+
bcs = block._homogenize_bcs()
186+
adj_sol = Function(block.function_space)
187+
right_hand_side = Cofunction(block.function_space.dual())
188+
tmp_problem = LinearVariationalProblem(
189+
adj_dFdu, right_hand_side, adj_sol, bcs=bcs,
190+
constant_jacobian=self._ad_problem._constant_jacobian)
191+
192+
_ad_count_map, J_replace_map, _ = self._build_count_map(
193+
adj_dFdu, block._dependencies,
194+
)
195+
lvp = LinearVariationalProblem(
196+
replace(tmp_problem.J, J_replace_map), right_hand_side, adj_sol,
197+
bcs=tmp_problem.bcs,
198+
constant_jacobian=self._ad_problem._constant_jacobian)
199+
lvp._ad_count_map_update(_ad_count_map)
200+
return lvp
201+
171202
@no_annotations
172203
def _ad_tlm_lvs_problem(self, block, F, u):
173204
from firedrake import Function, Cofunction, LinearVariationalProblem

0 commit comments

Comments
 (0)