Skip to content

Commit 61ff76a

Browse files
authored
Explicitly assemble the interpolate adjoint matrix (#4576)
* Explicitly assemble the interpolate adjoint matrix * Move renumbering logic to Interpolator * Enhacements for interpolation into VOM * SameMeshInterpolator: support matfree/explcit adjoint on Submesh
1 parent 210bae0 commit 61ff76a

File tree

7 files changed

+330
-234
lines changed

7 files changed

+330
-234
lines changed

firedrake/assemble.py

Lines changed: 4 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,6 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
556556
result = expr.assemble(assembly_opts=opts)
557557
return tensor.assign(result) if tensor else result
558558
elif isinstance(expr, ufl.Interpolate):
559-
orig_expr = expr
560559
# Replace assembled children
561560
_, operand = expr.argument_slots()
562561
v, *assembled_operand = args
@@ -568,13 +567,9 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
568567
if (v, operand) != expr.argument_slots():
569568
expr = reconstruct_interp(operand, v=v)
570569

571-
# Different assembly procedures:
572-
# 1) Interpolate(Argument(V1, 1), Argument(V2.dual(), 0)) -> Jacobian (Interpolate matrix)
573-
# 2) Interpolate(Coefficient(...), Argument(V2.dual(), 0)) -> Operator (or Jacobian action)
574-
# 3) Interpolate(Argument(V1, 0), Argument(V2.dual(), 1)) -> Jacobian adjoint
575-
# 4) Interpolate(Argument(V1, 0), Cofunction(...)) -> Action of the Jacobian adjoint
576-
# This can be generalized to the case where the first slot is an arbitray expression.
577570
rank = len(expr.arguments())
571+
if rank > 2:
572+
raise ValueError("Cannot assemble an Interpolate with more than two arguments")
578573
# If argument numbers have been swapped => Adjoint.
579574
arg_operand = ufl.algorithms.extract_arguments(operand)
580575
is_adjoint = (arg_operand and arg_operand[0].number() == 0)
@@ -605,67 +600,14 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
605600
assemble(sub_interp, tensor=tensor.subfunctions[i])
606601
return tensor
607602

608-
# Workaround: Renumber argument when needed since Interpolator assumes it takes a zero-numbered argument.
609-
if not is_adjoint and rank == 2:
610-
v0, v1 = expr.arguments()
611-
expr = ufl.replace(expr, {v0: v0.reconstruct(number=v1.number()),
612-
v1: v1.reconstruct(number=v0.number())})
613-
v, operand = expr.argument_slots()
614-
615-
# Matrix-free adjoint interpolation is only implemented by SameMeshInterpolator
616-
# so we need assemble the interpolator matrix if the meshes are different
617-
target_mesh = V.mesh()
618-
source_mesh = extract_unique_domain(operand) or target_mesh
619-
if is_adjoint and rank < 2 and source_mesh is not target_mesh:
620-
expr = reconstruct_interp(operand, v=V)
621-
matfree = (rank == len(expr.arguments())) and (rank < 2)
622-
623603
# Get the interpolator
624604
interp_data = expr.interp_data.copy()
625605
default_missing_val = interp_data.pop('default_missing_val', None)
626-
if matfree and ((is_adjoint and rank == 1) or rank == 0):
627-
# Adjoint interpolation of a Cofunction or the action of a
628-
# Cofunction on an interpolated Function require INC access
629-
# on the output tensor
630-
interp_data["access"] = op2.INC
631-
632-
if rank == 1 and matfree and isinstance(tensor, firedrake.Function):
606+
if rank == 1 and isinstance(tensor, firedrake.Function):
633607
V = tensor
634608
interpolator = firedrake.Interpolator(expr, V, **interp_data)
635-
636609
# Assembly
637-
if matfree:
638-
# Assembling the operator
639-
return interpolator._interpolate(output=tensor, default_missing_val=default_missing_val)
640-
elif rank == 0:
641-
# Assembling the double action.
642-
Iu = interpolator._interpolate(default_missing_val=default_missing_val)
643-
return assemble(ufl.Action(v, Iu), tensor=tensor)
644-
elif rank == 1:
645-
# Assembling the action of the Jacobian adjoint.
646-
if is_adjoint:
647-
return interpolator._interpolate(v, output=tensor, adjoint=True, default_missing_val=default_missing_val)
648-
# Assembling the Jacobian action.
649-
else:
650-
return interpolator._interpolate(operand, output=tensor, default_missing_val=default_missing_val)
651-
elif rank == 2:
652-
res = tensor.petscmat if tensor else PETSc.Mat()
653-
# Get the interpolation matrix
654-
op2_mat = interpolator.callable()
655-
petsc_mat = op2_mat.handle
656-
if is_adjoint:
657-
# Out-of-place Hermitian transpose
658-
petsc_mat.hermitianTranspose(out=res)
659-
elif res:
660-
# Copy the interpolation matrix into the output tensor
661-
petsc_mat.copy(result=res)
662-
else:
663-
res = petsc_mat
664-
if tensor is None:
665-
tensor = self.assembled_matrix(orig_expr, res)
666-
return tensor
667-
else:
668-
raise ValueError("Incompatible number of arguments.")
610+
return interpolator.assemble(tensor=tensor, default_missing_val=default_missing_val)
669611
elif tensor and isinstance(expr, (firedrake.Function, firedrake.Cofunction, firedrake.MatrixBase)):
670612
return tensor.assign(expr)
671613
elif tensor and isinstance(expr, ufl.ZeroBaseForm):

0 commit comments

Comments
 (0)