@@ -1237,33 +1237,28 @@ def _weight(self):
1237
1237
@cached_property
1238
1238
def _kernels (self ):
1239
1239
try :
1240
- # We generate custom prolongation and restriction kernels mainly because:
1241
- # 1. Code generation for the transpose of prolongation is not readily available
1242
- # 2. Dual evaluation of EnrichedElement is not yet implemented in FInAT
1240
+ prolong = partial (firedrake .assemble , firedrake .interpolate (self .uc , self .Vf ), tensor = self .uf )
1241
+ prolong ()
1242
+ self .rf = firedrake .Function (self .Vf .dual (), val = self .uf .dat )
1243
+ self .rc = firedrake .Function (self .Vc .dual (), val = self .uc .dat )
1244
+ restrict = partial (firedrake .assemble , firedrake .interpolate (firedrake .TestFunction (self .Vc ), self .rf ), tensor = self .rc )
1245
+ except NotImplementedError :
1246
+ # We generate custom prolongation and restriction kernels because
1247
+ # dual evaluation of EnrichedElement is not yet implemented in FInAT
1243
1248
uf_map = get_permuted_map (self .Vf )
1244
1249
uc_map = get_permuted_map (self .Vc )
1245
1250
prolong_kernel , restrict_kernel , coefficients = self .make_blas_kernels (self .Vf , self .Vc )
1246
1251
prolong_args = [prolong_kernel , self .uf .cell_set ,
1247
1252
self .uf .dat (op2 .INC , uf_map ),
1248
1253
self .uc .dat (op2 .READ , uc_map ),
1249
1254
self ._weight .dat (op2 .READ , uf_map )]
1250
- except ValueError :
1251
- # The elements do not have the expected tensor product structure
1252
- # Fall back to aij kernels
1253
- uf_map = self .Vf .cell_node_map ()
1254
- uc_map = self .Vc .cell_node_map ()
1255
- prolong_kernel , restrict_kernel , coefficients = self .make_kernels (self .Vf , self .Vc )
1256
- prolong_args = [prolong_kernel , self .uf .cell_set ,
1257
- self .uf .dat (op2 .WRITE , uf_map ),
1258
- self .uc .dat (op2 .READ , uc_map )]
1259
-
1260
- restrict_args = [restrict_kernel , self .uf .cell_set ,
1261
- self .uc .dat (op2 .INC , uc_map ),
1262
- self .uf .dat (op2 .READ , uf_map ),
1263
- self ._weight .dat (op2 .READ , uf_map )]
1264
- coefficient_args = [c .dat (op2 .READ , c .cell_node_map ()) for c in coefficients ]
1265
- prolong = op2 .ParLoop (* prolong_args , * coefficient_args )
1266
- restrict = op2 .ParLoop (* restrict_args , * coefficient_args )
1255
+ restrict_args = [restrict_kernel , self .uf .cell_set ,
1256
+ self .uc .dat (op2 .INC , uc_map ),
1257
+ self .uf .dat (op2 .READ , uf_map ),
1258
+ self ._weight .dat (op2 .READ , uf_map )]
1259
+ coefficient_args = [c .dat (op2 .READ , c .cell_node_map ()) for c in coefficients ]
1260
+ prolong = op2 .ParLoop (* prolong_args , * coefficient_args )
1261
+ restrict = op2 .ParLoop (* restrict_args , * coefficient_args )
1267
1262
return prolong , restrict
1268
1263
1269
1264
def _prolong (self ):
@@ -1571,56 +1566,9 @@ def prolongation_matrix_aij(P1, Pk, P1_bcs=[], Pk_bcs=[]):
1571
1566
P1 = P1 .function_space ()
1572
1567
if isinstance (Pk , firedrake .Function ):
1573
1568
Pk = Pk .function_space ()
1574
- sp = op2 .Sparsity ((Pk .dof_dset ,
1575
- P1 .dof_dset ),
1576
- {(i , j ): [(rmap , cmap , None )]
1577
- for i , rmap in enumerate (Pk .cell_node_map ())
1578
- for j , cmap in enumerate (P1 .cell_node_map ())
1579
- if i == j })
1580
- mat = op2 .Mat (sp , PETSc .ScalarType )
1581
- mesh = Pk .mesh ()
1582
-
1583
- fele = Pk .ufl_element ()
1584
- if type (fele ) is finat .ufl .MixedElement :
1585
- for i in range (fele .num_sub_elements ):
1586
- Pk_bcs_i = [bc for bc in Pk_bcs if bc .function_space ().index == i ]
1587
- P1_bcs_i = [bc for bc in P1_bcs if bc .function_space ().index == i ]
1588
-
1589
- rlgmap , clgmap = mat [i , i ].local_to_global_maps
1590
- rlgmap = Pk .sub (i ).local_to_global_map (Pk_bcs_i , lgmap = rlgmap )
1591
- clgmap = P1 .sub (i ).local_to_global_map (P1_bcs_i , lgmap = clgmap )
1592
- unroll = any (bc .function_space ().component is not None
1593
- for bc in chain (Pk_bcs_i , P1_bcs_i ) if bc is not None )
1594
- matarg = mat [i , i ](op2 .WRITE , (Pk .sub (i ).cell_node_map (), P1 .sub (i ).cell_node_map ()),
1595
- lgmaps = ((rlgmap , clgmap ), ), unroll_map = unroll )
1596
- expr = firedrake .TrialFunction (P1 .sub (i ))
1597
- kernel , coefficients = prolongation_transfer_kernel_action (Pk .sub (i ), expr )
1598
- parloop_args = [kernel , mesh .cell_set , matarg ]
1599
- for coefficient in coefficients :
1600
- m_ = coefficient .cell_node_map ()
1601
- parloop_args .append (coefficient .dat (op2 .READ , m_ ))
1602
-
1603
- op2 .par_loop (* parloop_args )
1604
-
1605
- else :
1606
- rlgmap , clgmap = mat .local_to_global_maps
1607
- rlgmap = Pk .local_to_global_map (Pk_bcs , lgmap = rlgmap )
1608
- clgmap = P1 .local_to_global_map (P1_bcs , lgmap = clgmap )
1609
- unroll = any (bc .function_space ().component is not None
1610
- for bc in chain (Pk_bcs , P1_bcs ) if bc is not None )
1611
- matarg = mat (op2 .WRITE , (Pk .cell_node_map (), P1 .cell_node_map ()),
1612
- lgmaps = ((rlgmap , clgmap ), ), unroll_map = unroll )
1613
- expr = firedrake .TrialFunction (P1 )
1614
- kernel , coefficients = prolongation_transfer_kernel_action (Pk , expr )
1615
- parloop_args = [kernel , mesh .cell_set , matarg ]
1616
- for coefficient in coefficients :
1617
- m_ = coefficient .cell_node_map ()
1618
- parloop_args .append (coefficient .dat (op2 .READ , m_ ))
1619
-
1620
- op2 .par_loop (* parloop_args )
1621
-
1622
- mat .assemble ()
1623
- return mat .handle
1569
+ bcs = P1_bcs + Pk_bcs
1570
+ mat = firedrake .assemble (firedrake .interpolate (firedrake .TrialFunction (P1 ), Pk ), bcs = bcs )
1571
+ return mat .petscmat
1624
1572
1625
1573
1626
1574
def prolongation_matrix_matfree (Vc , Vf , Vc_bcs = [], Vf_bcs = []):
0 commit comments