Skip to content

Commit 3bd935e

Browse files
committed
MixedInterpolator
1 parent 61ff76a commit 3bd935e

File tree

3 files changed

+173
-66
lines changed

3 files changed

+173
-66
lines changed

firedrake/assemble.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import finat.ufl
1919
from firedrake import (extrusion_utils as eutils, matrix, parameters, solving,
2020
tsfc_interface, utils)
21-
from firedrake.formmanipulation import split_form
2221
from firedrake.adjoint_utils import annotate_assemble
2322
from firedrake.ufl_expr import extract_unique_domain
2423
from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit
@@ -570,36 +569,9 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
570569
rank = len(expr.arguments())
571570
if rank > 2:
572571
raise ValueError("Cannot assemble an Interpolate with more than two arguments")
573-
# If argument numbers have been swapped => Adjoint.
574-
arg_operand = ufl.algorithms.extract_arguments(operand)
575-
is_adjoint = (arg_operand and arg_operand[0].number() == 0)
576-
577572
# Get the target space
578573
V = v.function_space().dual()
579574

580-
# Dual interpolation from mixed source
581-
if is_adjoint and len(V) > 1:
582-
cur = 0
583-
sub_operands = []
584-
components = numpy.reshape(operand, (-1,))
585-
for Vi in V:
586-
sub_operands.append(ufl.as_tensor(components[cur:cur+Vi.value_size].reshape(Vi.value_shape)))
587-
cur += Vi.value_size
588-
589-
# Component-split of the primal operands interpolated into the dual argument-split
590-
split_interp = sum(reconstruct_interp(sub_operands[i], v=vi) for (i,), vi in split_form(v))
591-
return assemble(split_interp, tensor=tensor)
592-
593-
# Dual interpolation into mixed target
594-
if is_adjoint and len(arg_operand[0].function_space()) > 1 and rank == 1:
595-
V = arg_operand[0].function_space()
596-
tensor = tensor or firedrake.Cofunction(V.dual())
597-
598-
# Argument-split of the Interpolate gets assembled into the corresponding sub-tensor
599-
for (i,), sub_interp in split_form(expr):
600-
assemble(sub_interp, tensor=tensor.subfunctions[i])
601-
return tensor
602-
603575
# Get the interpolator
604576
interp_data = expr.interp_data.copy()
605577
default_missing_val = interp_data.pop('default_missing_val', None)

firedrake/interpolation.py

Lines changed: 126 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,18 @@ class Interpolator(abc.ABC):
263263

