Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 88 additions & 6 deletions firedrake/adjoint_utils/blocks/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class Solver(Enum):
"""Enum for solver types."""
FORWARD = 0
ADJOINT = 1
TLM = 2
HESSIAN = 3


class GenericSolveBlock(Block):
Expand Down Expand Up @@ -220,6 +222,9 @@ def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy):

return adj_sol, adj_sol_bdy

def _hessian_solve(self, *args):
return self._assemble_and_solve_adj_eq(*args)

def _compute_adj_bdy(self, adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu):
adj_sol_bdy = firedrake.assemble(dJdu - firedrake.action(dFdu_adj_form, adj_sol))
return adj_sol_bdy.riesz_representation("l2")
Expand Down Expand Up @@ -378,8 +383,7 @@ def _assemble_and_solve_soa_eq(self, dFdu_form, adj_sol, hessian_input,
b = self._assemble_soa_eq_rhs(dFdu_form, adj_sol, hessian_input,
d2Fdu2)
dFdu_form = firedrake.adjoint(dFdu_form)
adj_sol2, adj_sol2_bdy = self._assemble_and_solve_adj_eq(dFdu_form, b,
compute_bdy)
adj_sol2, adj_sol2_bdy = self._hessian_solve(dFdu_form, b, compute_bdy)
if self.adj2_cb is not None:
self.adj2_cb(adj_sol2)
if self.adj2_bdy_cb is not None and compute_bdy:
Expand Down Expand Up @@ -678,11 +682,30 @@ def _adjoint_solve(self, dJdu, compute_bdy):
u_sol, adj_sol_bdy, jac_adj, dJdu_copy)
return u_sol, adj_sol_bdy

def _hessian_solve(self, adj_form, rhs, compute_bdy):
# self._ad_solver_replace_forms(Solver.HESSIAN)
# self._ad_solvers["hessian_lvs"].invalidate_jacobian()
self._ad_solvers["hessian_lvs"]._problem.F._components[1].assign(rhs)
self._ad_solvers["hessian_lvs"].solve()
u_sol = self._ad_solvers["hessian_lvs"]._problem.u

adj_sol_bdy = None
if compute_bdy:
jac_adj = self._ad_solvers["hessian_lvs"]._problem.J
adj_sol_bdy = self._compute_adj_bdy(
u_sol, adj_sol_bdy, jac_adj, rhs.copy()
)

return u_sol, adj_sol_bdy

def _ad_assign_map(self, form, solver):
if solver == Solver.FORWARD:
count_map = self._ad_solvers["forward_nlvs"]._problem._ad_count_map
else:
elif solver == Solver.ADJOINT:
count_map = self._ad_solvers["adjoint_lvs"]._problem._ad_count_map
elif solver == Solver.TLM:
count_map = self._ad_solvers["tlm_lvs"]._problem._ad_count_map

assign_map = {}
form_ad_count_map = dict((count_map[coeff], coeff)
for coeff in form.coefficients())
Expand All @@ -693,8 +716,10 @@ def _ad_assign_map(self, form, solver):
firedrake.Cofunction)):
coeff_count = coeff.count()
if coeff_count in form_ad_count_map:
assign_map[form_ad_count_map[coeff_count]] = \
block_variable.saved_output
if solver == Solver.HESSIAN:
assign_map[form_ad_count_map[coeff_count]] = block_variable.tlm_value
else:
assign_map[form_ad_count_map[coeff_count]] = block_variable.saved_output

if (
solver == Solver.ADJOINT
Expand All @@ -705,6 +730,7 @@ def _ad_assign_map(self, form, solver):
if coeff_count in form_ad_count_map:
assign_map[form_ad_count_map[coeff_count]] = \
block_variable.saved_output

return assign_map

def _ad_assign_coefficients(self, form, solver):
Expand All @@ -717,9 +743,17 @@ def _ad_solver_replace_forms(self, solver=Solver.FORWARD):
problem = self._ad_solvers["forward_nlvs"]._problem
self._ad_assign_coefficients(problem.F, solver)
self._ad_assign_coefficients(problem.J, solver)
else:
elif solver == Solver.ADJOINT:
self._ad_assign_coefficients(
self._ad_solvers["adjoint_lvs"]._problem.J, solver)
elif solver == Solver.TLM:
self._ad_assign_coefficients(
self._ad_solvers["tlm_lvs"]._problem.J, solver
)
elif solver == Solver.HESSIAN:
self._ad_assign_coefficients(
self._ad_solvers["hessian_lvs"]._problem.J, solver
)

def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies):
compute_bdy = self._should_compute_boundary_adjoint(
Expand Down Expand Up @@ -796,6 +830,54 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,

return dFdm

def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
prepared=None):
F_form = prepared["form"]
dFdu = prepared["dFdu"]

bcs = []
dFdm = 0.
for block_variable in self.get_dependencies():
tlm_value = block_variable.tlm_value
c = block_variable.output
c_rep = block_variable.saved_output

if isinstance(c, firedrake.DirichletBC):
if tlm_value is None:
bcs.append(c.reconstruct(g=0))
else:
bcs.append(tlm_value)
continue
elif isinstance(c, firedrake.MeshGeometry):
X = firedrake.SpatialCoordinate(c)
c_rep = X

if tlm_value is None:
continue

