Skip to content

Commit

Permalink
add bcs to Interpolator
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 13, 2023
1 parent a81fa00 commit 1d333bd
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 27 deletions.
24 changes: 17 additions & 7 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,17 @@ class Interpolator(object):
:class:`Interpolator` is also collected).
"""
def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE):
def __init__(self, expr, V, subset=None, freeze_expr=False, access=op2.WRITE, bcs=None):
try:
self.callable, arguments = make_interpolator(expr, V, subset, access)
self.callable, arguments = make_interpolator(expr, V, subset, access, bcs=bcs)
except FIAT.hdiv_trace.TraceError:
raise NotImplementedError("Can't interpolate onto traces sorry")
self.arguments = arguments
self.nargs = len(arguments)
self.freeze_expr = freeze_expr
self.expr = expr
self.V = V
self.bcs = bcs

@PETSc.Log.EventDecorator()
@annotate_interpolate
Expand Down Expand Up @@ -154,7 +155,7 @@ def interpolate(self, *function, output=None, transpose=False):


@PETSc.Log.EventDecorator()
def make_interpolator(expr, V, subset, access):
def make_interpolator(expr, V, subset, access, bcs=None):
assert isinstance(expr, ufl.classes.Expr)

arguments = extract_arguments(expr)
Expand Down Expand Up @@ -215,7 +216,10 @@ def make_interpolator(expr, V, subset, access):
if len(V) > 1:
raise NotImplementedError(
"UFL expressions for mixed functions are not yet supported.")
loops.extend(_interpolator(V, tensor, expr, subset, arguments, access))
loops.extend(_interpolator(V, tensor, expr, subset, arguments, access, bcs=bcs))

if bcs and len(arguments) == 0:
loops.extend([partial(bc.apply, f) for bc in bcs])

def callable(loops, f):
for l in loops:
Expand All @@ -226,7 +230,7 @@ def callable(loops, f):


@utils.known_pyop2_safe
def _interpolator(V, tensor, expr, subset, arguments, access):
def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
try:
expr = ufl.as_ufl(expr)
except ufl.UFLException:
Expand Down Expand Up @@ -343,14 +347,20 @@ def _interpolator(V, tensor, expr, subset, arguments, access):
else:
assert access == op2.WRITE # Other access descriptors not done for Matrices.
rows_map = V.cell_node_map()
columns_map = arguments[0].function_space().cell_node_map()
Vcol = arguments[0].function_space()
columns_map = Vcol.cell_node_map()
if target_mesh is not source_mesh:
# Since the par_loop is over the target mesh cells we need to
# compose a map that takes us from target mesh cells to the
# function space nodes on the source mesh.
columns_map = compose_map_and_cache(target_mesh.cell_parent_cell_map,
columns_map)
parloop_args.append(tensor(op2.WRITE, (rows_map, columns_map)))
lgmaps = None
if bcs:
bc_rows = [bc for bc in bcs if bc.function_space() == V]
bc_cols = [bc for bc in bcs if bc.function_space() == Vcol]
lgmaps = [(V.local_to_global_map(bc_rows), Vcol.local_to_global_map(bc_cols))]
parloop_args.append(tensor(op2.WRITE, (rows_map, columns_map), lgmaps=lgmaps))
if oriented:
co = target_mesh.cell_orientations()
parloop_args.append(co.dat(op2.READ, co.cell_node_map()))
Expand Down
18 changes: 4 additions & 14 deletions firedrake/preconditioners/hiptmair.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,12 @@ def coarsen(self, pc):
if P.getType() == "python":
ctx = P.getPythonContext()
a = ctx.a
bcs = ctx.bcs
bcs = tuple(ctx.bcs)
else:
ctx = dmhooks.get_appctx(pc.getDM())
problem = ctx._problem
a = problem.Jp or problem.J
bcs = problem.bcs
bcs = tuple(problem.bcs)

mesh = V.mesh()
element = V.ufl_element()
Expand All @@ -164,7 +164,7 @@ def coarsen(self, pc):

coarse_space = FunctionSpace(mesh, celement)
assert coarse_space.finat_element.formdegree + 1 == formdegree
coarse_space_bcs = [bc.reconstruct(V=coarse_space, g=0) for bc in bcs]
coarse_space_bcs = tuple([bc.reconstruct(V=coarse_space, g=0) for bc in bcs])

# Get only the zero-th order term of the form
beta = replace(expand_derivatives(a), {grad(t): zero(grad(t).ufl_shape) for t in a.arguments()})
Expand All @@ -179,18 +179,8 @@ def coarsen(self, pc):
coarse_operator += beta(test, shift*trial, coefficients={})

if G_callback is None:
from firedrake import Function
from firedrake.preconditioners.hypre_ams import chop

interp_petscmat = chop(Interpolator(dminus(test), V).callable().handle)
# FIXME bcs should be imposed during the assembly
cmask = Function(coarse_space)
with cmask.dat.vec as R:
R.set(1)
for bc in coarse_space_bcs:
bc.zero(cmask)
with cmask.dat.vec as R:
interp_petscmat.diagonalScale(R=R)
interp_petscmat = chop(Interpolator(dminus(test), V, bcs=bcs + coarse_space_bcs).callable().handle)
else:
interp_petscmat = G_callback(V, coarse_space, bcs, coarse_space_bcs)

Expand Down
7 changes: 1 addition & 6 deletions tests/multigrid/test_hiptmair.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from firedrake import *
import numpy
import pytest


Expand Down Expand Up @@ -36,8 +35,6 @@ def run_riesz_map(V, mat_type):
"snes_type": "ksponly",
"ksp_type": "cg",
"ksp_norm_type": "natural",
"ksp_monitor": None,
"ksp_view_eigenvalues": None,
"pc_type": "mg",
"mg_coarse": coarse,
"mg_levels": {
Expand All @@ -55,10 +52,8 @@ def run_riesz_map(V, mat_type):
assert sobolev in [HCurl, HDiv]
d = div if sobolev == HDiv else curl

x = SpatialCoordinate(V.mesh())
u_exact = Constant((1,2,4))
u_exact = Constant((1, 2, 4))
f = u_exact
# f = -grad(div(u_exact)) + u_exact if sobolev == HDiv else curl(curl(u_exact)) + u_exact

u = Function(V)
v = TestFunction(V)
Expand Down

0 comments on commit 1d333bd

Please sign in to comment.