Skip to content

Commit

Permalink
Dualspace update (#2294)
Browse files Browse the repository at this point in the history
* Assembling Formsum + tests
* Add assembled matrix
* Add firedrake.Coargument
* Expunge UFLType from Firedrake

---------

Co-authored-by: India Marsden <imm1117@ic.ac.uk>
Co-authored-by: Colin Cotter <colin.cotter@imperial.ac.uk>
Co-authored-by: Sophia Vorderwuelbecke <sv2518@ic.ac.uk>
  • Loading branch information
4 people authored Sep 21, 2023
1 parent 6e624b7 commit 3676bab
Show file tree
Hide file tree
Showing 42 changed files with 1,502 additions and 279 deletions.
2 changes: 1 addition & 1 deletion demos/linear-wave-equation/linear_wave_equation.py.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ options at this point, we may either `lump` the mass, which reduces
the inversion to a pointwise division::

if lump_mass:
p += interpolate(assemble(dt * inner(nabla_grad(v), nabla_grad(phi))*dx) / assemble(v*dx), V)
p.dat.data[:] += assemble(dt * inner(nabla_grad(v), nabla_grad(phi))*dx).dat.data_ro / assemble(v*dx).dat.data_ro

In the mass lumped case, we must now ensure that the resulting
solution for :math:`p` satisfies the boundary conditions::
Expand Down
2 changes: 2 additions & 0 deletions docs/notebooks/11-extract-adjoint-solutions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,8 @@
" t = i*timesteps_per_export*dt\n",
" tricontourf(forward_solutions[i], axes=axs[i, 0])\n",
" adjoint_solution = dJdu if i == num_exports else solve_blocks[timesteps_per_export*i].adj_sol\n",
" # Get the Riesz representer\n",
" adjoint_solution = dJdu.riesz_representation(riesz_map=\"H1\")\n",
" tricontourf(adjoint_solution, axes=axs[i, 1])\n",
" axs[i, 0].annotate('t={:.2f}'.format(t), (0.05, 0.05), color='white');\n",
" axs[i, 1].annotate('t={:.2f}'.format(t), (0.05, 0.05), color='white');\n",
Expand Down
1 change: 1 addition & 0 deletions firedrake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from firedrake.assemble import *
from firedrake.bcs import *
from firedrake.checkpointing import *
from firedrake.cofunction import *
from firedrake.constant import *
from firedrake.exceptions import *
from firedrake.function import *
Expand Down
8 changes: 5 additions & 3 deletions firedrake/adjoint_utils/assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ def wrapper(*args, **kwargs):
output = assemble(*args, **kwargs)

from firedrake.function import Function
from firedrake.cofunction import Cofunction
form = args[0]
if isinstance(output, (numbers.Complex, Function)):
if isinstance(output, (numbers.Complex, Function, Cofunction)):
# Assembling a 0-form or 1-form (e.g. Form)
if not annotate:
return output

if not isinstance(output, (float, Function)):
if not isinstance(output, (float, Function, Cofunction)):
raise NotImplementedError("Taping for complex-valued 0-forms not yet done!")
output = create_overloaded_object(output)
block = AssembleBlock(form, ad_block_tag=ad_block_tag)
Expand All @@ -34,7 +36,7 @@ def wrapper(*args, **kwargs):

block.add_output(output.block_variable)
else:
# Assembled a matrix
# Assembled a 2-form
output.form = form

return output
Expand Down
57 changes: 31 additions & 26 deletions firedrake/adjoint_utils/blocks/assembly.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ufl
import firedrake
from ufl.formatting.ufl2unicode import ufl2unicode
from pyadjoint import Block, create_overloaded_object
from pyadjoint import Block, AdjFloat, create_overloaded_object
from .backend import Backend
from firedrake.adjoint_utils.checkpointing import maybe_disk_checkpoint

Expand All @@ -13,8 +13,11 @@ def __init__(self, form, ad_block_tag=None):
if self.backend.__name__ != "firedrake":
mesh = self.form.ufl_domain().ufl_cargo()
else:
mesh = self.form.ufl_domain()
self.add_dependency(mesh)
mesh = self.form.ufl_domain() if hasattr(self.form, 'ufl_domain') else None

if mesh:
self.add_dependency(mesh)

for c in self.form.coefficients():
self.add_dependency(c, no_duplicates=True)

Expand All @@ -28,7 +31,7 @@ def compute_action_adjoint(self, adj_input, arity_form, form=None,
`<(dform/dc_rep)*, adj_input>`
- If `form` has arity 0 => `dform/dc_rep` is a 1-form and
`adj_input` a foat, we can simply use the `*` operator.
`adj_input` a float, we can simply use the `*` operator.
- If `form` has arity 1 => `dform/dc_rep` is a 2-form and we can
symbolically take its adjoint and then apply the action on
Expand All @@ -38,31 +41,38 @@ def compute_action_adjoint(self, adj_input, arity_form, form=None,
if dform is None:
dc = self.backend.TestFunction(space)
dform = self.backend.derivative(form, c_rep, dc)
dform_vector = self.compat.assemble_adjoint_value(dform)
# Return a Vector scaled by the scalar `adj_input`
return dform_vector * adj_input, dform
dform_adj = self.compat.assemble_adjoint_value(dform)
if dform_adj == 0:
# `dform_adj` is a `ZeroBaseForm`
return AdjFloat(0.), dform
# Return the adjoint model of `form` scaled by the scalar `adj_input`
adj_output = dform_adj._ad_mul(adj_input)
return adj_output, dform
elif arity_form == 1:
if dform is None:
dc = self.backend.TrialFunction(space)
dform = self.backend.derivative(form, c_rep, dc)
# Get the Function
adj_input = adj_input.function
# Symbolic operators such as action/adjoint require derivatives to
# have been expanded beforehand. However, UFL doesn't support
# expanding coordinate derivatives of Coefficients in physical
# space, implying that we can't symbolically take the
# action/adjoint of the Jacobian for SpatialCoordinates. ->
# Workaround: Apply action/adjoint numerically (using PETSc).
# action/adjoint of the Jacobian for SpatialCoordinates.
# -> Workaround: Apply action/adjoint numerically (using PETSc).
if not isinstance(c_rep, self.backend.SpatialCoordinate):
# Symbolically compute: (dform/dc_rep)^* * adj_input
adj_output = self.backend.action(self.backend.adjoint(dform),
adj_input)
adj_output = self.compat.assemble_adjoint_value(adj_output)
else:
adj_output = self.backend.Cofunction(space.dual())
# Assemble `dform`: derivatives are expanded along the way
# which may lead to a ZeroBaseForm
assembled_dform = self.compat.assemble_adjoint_value(dform)
if assembled_dform == 0:
return adj_output, dform
# Get PETSc matrix
dform_mat = self.compat.assemble_adjoint_value(dform).petscmat
dform_mat = assembled_dform.petscmat
# Action of the adjoint (Hermitian transpose)
adj_output = self.backend.Function(space)
with adj_input.dat.vec_ro as v_vec:
with adj_output.dat.vec as res_vec:
dform_mat.multHermitian(v_vec, res_vec)
Expand Down Expand Up @@ -105,7 +115,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
if self.compat.isconstant(c):
mesh = self.compat.extract_mesh_from_form(self.form)
space = c._ad_function_space(mesh)
elif isinstance(c, self.backend.Function):
elif isinstance(c, (self.backend.Function, self.backend.Cofunction)):
space = c.function_space()
elif isinstance(c, self.compat.MeshType):
c_rep = self.backend.SpatialCoordinate(c_rep)
Expand All @@ -123,8 +133,6 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
form = prepared
dform = 0.

from ufl.algorithms.analysis import extract_arguments
arity_form = len(extract_arguments(form))
for bv in self.get_dependencies():
c_rep = bv.saved_output
tlm_value = bv.tlm_value
Expand All @@ -133,15 +141,14 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
continue
if isinstance(c_rep, self.compat.MeshType):
X = self.backend.SpatialCoordinate(c_rep)
# Spatial coordinates derivatives cannot be expanded in the physical space,
# which is required by symbolic operators such as `action`.
dform += self.backend.derivative(form, X, tlm_value)
else:
dform += self.backend.derivative(form, c_rep, tlm_value)
dform += self.backend.action(self.backend.derivative(form, c_rep), tlm_value)
if not isinstance(dform, float):
dform = ufl.algorithms.expand_derivatives(dform)
dform = self.compat.assemble_adjoint_value(dform)
if arity_form == 1 and dform != 0:
# Then dform is a Vector
dform = dform.function
return dform

def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs,
Expand All @@ -165,7 +172,7 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs,
if self.compat.isconstant(c1):
mesh = self.compat.extract_mesh_from_form(form)
space = c1._ad_function_space(mesh)
elif isinstance(c1, self.backend.Function):
elif isinstance(c1, (self.backend.Function, self.backend.Cofunction)):
space = c1.function_space()
elif isinstance(c1, self.compat.ExpressionType):
mesh = form.ufl_domain().ufl_cargo()
Expand All @@ -180,7 +187,7 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs,
hessian_input, arity_form, form, c1_rep, space
)

ddform = 0
ddform = 0.
for other_idx, bv in relevant_dependencies:
c2_rep = bv.saved_output
tlm_input = bv.tlm_value
Expand All @@ -196,10 +203,8 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs,

if not isinstance(ddform, float):
ddform = ufl.algorithms.expand_derivatives(ddform)
if not ddform.empty():
hessian_outputs += self.compute_action_adjoint(
adj_input, arity_form, dform=ddform
)[0]
if not (isinstance(ddform, ufl.ZeroBaseForm) or (isinstance(ddform, ufl.Form) and ddform.empty())):
hessian_outputs += self.compute_action_adjoint(adj_input, arity_form, dform=ddform)[0]

if isinstance(c1, self.compat.ExpressionType):
return [(hessian_outputs, space)]
Expand Down
12 changes: 2 additions & 10 deletions firedrake/adjoint_utils/blocks/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def extract_bc_subvector(value, Vtarget, bc):
for idx in bc._indices:
r = r.sub(idx)
assert Vtarget == r.function_space()
return r.vector()
return r
compat.extract_bc_subvector = extract_bc_subvector

def extract_mesh_from_form(form):
Expand Down Expand Up @@ -115,15 +115,7 @@ def constant_function_firedrake_compat(value):
return value.dat.data
compat.constant_function_firedrake_compat = constant_function_firedrake_compat

def assemble_adjoint_value(*args, **kwargs):
"""A wrapper around Firedrake's assemble that returns a Vector
instead of a Function when assembling a 1-form."""
result = backend.assemble(*args, **kwargs)
if isinstance(result, backend.Function):
return result.vector()
else:
return result
compat.assemble_adjoint_value = assemble_adjoint_value
compat.assemble_adjoint_value = backend.assemble

def gather(vec):
return vec.gather()
Expand Down
21 changes: 10 additions & 11 deletions firedrake/adjoint_utils/blocks/dirichlet_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,16 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
adj_output = None
for adj_input in adj_inputs:
if self.compat.isconstant(c):
adj_value = self.backend.Function(self.parent_space)
adj_input.apply(adj_value.vector())
adj_value = self.backend.Function(self.parent_space.dual())
adj_input.apply(adj_value)
if self.function_space != self.parent_space:
vec = self.compat.extract_bc_subvector(
adj_value, self.collapsed_space, bc
)
adj_value = self.compat.function_from_vector(
self.collapsed_space, vec
)
adj_value = self.backend.Function(self.collapsed_space, vec.dat)

if adj_value.ufl_shape == () or adj_value.ufl_shape[0] <= 1:
r = adj_value.vector().sum()
r = adj_value.dat.data_ro.sum()
else:
output = []
subindices = _extract_subindices(self.function_space)
Expand All @@ -64,7 +62,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
output.append(
current_subfunc.sub(
prev_idx, deepcopy=True
).vector().sum()
).dat.data_ro.sum()
)

r = self.backend.cpp.la.Vector(self.backend.MPI.comm_world,
Expand All @@ -77,14 +75,14 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
# you can even use the Function outside its domain.
# For now we will just assume the FunctionSpace is the same for
# the BC and the Function.
adj_value = self.backend.Function(self.parent_space)
adj_input.apply(adj_value.vector())
adj_value = self.backend.Function(self.parent_space.dual())
adj_input.apply(adj_value)
r = self.compat.extract_bc_subvector(
adj_value, c.function_space(), bc
)
elif isinstance(c, self.compat.Expression):
adj_value = self.backend.Function(self.parent_space)
adj_input.apply(adj_value.vector())
adj_value = self.backend.Function(self.parent_space.dual())
adj_input.apply(adj_value)
output = self.compat.extract_bc_subvector(
adj_value, self.collapsed_space, bc
)
Expand All @@ -93,6 +91,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
adj_output = r
else:
adj_output += r

return adj_output

def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
Expand Down
Loading

0 comments on commit 3676bab

Please sign in to comment.