Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Hiptmair multigrid PC #2707

Merged
merged 26 commits into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,17 @@ class Interpolator(object):
:class:`Interpolator` is also collected).

"""
def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE):
def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE, bcs=None):
try:
self.callable, arguments = make_interpolator(expr, V, subset, access)
self.callable, arguments = make_interpolator(expr, V, subset, access, bcs=bcs)
except FIAT.hdiv_trace.TraceError:
raise NotImplementedError("Can't interpolate onto traces sorry")
self.arguments = arguments
self.nargs = len(arguments)
self.freeze_expr = freeze_expr
self.expr = expr
self.V = V
self.bcs = bcs

@PETSc.Log.EventDecorator()
@annotate_interpolate
Expand Down Expand Up @@ -154,7 +155,7 @@ def interpolate(self, *function, output=None, transpose=False):


@PETSc.Log.EventDecorator()
def make_interpolator(expr, V, subset, access):
def make_interpolator(expr, V, subset, access, bcs=None):
assert isinstance(expr, ufl.classes.Expr)

arguments = extract_arguments(expr)
Expand Down Expand Up @@ -215,7 +216,10 @@ def make_interpolator(expr, V, subset, access):
if len(V) > 1:
raise NotImplementedError(
"UFL expressions for mixed functions are not yet supported.")
loops.extend(_interpolator(V, tensor, expr, subset, arguments, access))
loops.extend(_interpolator(V, tensor, expr, subset, arguments, access, bcs=bcs))

if bcs and len(arguments) == 0:
loops.extend([partial(bc.apply, f) for bc in bcs])

def callable(loops, f):
for l in loops:
Expand All @@ -226,7 +230,7 @@ def callable(loops, f):


@utils.known_pyop2_safe
def _interpolator(V, tensor, expr, subset, arguments, access):
def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
try:
expr = ufl.as_ufl(expr)
except ufl.UFLException:
Expand Down Expand Up @@ -343,14 +347,20 @@ def _interpolator(V, tensor, expr, subset, arguments, access):
else:
assert access == op2.WRITE # Other access descriptors not done for Matrices.
rows_map = V.cell_node_map()
columns_map = arguments[0].function_space().cell_node_map()
Vcol = arguments[0].function_space()
columns_map = Vcol.cell_node_map()
if target_mesh is not source_mesh:
# Since the par_loop is over the target mesh cells we need to
# compose a map that takes us from target mesh cells to the
# function space nodes on the source mesh.
columns_map = compose_map_and_cache(target_mesh.cell_parent_cell_map,
columns_map)
parloop_args.append(tensor(op2.WRITE, (rows_map, columns_map)))
lgmaps = None
if bcs:
bc_rows = [bc for bc in bcs if bc.function_space() == V]
bc_cols = [bc for bc in bcs if bc.function_space() == Vcol]
lgmaps = [(V.local_to_global_map(bc_rows), Vcol.local_to_global_map(bc_cols))]
parloop_args.append(tensor(op2.WRITE, (rows_map, columns_map), lgmaps=lgmaps))
if oriented:
co = target_mesh.cell_orientations()
parloop_args.append(co.dat(op2.READ, co.cell_node_map()))
Expand Down
1 change: 1 addition & 0 deletions firedrake/preconditioners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
from firedrake.preconditioners.hypre_ams import * # noqa: F401
from firedrake.preconditioners.hypre_ads import * # noqa: F401
from firedrake.preconditioners.fdm import * # noqa: F401
from firedrake.preconditioners.hiptmair import * # noqa: F401
262 changes: 262 additions & 0 deletions firedrake/preconditioners/hiptmair.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
import abc

from firedrake.petsc import PETSc
from firedrake.preconditioners.base import PCBase
import firedrake.dmhooks as dmhooks
import ufl


__all__ = ("HiptmairPC",)


class TwoLevelPC(PCBase):

needs_python_pmat = False

@abc.abstractmethod
def coarsen(self, pc):
"""Return a tuple with coarse bilinear form, coarse
boundary conditions, and coarse-to-fine interpolation matrix
"""
raise NotImplementedError

def initialize(self, pc):
from firedrake import parameters
from firedrake.assemble import allocate_matrix, TwoFormAssembler
JDBetteridge marked this conversation as resolved.
Show resolved Hide resolved

A, P = pc.getOperators()
appctx = self.get_appctx(pc)
fcp = appctx.get("form_compiler_parameters")

prefix = pc.getOptionsPrefix()
options_prefix = prefix + self._prefix
opts = PETSc.Options()

coarse_operator, coarse_space_bcs, interp_petscmat = self.coarsen(pc)

# Handle the coarse operator
coarse_options_prefix = options_prefix + "mg_coarse_"
coarse_mat_type = opts.getString(coarse_options_prefix + "mat_type",
parameters["default_matrix_type"])

self.coarse_op = allocate_matrix(coarse_operator,
bcs=coarse_space_bcs,
form_compiler_parameters=fcp,
mat_type=coarse_mat_type,
options_prefix=coarse_options_prefix)
self._assemble_coarse_op = TwoFormAssembler(coarse_operator, tensor=self.coarse_op,
form_compiler_parameters=fcp,
bcs=coarse_space_bcs).assemble
self._assemble_coarse_op()
coarse_opmat = self.coarse_op.petscmat

# We set up a PCMG object that uses the constructed interpolation
# matrix to generate the restriction/prolongation operators.
# This is a two-level multigrid preconditioner.
pcmg = PETSc.PC().create(comm=pc.comm)
pcmg.incrementTabLevel(1, parent=pc)

pcmg.setType(pc.Type.MG)
pcmg.setOptionsPrefix(options_prefix)
pcmg.setMGLevels(2)
pcmg.setMGType(pc.MGType.ADDITIVE)
pcmg.setMGCycleType(pc.MGCycleType.V)
pcmg.setMGInterpolation(1, interp_petscmat)
# FIXME the default for MGRScale is created with the wrong shape when dim(coarse) > dim(fine)
# FIXME there is no need for injection in a KSP context, probably this comes from the snes_ctx below
# as workaround define injection as the restriction of the solution times a zero vector
pcmg.setMGRScale(1, interp_petscmat.createVecRight())
pcmg.setOperators(A=A, P=P)

coarse_solver = pcmg.getMGCoarseSolve()
coarse_solver.setOperators(A=coarse_opmat, P=coarse_opmat)
# coarse space dm
coarse_space = coarse_operator.arguments()[-1].function_space()
coarse_dm = coarse_space.dm
coarse_solver.setDM(coarse_dm)
coarse_solver.setDMActive(False)
pcmg.setDM(pc.getDM())
pcmg.setFromOptions()
self.pc = pcmg
self._dm = coarse_dm

prefix = coarse_solver.getOptionsPrefix()
# Create new appctx
self._ctx_ref = self.new_snes_ctx(pc,
coarse_operator,
coarse_space_bcs,
coarse_mat_type,
fcp,
options_prefix=prefix)

with dmhooks.add_hooks(coarse_dm, self,
appctx=self._ctx_ref,
save=False):
coarse_solver.setFromOptions()

def update(self, pc):
self._assemble_coarse_op()
self.pc.setUp()

def apply(self, pc, X, Y):
dm = self._dm
with dmhooks.add_hooks(dm, self, appctx=self._ctx_ref):
self.pc.apply(X, Y)

def applyTranspose(self, pc, X, Y):
dm = self._dm
with dmhooks.add_hooks(dm, self, appctx=self._ctx_ref):
self.pc.applyTranspose(X, Y)

def view(self, pc, viewer=None):
super(TwoLevelPC, self).view(pc, viewer)
if hasattr(self, "pc"):
viewer.printfASCII("Two level PC\n")
self.pc.view(viewer)


class HiptmairPC(TwoLevelPC):
pbrubeck marked this conversation as resolved.
Show resolved Hide resolved

_prefix = "hiptmair_"

def coarsen(self, pc):
from firedrake_citations import Citations
pbrubeck marked this conversation as resolved.
Show resolved Hide resolved
from firedrake import FunctionSpace, TestFunction, TrialFunction
from firedrake.interpolation import Interpolator
from ufl.algorithms.ad import expand_derivatives
from ufl import replace, zero, grad, curl, as_vector

Citations().register("Hiptmair1998")
appctx = self.get_appctx(pc)
V = dmhooks.get_function_space(pc.getDM())

_, P = pc.getOperators()
if P.getType() == "python":
ctx = P.getPythonContext()
a = ctx.a
bcs = tuple(ctx.bcs)
else:
ctx = dmhooks.get_appctx(pc.getDM())
problem = ctx._problem
a = problem.Jp or problem.J
bcs = tuple(problem.bcs)

mesh = V.mesh()
element = V.ufl_element()
degree = element.degree()
try:
degree = max(degree)
except TypeError:
pass
formdegree = V.finat_element.formdegree
if formdegree == 1:
celement = curl_to_grad(element)
dminus = grad
G_callback = appctx.get("get_gradient", None)
elif formdegree == 2:
celement = div_to_curl(element)
dminus = curl
if V.shape:
dminus = lambda u: as_vector([curl(u[k, ...]) for k in range(u.ufl_shape[0])])
G_callback = appctx.get("get_curl", None)
else:
raise ValueError("Hiptmair decomposition not available for", element)

coarse_space = FunctionSpace(mesh, celement)
assert coarse_space.finat_element.formdegree + 1 == formdegree
coarse_space_bcs = tuple(bc.reconstruct(V=coarse_space, g=0) for bc in bcs)

# Get only the zero-th order term of the form
beta = replace(expand_derivatives(a), {grad(t): zero(grad(t).ufl_shape) for t in a.arguments()})

test = TestFunction(coarse_space)
trial = TrialFunction(coarse_space)
coarse_operator = beta(dminus(test), dminus(trial), coefficients={})

if formdegree > 1 and degree > 1:
pefarrell marked this conversation as resolved.
Show resolved Hide resolved
shift = appctx.get("hiptmair_shift", None)
pefarrell marked this conversation as resolved.
Show resolved Hide resolved
if shift is not None:
coarse_operator += beta(test, shift*trial, coefficients={})

if G_callback is None:
from firedrake.preconditioners.hypre_ams import chop
interp_petscmat = chop(Interpolator(dminus(test), V, bcs=bcs + coarse_space_bcs).callable().handle)
else:
interp_petscmat = G_callback(V, coarse_space, bcs, coarse_space_bcs)

return coarse_operator, coarse_space_bcs, interp_petscmat


def curl_to_grad(ele):
if isinstance(ele, ufl.VectorElement):
return type(ele)(curl_to_grad(ele._sub_element), dim=ele.num_sub_elements())
elif isinstance(ele, ufl.TensorElement):
return type(ele)(curl_to_grad(ele._sub_element), shape=ele.value_shape(), symmetry=ele.symmetry())
elif isinstance(ele, ufl.MixedElement):
return type(ele)(*(curl_to_grad(e) for e in ele.sub_elements()))
elif isinstance(ele, ufl.RestrictedElement):
return ufl.RestrictedElement(curl_to_grad(ele._element), ele.restriction_domain())
else:
cell = ele.cell()
family = ele.family()
variant = ele.variant()
degree = ele.degree()
if family.startswith("Sminus"):
family = "S"
else:
family = "Lagrange"
if isinstance(degree, tuple) and isinstance(cell, ufl.TensorProductCell):
cells = ele.cell().sub_cells()
elems = [ufl.FiniteElement(family, cell=c, degree=d, variant=variant) for c, d in zip(cells, degree)]
return ufl.TensorProductElement(*elems, cell=cell)

return ufl.FiniteElement(family, cell=cell, degree=degree, variant=variant)


def div_to_curl(ele):
if isinstance(ele, ufl.VectorElement):
return type(ele)(div_to_curl(ele._sub_element), dim=ele.num_sub_elements())
elif isinstance(ele, ufl.TensorElement):
return type(ele)(div_to_curl(ele._sub_element), shape=ele.value_shape(), symmetry=ele.symmetry())
elif isinstance(ele, ufl.MixedElement):
return type(ele)(*(div_to_curl(e) for e in ele.sub_elements()))
elif isinstance(ele, ufl.RestrictedElement):
return ufl.RestrictedElement(div_to_curl(ele._element), ele.restriction_domain())
elif isinstance(ele, ufl.EnrichedElement):
return type(ele)(*(div_to_curl(e) for e in reversed(ele._elements)))
elif isinstance(ele, ufl.TensorProductElement):
return type(ele)(*(div_to_curl(e) for e in ele.sub_elements()), cell=ele.cell())
elif isinstance(ele, ufl.WithMapping):
return type(ele)(div_to_curl(ele.wrapee), ele.mapping())
elif isinstance(ele, ufl.BrokenElement):
return type(ele)(div_to_curl(ele._element))
elif isinstance(ele, ufl.HDivElement):
return ufl.HCurlElement(div_to_curl(ele._element))
elif isinstance(ele, ufl.HCurlElement):
raise ValueError("Expecting an H(div) element")
else:
degree = ele.degree()
family = ele.family()

if family in ["Lagrange", "CG", "Q"]:
pbrubeck marked this conversation as resolved.
Show resolved Hide resolved
family = "DG" if ele.cell().is_simplex() else "DQ"
degree = degree-1
elif family in ["Discontinuous Lagrange", "DG", "DQ"]:
family = "Lagrange"
degree = degree+1
elif family in ["Raviart-Thomas", "RT"]:
family = "N1curl"
elif family in ["Brezzi-Douglas-Marini", "BDM"]:
family = "N2curl"
elif family == "RTCF":
family = "RTCE"
elif family == "NCF":
family = "NCE"
elif family == "SminusF":
family = "SminusE"
elif family == "SminusDiv":
family = "SminusCurl"
else:
raise ValueError("Unexpected family %s" % family)

return ele.reconstruct(degree=degree, family=family)
14 changes: 14 additions & 0 deletions firedrake_citations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,17 @@ def print_at_exit(cls):
url = {https://www.jstor.org/stable/43693530}
}
""")

Citations().add("Hiptmair1998", """
@Misc{Hiptmair1998,
author = {Hiptmair, Ralf},
title = {{Multigrid Method for Maxwell's Equations}},
journal = {SIAM Journal on Numerical Analysis},
volume = {36},
number = {1},
pages = {204-225},
year = {1998},
doi = {10.1137/S0036142997326203},
url = {https://doi.org/10.1137/S0036142997326203},
}
""")
Loading