@@ -263,7 +263,18 @@ class Interpolator(abc.ABC):
263
263
264
264
def __new__ (cls , expr , V , ** kwargs ):
265
265
if isinstance (expr , ufl .Interpolate ):
266
+ # Mixed spaces are handled well only by the primal 1-form.
267
+ # Are we a 2-form or a dual 1-form?
268
+ arguments = expr .arguments ()
269
+ if any (not isinstance (a , Coargument ) for a in arguments ):
270
+ # Do we have mixed source or target spaces?
271
+ spaces = [a .function_space () for a in arguments ]
272
+ if len (spaces ) < 2 :
273
+ spaces .append (V )
274
+ if any (len (space ) > 1 for space in spaces ):
275
+ return object .__new__ (MixedInterpolator )
266
276
expr , = expr .ufl_operands
277
+
267
278
target_mesh = as_domain (V )
268
279
source_mesh = extract_unique_domain (expr ) or target_mesh
269
280
submesh_interp_implemented = \
@@ -309,9 +320,10 @@ def __init__(
309
320
target_mesh = as_domain (V )
310
321
source_mesh = extract_unique_domain (operand ) or target_mesh
311
322
vom_onto_other_vom = ((source_mesh is not target_mesh )
323
+ and isinstance (self , SameMeshInterpolator )
312
324
and isinstance (source_mesh .topology , VertexOnlyMeshTopology )
313
325
and isinstance (target_mesh .topology , VertexOnlyMeshTopology ))
314
- if not isinstance (self , SameMeshInterpolator ) or vom_onto_other_vom :
326
+ if isinstance (self , CrossMeshInterpolator ) or vom_onto_other_vom :
315
327
# For bespoke interpolation, we currently rely on different assembly procedures:
316
328
# 1) Interpolate(Argument(V1, 1), Argument(V2.dual(), 0)) -> Forward operator (2-form)
317
329
# 2) Interpolate(Argument(V1, 0), Argument(V2.dual(), 1)) -> Adjoint operator (2-form)
@@ -369,7 +381,7 @@ def _interpolate(self, *args, **kwargs):
369
381
"""
370
382
pass
371
383
372
- def assemble (self , tensor = None , default_missing_val = None ):
384
+ def assemble (self , tensor = None , ** kwargs ):
373
385
"""Assemble the operator (or its action)."""
374
386
from firedrake .assemble import assemble
375
387
needs_adjoint = self .ufl_interpolate_renumbered != self .ufl_interpolate
@@ -383,13 +395,11 @@ def assemble(self, tensor=None, default_missing_val=None):
383
395
if needs_adjoint :
384
396
# Out-of-place Hermitian transpose
385
397
petsc_mat .hermitianTranspose (out = res )
386
- elif res :
387
- petsc_mat .copy (res )
398
+ elif tensor :
399
+ petsc_mat .copy (tensor . petscmat )
388
400
else :
389
401
res = petsc_mat
390
- if tensor is None :
391
- tensor = firedrake .AssembledMatrix (arguments , self .bcs , res )
392
- return tensor
402
+ return tensor or firedrake .AssembledMatrix (arguments , self .bcs , res )
393
403
else :
394
404
# Assembling the action
395
405
cofunctions = ()
@@ -401,11 +411,11 @@ def assemble(self, tensor=None, default_missing_val=None):
401
411
cofunctions = (dual_arg ,)
402
412
403
413
if needs_adjoint and len (arguments ) == 0 :
404
- Iu = self ._interpolate (default_missing_val = default_missing_val )
414
+ Iu = self ._interpolate (** kwargs )
405
415
return assemble (ufl .Action (* cofunctions , Iu ), tensor = tensor )
406
416
else :
407
417
return self ._interpolate (* cofunctions , output = tensor , adjoint = needs_adjoint ,
408
- default_missing_val = default_missing_val )
418
+ ** kwargs )
409
419
410
420
411
421
class DofNotDefinedError (Exception ):
@@ -975,33 +985,10 @@ def callable():
975
985
return callable
976
986
else :
977
987
loops = []
978
- if len (V ) == 1 :
979
- expressions = (expr ,)
980
- else :
981
- if (hasattr (operand , "subfunctions" ) and len (operand .subfunctions ) == len (V )
982
- and all (sub_op .ufl_shape == Vsub .value_shape for Vsub , sub_op in zip (V , operand .subfunctions ))):
983
- # Use subfunctions if they match the target shapes
984
- operands = operand .subfunctions
985
- else :
986
- # Unflatten the expression into the shapes of the mixed components
987
- offset = 0
988
- operands = []
989
- for Vsub in V :
990
- if len (Vsub .value_shape ) == 0 :
991
- operands .append (operand [offset ])
992
- else :
993
- components = [operand [offset + j ] for j in range (Vsub .value_size )]
994
- operands .append (ufl .as_tensor (numpy .reshape (components , Vsub .value_shape )))
995
- offset += Vsub .value_size
996
-
997
- # Split the dual argument
998
- if isinstance (dual_arg , Cofunction ):
999
- duals = dual_arg .subfunctions
1000
- elif isinstance (dual_arg , Coargument ):
1001
- duals = [Coargument (Vsub , number = dual_arg .number ()) for Vsub in dual_arg .function_space ()]
1002
- else :
1003
- duals = [v for _ , v in sorted (firedrake .formmanipulation .split_form (dual_arg ))]
1004
- expressions = map (expr ._ufl_expr_reconstruct_ , operands , duals )
988
+ expressions = split_interpolate_target (expr )
989
+
990
+ if access == op2 .INC :
991
+ loops .append (tensor .zero )
1005
992
1006
993
# Interpolate each sub expression into each function space
1007
994
for Vsub , sub_tensor , sub_expr in zip (V , tensor , expressions ):
@@ -1074,8 +1061,6 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
1074
1061
parameters ['scalar_type' ] = utils .ScalarType
1075
1062
1076
1063
callables = ()
1077
- if access == op2 .INC :
1078
- callables += (tensor .zero ,)
1079
1064
1080
1065
# For the matfree adjoint 1-form and the 0-form, the cellwise kernel will add multiple
1081
1066
# contributions from the facet DOFs of the dual argument.
@@ -1720,3 +1705,106 @@ def _wrap_dummy_mat(self):
1720
1705
1721
1706
def duplicate (self , mat = None , op = None ):
1722
1707
return self ._wrap_dummy_mat ()
1708
+
1709
+
1710
+ def split_interpolate_target (expr : ufl .Interpolate ):
1711
+ """Split an Interpolate into the components (subfunctions) of the target space."""
1712
+ dual_arg , operand = expr .argument_slots ()
1713
+ V = dual_arg .function_space ().dual ()
1714
+ if len (V ) == 1 :
1715
+ return (expr ,)
1716
+ # Split the target (dual) argument
1717
+ if isinstance (dual_arg , Cofunction ):
1718
+ duals = dual_arg .subfunctions
1719
+ elif isinstance (dual_arg , ufl .Coargument ):
1720
+ duals = [Coargument (Vsub , dual_arg .number ()) for Vsub in dual_arg .function_space ()]
1721
+ else :
1722
+ duals = [vi for _ , vi in sorted (firedrake .formmanipulation .split_form (dual_arg ))]
1723
+ # Split the operand into the target shapes
1724
+ if (isinstance (operand , firedrake .Function ) and len (operand .subfunctions ) == len (V )
1725
+ and all (fsub .ufl_shape == Vsub .value_shape for Vsub , fsub in zip (V , operand .subfunctions ))):
1726
+ # Use subfunctions if they match the target shapes
1727
+ operands = operand .subfunctions
1728
+ else :
1729
+ # Unflatten the expression into the target shapes
1730
+ cur = 0
1731
+ operands = []
1732
+ components = numpy .reshape (operand , (- 1 ,))
1733
+ for Vi in V :
1734
+ operands .append (ufl .as_tensor (components [cur :cur + Vi .value_size ].reshape (Vi .value_shape )))
1735
+ cur += Vi .value_size
1736
+ expressions = tuple (map (expr ._ufl_expr_reconstruct_ , operands , duals ))
1737
+ return expressions
1738
+
1739
+
1740
+ class MixedInterpolator (Interpolator ):
1741
+ """A reusable interpolation object between MixedFunctionSpaces.
1742
+
1743
+ Parameters
1744
+ ----------
1745
+ expr
1746
+ The underlying ufl.Interpolate or the operand to the ufl.Interpolate.
1747
+ V
1748
+ The :class:`.FunctionSpace` or :class:`.Function` to
1749
+ interpolate into.
1750
+ bcs
1751
+ A list of boundary conditions.
1752
+ **kwargs
1753
+ Any extra kwargs are passed on to the sub Interpolators.
1754
+ For details see :class:`firedrake.interpolation.Interpolator`.
1755
+ """
1756
+ def __init__ (self , expr , V , bcs = None , ** kwargs ):
1757
+ super (MixedInterpolator , self ).__init__ (expr , V , bcs = bcs , ** kwargs )
1758
+ expr = self .ufl_interpolate
1759
+ bcs = bcs or ()
1760
+ self .arguments = expr .arguments ()
1761
+
1762
+ # Split the target (dual) argument
1763
+ dual_split = split_interpolate_target (expr )
1764
+ self .sub_interpolators = {}
1765
+ for i , form in enumerate (dual_split ):
1766
+ # Split the source (primal) argument
1767
+ for j , sub_interp in firedrake .formmanipulation .split_form (form ):
1768
+ j = max (j ) if j else 0
1769
+ # Ensure block sparsity
1770
+ vi , operand = sub_interp .argument_slots ()
1771
+ if not isinstance (operand , ufl .classes .Zero ):
1772
+ Vtarget = vi .function_space ().dual ()
1773
+ adjoint = vi .number () == 1 if isinstance (vi , Coargument ) else True
1774
+
1775
+ args = sub_interp .arguments ()
1776
+ Vsource = args [0 if adjoint else 1 ].function_space ()
1777
+ sub_bcs = [bc for bc in bcs if bc .function_space () in {Vsource , Vtarget }]
1778
+
1779
+ indices = (j , i ) if adjoint else (i , j )
1780
+ Isub = Interpolator (sub_interp , Vtarget , bcs = sub_bcs , ** kwargs )
1781
+ self .sub_interpolators [indices ] = Isub
1782
+
1783
+ self .callable = self ._callable
1784
+
1785
+ def _callable (self ):
1786
+ """Assemble the operator."""
1787
+ shape = tuple (len (a .function_space ()) for a in self .arguments )
1788
+ Isubs = self .sub_interpolators
1789
+ blocks = numpy .reshape ([Isubs [ij ].callable ().handle if ij in Isubs else PETSc .Mat ()
1790
+ for ij in numpy .ndindex (shape )], shape )
1791
+ petscmat = PETSc .Mat ().createNest (blocks )
1792
+ tensor = firedrake .AssembledMatrix (self .arguments , self .bcs , petscmat )
1793
+ return tensor .M
1794
+
1795
+ def _interpolate (self , output = None , adjoint = False , ** kwargs ):
1796
+ """Assemble the action."""
1797
+ tensor = output
1798
+ rank = len (self .arguments )
1799
+ if rank == 1 :
1800
+ # Assemble the action
1801
+ if tensor is None :
1802
+ V_dest = self .arguments [0 ].function_space ().dual ()
1803
+ tensor = firedrake .Function (V_dest )
1804
+ for k , fsub in enumerate (tensor .subfunctions ):
1805
+ fsub .assign (sum (Isub .assemble (** kwargs ) for (i , j ), Isub in self .sub_interpolators .items () if i == k ))
1806
+ return tensor
1807
+ elif rank == 0 :
1808
+ # Assemble the double action
1809
+ result = sum (Isub .assemble (** kwargs ) for (i , j ), Isub in self .sub_interpolators .items ())
1810
+ return tensor .assign (result ) if tensor else result
0 commit comments