264264
def __new__(cls, expr, V, **kwargs):
265265
if isinstance(expr, ufl.Interpolate):
266+
# Mixed spaces are handled well only by the primal 1-form.
267+
# Are we a 2-form or a dual 1-form?
268+
arguments = expr.arguments()
269+
if any(not isinstance(a, Coargument) for a in arguments):
270+
# Do we have mixed source or target spaces?
271+
spaces = [a.function_space() for a in arguments]
272+
if len(spaces) < 2:
273+
spaces.append(V)
274+
if any(len(space) > 1 for space in spaces):
275+
return object.__new__(MixedInterpolator)
266276
expr, = expr.ufl_operands
277+
267278
target_mesh = as_domain(V)
268279
source_mesh = extract_unique_domain(expr) or target_mesh
269280
submesh_interp_implemented = \
@@ -309,9 +320,10 @@ def __init__(
309320
target_mesh = as_domain(V)
310321
source_mesh = extract_unique_domain(operand) or target_mesh
311322
vom_onto_other_vom = ((source_mesh is not target_mesh)
323+
and isinstance(self, SameMeshInterpolator)
312324
and isinstance(source_mesh.topology, VertexOnlyMeshTopology)
313325
and isinstance(target_mesh.topology, VertexOnlyMeshTopology))
314-
if not isinstance(self, SameMeshInterpolator) or vom_onto_other_vom:
326+
if isinstance(self, CrossMeshInterpolator) or vom_onto_other_vom:
315327
# For bespoke interpolation, we currently rely on different assembly procedures:
316328
# 1) Interpolate(Argument(V1, 1), Argument(V2.dual(), 0)) -> Forward operator (2-form)
317329
# 2) Interpolate(Argument(V1, 0), Argument(V2.dual(), 1)) -> Adjoint operator (2-form)
@@ -369,7 +381,7 @@ def _interpolate(self, *args, **kwargs):
369381
"""
370382
pass
371383

372-
def assemble(self, tensor=None, default_missing_val=None):
384+
def assemble(self, tensor=None, **kwargs):
373385
"""Assemble the operator (or its action)."""
374386
from firedrake.assemble import assemble
375387
needs_adjoint = self.ufl_interpolate_renumbered != self.ufl_interpolate
@@ -383,13 +395,11 @@ def assemble(self, tensor=None, default_missing_val=None):
383395
if needs_adjoint:
384396
# Out-of-place Hermitian transpose
385397
petsc_mat.hermitianTranspose(out=res)
386-
elif res:
387-
petsc_mat.copy(res)
398+
elif tensor:
399+
petsc_mat.copy(tensor.petscmat)
388400
else:
389401
res = petsc_mat
390-
if tensor is None:
391-
tensor = firedrake.AssembledMatrix(arguments, self.bcs, res)
392-
return tensor
402+
return tensor or firedrake.AssembledMatrix(arguments, self.bcs, res)
393403
else:
394404
# Assembling the action
395405
cofunctions = ()
@@ -401,11 +411,11 @@ def assemble(self, tensor=None, default_missing_val=None):
401411
cofunctions = (dual_arg,)
402412

403413
if needs_adjoint and len(arguments) == 0:
404-
Iu = self._interpolate(default_missing_val=default_missing_val)
414+
Iu = self._interpolate(**kwargs)
405415
return assemble(ufl.Action(*cofunctions, Iu), tensor=tensor)
406416
else:
407417
return self._interpolate(*cofunctions, output=tensor, adjoint=needs_adjoint,
408-
default_missing_val=default_missing_val)
418+
**kwargs)
409419

410420

411421
class DofNotDefinedError(Exception):
@@ -975,33 +985,10 @@ def callable():
975985
return callable
976986
else:
977987
loops = []
978-
if len(V) == 1:
979-
expressions = (expr,)
980-
else:
981-
if (hasattr(operand, "subfunctions") and len(operand.subfunctions) == len(V)
982-
and all(sub_op.ufl_shape == Vsub.value_shape for Vsub, sub_op in zip(V, operand.subfunctions))):
983-
# Use subfunctions if they match the target shapes
984-
operands = operand.subfunctions
985-
else:
986-
# Unflatten the expression into the shapes of the mixed components
987-
offset = 0
988-
operands = []
989-
for Vsub in V:
990-
if len(Vsub.value_shape) == 0:
991-
operands.append(operand[offset])
992-
else:
993-
components = [operand[offset + j] for j in range(Vsub.value_size)]
994-
operands.append(ufl.as_tensor(numpy.reshape(components, Vsub.value_shape)))
995-
offset += Vsub.value_size
996-
997-
# Split the dual argument
998-
if isinstance(dual_arg, Cofunction):
999-
duals = dual_arg.subfunctions
1000-
elif isinstance(dual_arg, Coargument):
1001-
duals = [Coargument(Vsub, number=dual_arg.number()) for Vsub in dual_arg.function_space()]
1002-
else:
1003-
duals = [v for _, v in sorted(firedrake.formmanipulation.split_form(dual_arg))]
1004-
expressions = map(expr._ufl_expr_reconstruct_, operands, duals)
988+
expressions = split_interpolate_target(expr)
989+
990+
if access == op2.INC:
991+
loops.append(tensor.zero)
1005992

1006993
# Interpolate each sub expression into each function space
1007994
for Vsub, sub_tensor, sub_expr in zip(V, tensor, expressions):
@@ -1074,8 +1061,6 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
10741061
parameters['scalar_type'] = utils.ScalarType
10751062

10761063
callables = ()
1077-
if access == op2.INC:
1078-
callables += (tensor.zero,)
10791064

10801065
# For the matfree adjoint 1-form and the 0-form, the cellwise kernel will add multiple
10811066
# contributions from the facet DOFs of the dual argument.
@@ -1720,3 +1705,106 @@ def _wrap_dummy_mat(self):
17201705

17211706
def duplicate(self, mat=None, op=None):
17221707
return self._wrap_dummy_mat()
1708+
1709+
1710+
def split_interpolate_target(expr: ufl.Interpolate):
1711+
"""Split an Interpolate into the components (subfunctions) of the target space."""
1712+
dual_arg, operand = expr.argument_slots()
1713+
V = dual_arg.function_space().dual()
1714+
if len(V) == 1:
1715+
return (expr,)
1716+
# Split the target (dual) argument
1717+
if isinstance(dual_arg, Cofunction):
1718+
duals = dual_arg.subfunctions
1719+
elif isinstance(dual_arg, ufl.Coargument):
1720+
duals = [Coargument(Vsub, dual_arg.number()) for Vsub in dual_arg.function_space()]
1721+
else:
1722+
duals = [vi for _, vi in sorted(firedrake.formmanipulation.split_form(dual_arg))]
1723+
# Split the operand into the target shapes
1724+
if (isinstance(operand, firedrake.Function) and len(operand.subfunctions) == len(V)
1725+
and all(fsub.ufl_shape == Vsub.value_shape for Vsub, fsub in zip(V, operand.subfunctions))):
1726+
# Use subfunctions if they match the target shapes
1727+
operands = operand.subfunctions
1728+
else:
1729+
# Unflatten the expression into the target shapes
1730+
cur = 0
1731+
operands = []
1732+
components = numpy.reshape(operand, (-1,))
1733+
for Vi in V:
1734+
operands.append(ufl.as_tensor(components[cur:cur+Vi.value_size].reshape(Vi.value_shape)))
1735+
cur += Vi.value_size
1736+
expressions = tuple(map(expr._ufl_expr_reconstruct_, operands, duals))
1737+
return expressions
1738+
1739+
1740+
class MixedInterpolator(Interpolator):
1741+
"""A reusable interpolation object between MixedFunctionSpaces.
1742+
1743+
Parameters
1744+
----------
1745+
expr
1746+
The underlying ufl.Interpolate or the operand to the ufl.Interpolate.
1747+
V
1748+
The :class:`.FunctionSpace` or :class:`.Function` to
1749+
interpolate into.
1750+
bcs
1751+
A list of boundary conditions.
1752+
**kwargs
1753+
Any extra kwargs are passed on to the sub Interpolators.
1754+
For details see :class:`firedrake.interpolation.Interpolator`.
1755+
"""
1756+
def __init__(self, expr, V, bcs=None, **kwargs):
1757+
super(MixedInterpolator, self).__init__(expr, V, bcs=bcs, **kwargs)
1758+
expr = self.ufl_interpolate
1759+
bcs = bcs or ()
1760+
self.arguments = expr.arguments()
1761+
1762+
# Split the target (dual) argument
1763+
dual_split = split_interpolate_target(expr)
1764+
self.sub_interpolators = {}
1765+
for i, form in enumerate(dual_split):
1766+
# Split the source (primal) argument
1767+
for j, sub_interp in firedrake.formmanipulation.split_form(form):
1768+
j = max(j) if j else 0
1769+
# Ensure block sparsity
1770+
vi, operand = sub_interp.argument_slots()
1771+
if not isinstance(operand, ufl.classes.Zero):
1772+
Vtarget = vi.function_space().dual()
1773+
adjoint = vi.number() == 1 if isinstance(vi, Coargument) else True
1774+
1775+
args = sub_interp.arguments()
1776+
Vsource = args[0 if adjoint else 1].function_space()
1777+
sub_bcs = [bc for bc in bcs if bc.function_space() in {Vsource, Vtarget}]
1778+
1779+
indices = (j, i) if adjoint else (i, j)
1780+
Isub = Interpolator(sub_interp, Vtarget, bcs=sub_bcs, **kwargs)
1781+
self.sub_interpolators[indices] = Isub
1782+
1783+
self.callable = self._callable
1784+
1785+
def _callable(self):
1786+
"""Assemble the operator."""
1787+
shape = tuple(len(a.function_space()) for a in self.arguments)
1788+
Isubs = self.sub_interpolators
1789+
blocks = numpy.reshape([Isubs[ij].callable().handle if ij in Isubs else PETSc.Mat()
1790+
for ij in numpy.ndindex(shape)], shape)
1791+
petscmat = PETSc.Mat().createNest(blocks)
1792+
tensor = firedrake.AssembledMatrix(self.arguments, self.bcs, petscmat)
1793+
return tensor.M
1794+
1795+
def _interpolate(self, output=None, adjoint=False, **kwargs):
1796+
"""Assemble the action."""
1797+
tensor = output
1798+
rank = len(self.arguments)
1799+
if rank == 1:
1800+
# Assemble the action
1801+
if tensor is None:
1802+
V_dest = self.arguments[0].function_space().dual()
1803+
tensor = firedrake.Function(V_dest)
1804+
for k, fsub in enumerate(tensor.subfunctions):
1805+
fsub.assign(sum(Isub.assemble(**kwargs) for (i, j), Isub in self.sub_interpolators.items() if i == k))
1806+
return tensor
1807+
elif rank == 0:
1808+
# Assemble the double action
1809+
result = sum(Isub.assemble(**kwargs) for (i, j), Isub in self.sub_interpolators.items())
1810+
return tensor.assign(result) if tensor else result

tests/firedrake/regression/test_interpolate.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,3 +519,50 @@ def test_interpolate_logical_not():
519519
a = assemble(interpolate(conditional(Not(x < .2), 1, 0), V))
520520
b = assemble(interpolate(conditional(x >= .2, 1, 0), V))
521521
assert np.allclose(a.dat.data, b.dat.data)
522+
523+
524+
@pytest.mark.parametrize("mode", ("forward", "adjoint"))
525+
def test_mixed_matrix(mode):
526+
nx = 3
527+
mesh = UnitSquareMesh(nx, nx)
528+
529+
V1 = VectorFunctionSpace(mesh, "CG", 2)
530+
V2 = FunctionSpace(mesh, "CG", 1)
531+
V3 = FunctionSpace(mesh, "CG", 1)
532+
V4 = FunctionSpace(mesh, "DG", 1)
533+
534+
Z = V1 * V2
535+
W = V3 * V3 * V4
536+
537+
if mode == "forward":
538+
I = Interpolate(TrialFunction(Z), TestFunction(W.dual()))
539+
a = assemble(I)
540+
assert a.arguments()[0].function_space() == W.dual()
541+
assert a.arguments()[1].function_space() == Z
542+
assert a.petscmat.getSize() == (W.dim(), Z.dim())
543+
assert a.petscmat.getType() == "nest"
544+
545+
u = Function(Z)
546+
u.subfunctions[0].sub(0).assign(1)
547+
u.subfunctions[0].sub(1).assign(2)
548+
u.subfunctions[1].assign(3)
549+
result_matfree = assemble(Interpolate(u, TestFunction(W.dual())))
550+
elif mode == "adjoint":
551+
I = Interpolate(TestFunction(Z), TrialFunction(W.dual()))
552+
a = assemble(I)
553+
assert a.arguments()[1].function_space() == W.dual()
554+
assert a.arguments()[0].function_space() == Z
555+
assert a.petscmat.getSize() == (Z.dim(), W.dim())
556+
assert a.petscmat.getType() == "nest"
557+
558+
u = Function(W.dual())
559+
u.subfunctions[0].assign(1)
560+
u.subfunctions[1].assign(2)
561+
u.subfunctions[2].assign(3)
562+
result_matfree = assemble(Interpolate(TestFunction(Z), u))
563+
else:
564+
raise ValueError(f"Unrecognized mode {mode}")
565+
566+
result_explicit = assemble(action(a, u))
567+
for x, y in zip(result_explicit.subfunctions, result_matfree.subfunctions):
568+
assert np.allclose(x.dat.data, y.dat.data)

0 commit comments

Comments
 (0)