Skip to content

Commit 2297f08

Browse files
committed
PMGPC: use native matrix-free (adjoint) interpolation
1 parent f21dad9 commit 2297f08

File tree

1 file changed

+18
-70
lines changed
  • firedrake/preconditioners

1 file changed

+18
-70
lines changed

firedrake/preconditioners/pmg.py

Lines changed: 18 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,33 +1237,28 @@ def _weight(self):
12371237
@cached_property
12381238
def _kernels(self):
12391239
try:
1240-
# We generate custom prolongation and restriction kernels mainly because:
1241-
# 1. Code generation for the transpose of prolongation is not readily available
1242-
# 2. Dual evaluation of EnrichedElement is not yet implemented in FInAT
1240+
prolong = partial(firedrake.assemble, firedrake.interpolate(self.uc, self.Vf), tensor=self.uf)
1241+
prolong()
1242+
self.rf = firedrake.Function(self.Vf.dual(), val=self.uf.dat)
1243+
self.rc = firedrake.Function(self.Vc.dual(), val=self.uc.dat)
1244+
restrict = partial(firedrake.assemble, firedrake.interpolate(firedrake.TestFunction(self.Vc), self.rf), tensor=self.rc)
1245+
except NotImplementedError:
1246+
# We generate custom prolongation and restriction kernels because
1247+
# dual evaluation of EnrichedElement is not yet implemented in FInAT
12431248
uf_map = get_permuted_map(self.Vf)
12441249
uc_map = get_permuted_map(self.Vc)
12451250
prolong_kernel, restrict_kernel, coefficients = self.make_blas_kernels(self.Vf, self.Vc)
12461251
prolong_args = [prolong_kernel, self.uf.cell_set,
12471252
self.uf.dat(op2.INC, uf_map),
12481253
self.uc.dat(op2.READ, uc_map),
12491254
self._weight.dat(op2.READ, uf_map)]
1250-
except ValueError:
1251-
# The elements do not have the expected tensor product structure
1252-
# Fall back to aij kernels
1253-
uf_map = self.Vf.cell_node_map()
1254-
uc_map = self.Vc.cell_node_map()
1255-
prolong_kernel, restrict_kernel, coefficients = self.make_kernels(self.Vf, self.Vc)
1256-
prolong_args = [prolong_kernel, self.uf.cell_set,
1257-
self.uf.dat(op2.WRITE, uf_map),
1258-
self.uc.dat(op2.READ, uc_map)]
1259-
1260-
restrict_args = [restrict_kernel, self.uf.cell_set,
1261-
self.uc.dat(op2.INC, uc_map),
1262-
self.uf.dat(op2.READ, uf_map),
1263-
self._weight.dat(op2.READ, uf_map)]
1264-
coefficient_args = [c.dat(op2.READ, c.cell_node_map()) for c in coefficients]
1265-
prolong = op2.ParLoop(*prolong_args, *coefficient_args)
1266-
restrict = op2.ParLoop(*restrict_args, *coefficient_args)
1255+
restrict_args = [restrict_kernel, self.uf.cell_set,
1256+
self.uc.dat(op2.INC, uc_map),
1257+
self.uf.dat(op2.READ, uf_map),
1258+
self._weight.dat(op2.READ, uf_map)]
1259+
coefficient_args = [c.dat(op2.READ, c.cell_node_map()) for c in coefficients]
1260+
prolong = op2.ParLoop(*prolong_args, *coefficient_args)
1261+
restrict = op2.ParLoop(*restrict_args, *coefficient_args)
12671262
return prolong, restrict
12681263

12691264
def _prolong(self):
@@ -1571,56 +1566,9 @@ def prolongation_matrix_aij(P1, Pk, P1_bcs=[], Pk_bcs=[]):
15711566
P1 = P1.function_space()
15721567
if isinstance(Pk, firedrake.Function):
15731568
Pk = Pk.function_space()
1574-
sp = op2.Sparsity((Pk.dof_dset,
1575-
P1.dof_dset),
1576-
{(i, j): [(rmap, cmap, None)]
1577-
for i, rmap in enumerate(Pk.cell_node_map())
1578-
for j, cmap in enumerate(P1.cell_node_map())
1579-
if i == j})
1580-
mat = op2.Mat(sp, PETSc.ScalarType)
1581-
mesh = Pk.mesh()
1582-
1583-
fele = Pk.ufl_element()
1584-
if type(fele) is finat.ufl.MixedElement:
1585-
for i in range(fele.num_sub_elements):
1586-
Pk_bcs_i = [bc for bc in Pk_bcs if bc.function_space().index == i]
1587-
P1_bcs_i = [bc for bc in P1_bcs if bc.function_space().index == i]
1588-
1589-
rlgmap, clgmap = mat[i, i].local_to_global_maps
1590-
rlgmap = Pk.sub(i).local_to_global_map(Pk_bcs_i, lgmap=rlgmap)
1591-
clgmap = P1.sub(i).local_to_global_map(P1_bcs_i, lgmap=clgmap)
1592-
unroll = any(bc.function_space().component is not None
1593-
for bc in chain(Pk_bcs_i, P1_bcs_i) if bc is not None)
1594-
matarg = mat[i, i](op2.WRITE, (Pk.sub(i).cell_node_map(), P1.sub(i).cell_node_map()),
1595-
lgmaps=((rlgmap, clgmap), ), unroll_map=unroll)
1596-
expr = firedrake.TrialFunction(P1.sub(i))
1597-
kernel, coefficients = prolongation_transfer_kernel_action(Pk.sub(i), expr)
1598-
parloop_args = [kernel, mesh.cell_set, matarg]
1599-
for coefficient in coefficients:
1600-
m_ = coefficient.cell_node_map()
1601-
parloop_args.append(coefficient.dat(op2.READ, m_))
1602-
1603-
op2.par_loop(*parloop_args)
1604-
1605-
else:
1606-
rlgmap, clgmap = mat.local_to_global_maps
1607-
rlgmap = Pk.local_to_global_map(Pk_bcs, lgmap=rlgmap)
1608-
clgmap = P1.local_to_global_map(P1_bcs, lgmap=clgmap)
1609-
unroll = any(bc.function_space().component is not None
1610-
for bc in chain(Pk_bcs, P1_bcs) if bc is not None)
1611-
matarg = mat(op2.WRITE, (Pk.cell_node_map(), P1.cell_node_map()),
1612-
lgmaps=((rlgmap, clgmap), ), unroll_map=unroll)
1613-
expr = firedrake.TrialFunction(P1)
1614-
kernel, coefficients = prolongation_transfer_kernel_action(Pk, expr)
1615-
parloop_args = [kernel, mesh.cell_set, matarg]
1616-
for coefficient in coefficients:
1617-
m_ = coefficient.cell_node_map()
1618-
parloop_args.append(coefficient.dat(op2.READ, m_))
1619-
1620-
op2.par_loop(*parloop_args)
1621-
1622-
mat.assemble()
1623-
return mat.handle
1569+
bcs = P1_bcs + Pk_bcs
1570+
mat = firedrake.assemble(firedrake.interpolate(firedrake.TrialFunction(P1), Pk), bcs=bcs)
1571+
return mat.petscmat
16241572

16251573

16261574
def prolongation_matrix_matfree(Vc, Vf, Vc_bcs=[], Vf_bcs=[]):

0 commit comments

Comments
 (0)