if c == self.func and not self.linear:
continue

dFdm += firedrake.derivative(-F_form, c_rep, tlm_value)

if isinstance(dFdm, float):
v = dFdu.arguments()[0]
dFdm = firedrake.inner(
firedrake.Constant(numpy.zeros(v.ufl_shape)), v
) * firedrake.dx

dFdm = ufl.algorithms.expand_derivatives(dFdm)
dFdm = firedrake.assemble(dFdm)

# XXX I dunno how this works
self._ad_solver_replace_forms(Solver.TLM)
self._ad_solvers["tlm_lvs"].invalidate_jacobian()
# update RHS
self._ad_solvers["tlm_lvs"]._problem.F._components[1].assign(dFdm)

self._ad_solvers["tlm_lvs"].solve()
return self._ad_solvers["tlm_lvs"]._problem.u


class ProjectBlock(SolveVarFormBlock):
def __init__(self, v, V, output, bcs=[], *args, **kwargs):
Expand Down
62 changes: 60 additions & 2 deletions firedrake/adjoint_utils/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def wrapper(self, *args, **kwargs):
self._ad_u = self.u_restrict
self._ad_bcs = self.bcs
self._ad_J = self.J

try:
# Some forms (e.g. SLATE tensors) are not currently
# differentiable.
Expand All @@ -27,8 +28,10 @@ def wrapper(self, *args, **kwargs):
# Try again without expanding derivatives,
# as dFdu might have been simplied to an empty Form
self._ad_adj_F = adjoint(dFdu, derivatives_expanded=True)

except (TypeError, NotImplementedError):
self._ad_adj_F = None

self._ad_kwargs = {'Jp': self.Jp, 'form_compiler_parameters': self.form_compiler_parameters, 'is_linear': self.is_linear}
self._ad_count_map = {}
return wrapper
Expand All @@ -49,7 +52,8 @@ def wrapper(self, problem, *args, **kwargs):
self._ad_args = args
self._ad_kwargs = kwargs
self._ad_solvers = {"forward_nlvs": None, "adjoint_lvs": None,
"recompute_count": 0}
"recompute_count": 0, "tlm_lvs": None,
"hessian_lvs": None}
self._ad_adj_cache = {}

return wrapper
Expand Down Expand Up @@ -100,6 +104,20 @@ def wrapper(self, **kwargs):
if self._ad_problem._constant_jacobian:
self._ad_solvers["update_adjoint"] = False

if not self._ad_solvers["hessian_lvs"]:
with stop_annotating():
self._ad_solvers["hessian_lvs"] = LinearVariationalSolver(
self._ad_hessian_lvs_problem(block, problem._ad_adj_F),
)

if not self._ad_solvers["tlm_lvs"]:
with stop_annotating():
self._ad_solvers["tlm_lvs"] = LinearVariationalSolver(
self._ad_tlm_lvs_problem(block, problem.F, problem.u_restrict)
)
if self._ad_problem._constant_jacobian:
self._ad_solvers["update_tlm"] = False

block._ad_solvers = self._ad_solvers

tape.add_block(block)
Expand Down Expand Up @@ -151,14 +169,54 @@ def _ad_adj_lvs_problem(self, block, adj_F):
# linear variational problem is created with a deep copy of the
# `block.adj_F` coefficients.
_ad_count_map, J_replace_map, _ = self._build_count_map(
adj_F, block._dependencies)
adj_F, block._dependencies,
)
lvp = LinearVariationalProblem(
replace(tmp_problem.J, J_replace_map), right_hand_side, adj_sol,
bcs=tmp_problem.bcs,
constant_jacobian=self._ad_problem._constant_jacobian)
lvp._ad_count_map_update(_ad_count_map)
return lvp

@no_annotations
def _ad_hessian_lvs_problem(self, block, adj_dFdu):
from firedrake import Function, Cofunction, LinearVariationalProblem

bcs = block._homogenize_bcs()
adj_sol = Function(block.function_space)
right_hand_side = Cofunction(block.function_space.dual())
tmp_problem = LinearVariationalProblem(
adj_dFdu, right_hand_side, adj_sol, bcs=bcs,
constant_jacobian=self._ad_problem._constant_jacobian)

_ad_count_map, J_replace_map, _ = self._build_count_map(
adj_dFdu, block._dependencies,
)
lvp = LinearVariationalProblem(
replace(tmp_problem.J, J_replace_map), right_hand_side, adj_sol,
bcs=tmp_problem.bcs,
constant_jacobian=self._ad_problem._constant_jacobian)
lvp._ad_count_map_update(_ad_count_map)
return lvp

@no_annotations
def _ad_tlm_lvs_problem(self, block, F, u):
from firedrake import Function, Cofunction, LinearVariationalProblem

lhs = derivative(F, u)
_ad_count_map, F_replace_map, _ = self._build_count_map(lhs, block._dependencies)
sol = Function(block.function_space)
rhs = Cofunction(block.function_space.dual())
lvp = LinearVariationalProblem(
replace(lhs, F_replace_map),
rhs,
sol,
bcs=block._homogenize_bcs(),
constant_jacobian=self._ad_problem._constant_jacobian,
)
lvp._ad_count_map_update(_ad_count_map)
return lvp

def _build_count_map(self, J, dependencies, F=None):
from firedrake import Function

Expand Down