Skip to content

Commit

Permalink
submesh: enable cell submesh
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Aug 7, 2024
1 parent 5ab86d9 commit 37f79a2
Show file tree
Hide file tree
Showing 8 changed files with 884 additions and 119 deletions.
8 changes: 4 additions & 4 deletions firedrake/__future__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ufl.domain import as_domain, extract_unique_domain
from ufl.algorithms import extract_arguments
from firedrake.mesh import VertexOnlyMeshTopology
from firedrake.mesh import MeshTopology, VertexOnlyMeshTopology
from firedrake.interpolation import (interpolate as interpolate_old,
Interpolator as InterpolatorOld,
SameMeshInterpolator as SameMeshInterpolatorOld,
Expand All @@ -16,13 +16,13 @@ class Interpolator(InterpolatorOld):
def __new__(cls, expr, V, **kwargs):
target_mesh = as_domain(V)
source_mesh = extract_unique_domain(expr) or target_mesh
if target_mesh is not source_mesh:
if target_mesh is source_mesh or all(isinstance(m.topology, MeshTopology) for m in [target_mesh, source_mesh]) and target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1]:
return object.__new__(SameMeshInterpolator)
else:
if isinstance(target_mesh.topology, VertexOnlyMeshTopology):
return object.__new__(SameMeshInterpolator)
else:
return object.__new__(CrossMeshInterpolator)
else:
return object.__new__(SameMeshInterpolator)

interpolate = InterpolatorOld._interpolate_future

Expand Down
33 changes: 19 additions & 14 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,20 +1338,22 @@ def _make_maps_and_regions(self):
test, trial = self._form.arguments()
if self._allocation_integral_types is not None:
return ExplicitMatrixAssembler._make_maps_and_regions_default(test, trial, self._allocation_integral_types)
elif any(local_kernel.indices == (None, None) for local_kernel in self._all_local_kernels):
elif any(local_kernel.indices == (None, None) for local_kernel, _ in self._all_local_kernels):
# Handle special cases: slate or split=False
assert all(local_kernel.indices == (None, None) for local_kernel in self._all_local_kernels)
assert all(local_kernel.indices == (None, None) for local_kernel, _ in self._all_local_kernels)
allocation_integral_types = set(local_kernel.kinfo.integral_type
for local_kernel in self._all_local_kernels)
for local_kernel, _ in self._all_local_kernels)
return ExplicitMatrixAssembler._make_maps_and_regions_default(test, trial, allocation_integral_types)
else:
maps_and_regions = defaultdict(lambda: defaultdict(set))
for local_kernel in self._all_local_kernels:
for local_kernel, subdomain_id in self._all_local_kernels:
i, j = local_kernel.indices
mesh = self._form.ufl_domains()[local_kernel.kinfo.domain_number]
# Make Sparsity independent of _iterset, which can be a Subset, for better reusability.
integral_type = local_kernel.kinfo.integral_type
rmap_ = test.function_space().topological[i].entity_node_map(integral_type)
cmap_ = trial.function_space().topological[j].entity_node_map(integral_type)
all_subdomain_ids = self.all_integer_subdomain_ids
rmap_ = test.function_space().topological[i].entity_node_map(mesh.topology, integral_type, subdomain_id, all_subdomain_ids)
cmap_ = trial.function_space().topological[j].entity_node_map(mesh.topology, integral_type, subdomain_id, all_subdomain_ids)
region = ExplicitMatrixAssembler._integral_type_region_map[integral_type]
maps_and_regions[(i, j)][(rmap_, cmap_)].add(region)
return {block_indices: [map_pair + (tuple(region_set), ) for map_pair, region_set in map_pair_to_region_set.items()]
Expand All @@ -1367,8 +1369,11 @@ def _make_maps_and_regions_default(test, trial, allocation_integral_types):
# Use outer product of component maps.
for integral_type in allocation_integral_types:
region = ExplicitMatrixAssembler._integral_type_region_map[integral_type]
for i, rmap_ in enumerate(test.function_space().topological.entity_node_map(integral_type)):
for j, cmap_ in enumerate(trial.function_space().topological.entity_node_map(integral_type)):
for i, Vrow in enumerate(test.function_space()):
for j, Vcol in enumerate(trial.function_space()):
mesh = Vrow.mesh()
rmap_ = Vrow.topological.entity_node_map(mesh.topology, integral_type, None, None)
cmap_ = Vcol.topological.entity_node_map(mesh.topology, integral_type, None, None)
maps_and_regions[(i, j)][(rmap_, cmap_)].add(region)
return {block_indices: [map_pair + (tuple(region_set), ) for map_pair, region_set in map_pair_to_region_set.items()]
for block_indices, map_pair_to_region_set in maps_and_regions.items()}
Expand All @@ -1390,7 +1395,7 @@ def _all_local_kernels(self):
When constructing sparsity, we use all parloop_builders
that are to be used in the actual assembly.
"""
all_local_kernels = tuple(local_kernel for local_kernel, _ in self.local_kernels)
all_local_kernels = self.local_kernels
for bc in self._bcs:
if isinstance(bc, EquationBCSplit):
_assembler = type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False)
Expand Down Expand Up @@ -1555,7 +1560,7 @@ def __init__(self, form, local_knl, subdomain_id, all_integer_subdomain_ids, dia
self._form = form
self._indices, self._kinfo = local_knl
self._subdomain_id = subdomain_id
self._all_integer_subdomain_ids = all_integer_subdomain_ids.get(self._kinfo.integral_type, None)
self._all_integer_subdomain_ids = all_integer_subdomain_ids
self._diagonal = diagonal
self._unroll = unroll

Expand Down Expand Up @@ -1611,7 +1616,7 @@ def _needs_subset(self):
if self._subdomain_id == "everywhere":
return False
elif self._subdomain_id == "otherwise":
return self._all_integer_subdomain_ids is not None
return self._all_integer_subdomain_ids.get(self._kinfo.integral_type, None) is not None
else:
return True

Expand All @@ -1631,7 +1636,7 @@ def _get_dim(self, finat_element):

def _make_dat_global_kernel_arg(self, V, index=None):
finat_element = create_element(V.ufl_element())
map_arg = V.topological.entity_node_map(self._integral_type)._global_kernel_arg
map_arg = V.topological.entity_node_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids)._global_kernel_arg
if isinstance(finat_element, finat.EnrichedElement) and finat_element.is_mixed:
assert index is None
subargs = tuple(self._make_dat_global_kernel_arg(Vsub, index=index)
Expand All @@ -1649,7 +1654,7 @@ def _make_mat_global_kernel_arg(self, Vrow, Vcol):
shape = len(relem.elements), len(celem.elements)
return op2.MixedMatKernelArg(subargs, shape)
else:
rmap_arg, cmap_arg = (V.topological.entity_node_map(self._integral_type)._global_kernel_arg for V in [Vrow, Vcol])
rmap_arg, cmap_arg = (V.topological.entity_node_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids)._global_kernel_arg for V in [Vrow, Vcol])
# PyOP2 matrix objects have scalar dims so we flatten them here
rdim = numpy.prod(self._get_dim(relem), dtype=int)
cdim = numpy.prod(self._get_dim(celem), dtype=int)
Expand Down Expand Up @@ -1952,7 +1957,7 @@ def _iterset(self):
def _get_map(self, V):
"""Return the appropriate PyOP2 map for a given function space."""
assert isinstance(V, (WithGeometry, FiredrakeDualSpace, FunctionSpace))
return V.entity_node_map(self._integral_type)
return V.topological.entity_node_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids)

def _as_parloop_arg(self, tsfc_arg):
"""Return a :class:`op2.ParloopArg` corresponding to the provided
Expand Down
241 changes: 241 additions & 0 deletions firedrake/cython/dmcommon.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3623,3 +3623,244 @@ def create_halo_exchange_sf(PETSc.DM dm):
halo_exchange_sf = PETSc.SF().create(comm=point_sf.comm)
CHKERR(PetscSFSetGraph(halo_exchange_sf.sf, m, n, dof_ilocal, PETSC_OWN_POINTER, dof_iremote, PETSC_OWN_POINTER))
return halo_exchange_sf


# -- submesh --


@cython.boundscheck(False)
@cython.wraparound(False)
def submesh_create(PETSc.DM dm,
label_name,
PetscInt label_value):
"""Create submesh.
Parameters
----------
dm : PETSc.DM
DMPlex representing the mesh topology
label_name : str
Name of the label
label_value : int
Value in the label
"""
cdef:
PETSc.DM subdm = PETSc.DMPlex()
PETSc.DMLabel label
PETSc.SF ownership_transfer_sf = PETSc.SF()

label = dm.getLabel(label_name)
CHKERR(DMPlexFilter(dm.dm, label.dmlabel, label_value, PETSC_FALSE, PETSC_TRUE, &ownership_transfer_sf.sf, &subdm.dm))
submesh_update_facet_labels(dm, subdm)
submesh_correct_entity_classes(dm, subdm, ownership_transfer_sf)
return subdm


@cython.boundscheck(False)
@cython.wraparound(False)
def submesh_correct_entity_classes(PETSc.DM dm,
PETSc.DM subdm,
PETSc.SF ownership_transfer_sf):
"""Correct pyop2 entity classes.
Parameters
----------
dm : PETSc.DM
The original DM.
subdm : PETSc.DM
The subdm.
ownership_transfer_sf : PETSc.SF
The ownership transfer sf.
"""
cdef:
PetscInt pStart, pEnd, p, subpStart, subpEnd, subp, nsubpoints
PetscInt nroots, nleaves, i
const PetscInt *ilocal = NULL
const PetscSFNode *iremote = NULL
PETSc.IS subpoint_is
const PetscInt *subpoint_indices = NULL
np.ndarray[PetscInt, ndim=1, mode="c"] ownership_loss
np.ndarray[PetscInt, ndim=1, mode="c"] ownership_gain
DMLabel lbl_core, lbl_owned, lbl_ghost
PetscBool has

if dm.comm.size == 1:
return
CHKERR(DMPlexGetChart(dm.dm, &pStart, &pEnd))
CHKERR(DMPlexGetChart(subdm.dm, &subpStart, &subpEnd))
CHKERR(PetscSFGetGraph(ownership_transfer_sf.sf, &nroots, &nleaves, &ilocal, &iremote))
assert nroots == pEnd - pStart
assert pStart == 0
ownership_loss = np.zeros(pEnd - pStart, dtype=IntType)
ownership_gain = np.zeros(pEnd - pStart, dtype=IntType)
for i in range(nleaves):
p = ilocal[i] if ilocal != NULL else i
ownership_loss[p] = 1
unit = MPI._typedict[np.dtype(IntType).char]
ownership_transfer_sf.reduceBegin(unit, ownership_loss, ownership_gain, MPI.REPLACE)
ownership_transfer_sf.reduceEnd(unit, ownership_loss, ownership_gain, MPI.REPLACE)
subpoint_is = subdm.getSubpointIS()
CHKERR(ISGetSize(subpoint_is.iset, &nsubpoints))
assert nsubpoints == subpEnd - subpStart
assert subpStart == 0
CHKERR(ISGetIndices(subpoint_is.iset, &subpoint_indices))
CHKERR(DMGetLabel(subdm.dm, b"pyop2_core", &lbl_core))
CHKERR(DMGetLabel(subdm.dm, b"pyop2_owned", &lbl_owned))
CHKERR(DMGetLabel(subdm.dm, b"pyop2_ghost", &lbl_ghost))
CHKERR(DMLabelCreateIndex(lbl_core, subpStart, subpEnd))
CHKERR(DMLabelCreateIndex(lbl_owned, subpStart, subpEnd))
CHKERR(DMLabelCreateIndex(lbl_ghost, subpStart, subpEnd))
for subp in range(subpStart, subpEnd):
p = subpoint_indices[subp]
if ownership_loss[p] == 1:
CHKERR(DMLabelHasPoint(lbl_core, subp, &has))
assert has == PETSC_FALSE
CHKERR(DMLabelHasPoint(lbl_owned, subp, &has))
assert has == PETSC_TRUE
CHKERR(DMLabelClearValue(lbl_owned, subp, 1))
CHKERR(DMLabelSetValue(lbl_ghost, subp, 1))
if ownership_gain[p] == 1:
CHKERR(DMLabelHasPoint(lbl_core, subp, &has))
assert has == PETSC_FALSE
CHKERR(DMLabelHasPoint(lbl_ghost, subp, &has))
assert has == PETSC_TRUE
CHKERR(DMLabelClearValue(lbl_ghost, subp, 1))
CHKERR(DMLabelSetValue(lbl_owned, subp, 1))
CHKERR(DMLabelDestroyIndex(lbl_core))
CHKERR(DMLabelDestroyIndex(lbl_owned))
CHKERR(DMLabelDestroyIndex(lbl_ghost))
CHKERR(ISRestoreIndices(subpoint_is.iset, &subpoint_indices))


@cython.boundscheck(False)
@cython.wraparound(False)
def submesh_update_facet_labels(PETSc.DM dm, PETSc.DM subdm):
"""Update facet labels of subdm taking the new exterior facet points into account.
Parameters
----------
dm : PETSc.DM
The parent dm.
subdm : PETSc.DM
The subdm.
Notes
-----
This function marks the new exterior facets with current max label value + 1 in "Face Sets".
"""
cdef:
PetscInt dim, subdim, pStart, pEnd, f, subfStart, subfEnd, subf, sub_ext_facet_size, next_label_val, i
PETSc.IS subpoint_is
PETSc.IS sub_ext_facet_is
const PetscInt *subpoint_indices = NULL
const PetscInt *sub_ext_facet_indices = NULL
char *int_facet_label_name = <char *>"interior_facets"
char *ext_facet_label_name = <char *>"exterior_facets"
char *face_sets_label_name = <char *>"Face Sets"
DMLabel ext_facet_label
PETSc.DMLabel sub_int_facet_label, sub_ext_facet_label
PetscBool has_point

# Mark interior and exterior facets
label_facets(subdm)
sub_int_facet_label = subdm.getLabel("interior_facets")
sub_ext_facet_label = subdm.getLabel("exterior_facets")
# Mark new exterior facets with current max label value + 1 in "Face Sets"
dim = dm.getDimension()
subdim = subdm.getDimension()
subpoint_is = subdm.getSubpointIS()
CHKERR(ISGetIndices(subpoint_is.iset, &subpoint_indices))
if subdim == dim:
with dm.getLabelIdIS(FACE_SETS_LABEL) as label_value_indices:
next_label_val = label_value_indices.max() + 1 if len(label_value_indices) > 0 else 0
next_label_val = dm.comm.tompi4py().allreduce(next_label_val, op=MPI.MAX)
subdm.createLabel(FACE_SETS_LABEL)
sub_ext_facet_size = subdm.getStratumSize("exterior_facets", 1)
sub_ext_facet_is = subdm.getStratumIS("exterior_facets", 1)
if sub_ext_facet_is.iset:
CHKERR(ISGetIndices(sub_ext_facet_is.iset, &sub_ext_facet_indices))
CHKERR(DMGetLabel(dm.dm, ext_facet_label_name, &ext_facet_label))
pStart, pEnd = dm.getChart()
CHKERR(DMLabelCreateIndex(ext_facet_label, pStart, pEnd))
subfStart, subfEnd = subdm.getHeightStratum(1)
for i in range(sub_ext_facet_size):
subf = sub_ext_facet_indices[i]
if subf < subfStart or subf >= subfEnd:
continue
f = subpoint_indices[subf]
CHKERR(DMLabelHasPoint(ext_facet_label, f, &has_point))
if not has_point:
# Found a new exterior facet
CHKERR(DMSetLabelValue(subdm.dm, face_sets_label_name, subf, next_label_val))
CHKERR(DMLabelDestroyIndex(ext_facet_label))
if sub_ext_facet_is.iset:
CHKERR(ISRestoreIndices(sub_ext_facet_is.iset, &sub_ext_facet_indices))
else:
raise NotImplementedError("Currently, only implemented for cell submesh")
CHKERR(ISRestoreIndices(subpoint_is.iset, &subpoint_indices))
subdm.removeLabel("interior_facets")
subdm.removeLabel("exterior_facets")


@cython.boundscheck(False)
@cython.wraparound(False)
def submesh_create_cell_closure_cell_submesh(PETSc.DM subdm,
PETSc.DM dm,
PETSc.Section subcell_numbering,
PETSc.Section cell_numbering,
np.ndarray[PetscInt, ndim=2, mode="c"] cell_closure):
"""Inherit cell_closure from parent.
Parameters
----------
subdm : PETSc.DM
The subdm.
dm : PETSc.DM
The parent dm.
subcell_numbering : PETSc.Section
The cell_numbering of the submesh.
cell_numbering : PETSc.Section
The cell_numbering of the parent mesh.
cell_closure : numpy.ndarray
The cell_closure of the parent mesh.
"""
cdef:
PETSc.IS subpoint_is
const PetscInt *subpoint_indices = NULL
PetscInt *subpoint_indices_inv = NULL
PetscInt subpStart, subpEnd, subp, subcStart, subcEnd, subc, subcell
PetscInt pStart, pEnd, p, cStart, cEnd, c, cell
PetscInt nclosure, cl
np.ndarray[PetscInt, ndim=2, mode="c"] subcell_closure

get_chart(subdm.dm, &subpStart, &subpEnd)
get_height_stratum(subdm.dm, 0, &subcStart, &subcEnd)
get_chart(dm.dm, &pStart, &pEnd)
get_height_stratum(dm.dm, 0, &cStart, &cEnd)
subpoint_is = subdm.getSubpointIS()
CHKERR(ISGetIndices(subpoint_is.iset, &subpoint_indices))
CHKERR(PetscMalloc1(pEnd - pStart, &subpoint_indices_inv))
for p in range(pStart, pEnd):
subpoint_indices_inv[p - pStart] = -1
for subp in range(subpStart, subpEnd):
subpoint_indices_inv[subpoint_indices[subp] - pStart] = subp
nclosure = cell_closure.shape[1]
subcell_closure = np.empty((subcEnd - subcStart, nclosure), dtype=IntType)
for subc in range(subcStart, subcEnd):
c = subpoint_indices[subc]
CHKERR(PetscSectionGetOffset(subcell_numbering.sec, subc, &subcell))
CHKERR(PetscSectionGetOffset(cell_numbering.sec, c, &cell))
for cl in range(nclosure):
p = cell_closure[cell, cl]
subp = subpoint_indices_inv[p]
if subp >= 0:
subcell_closure[subcell, cl] = subp
else:
raise RuntimeError(f"subcell = {subcell}, cell = {cell}, p = {p}, subp = {subp}")
CHKERR(PetscFree(subpoint_indices_inv))
CHKERR(ISRestoreIndices(subpoint_is.iset, &subpoint_indices))
return subcell_closure
8 changes: 8 additions & 0 deletions firedrake/cython/petschdr.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ cdef extern from "petscdmplex.h" nogil:
int DMPlexSetAdjacencyUser(PETSc.PetscDM,int(*)(PETSc.PetscDM,PetscInt,PetscInt*,PetscInt[],void*),void*)
int DMPlexCreatePointNumbering(PETSc.PetscDM,PETSc.PetscIS*)
int DMPlexLabelComplete(PETSc.PetscDM, PETSc.PetscDMLabel)
int DMPlexDistributeOverlap(PETSc.PetscDM,PetscInt,PETSc.PetscSF*,PETSc.PetscDM*)

int DMPlexFilter(PETSc.PetscDM,PETSc.PetscDMLabel,PetscInt,PetscBool,PetscBool,PETSc.PetscSF*,PETSc.PetscDM*)
int DMPlexGetSubpointIS(PETSc.PetscDM,PETSc.PetscIS*)
int DMPlexGetSubpointMap(PETSc.PetscDM,PETSc.PetscDMLabel*)
int DMPlexSetSubpointMap(PETSc.PetscDM,PETSc.PetscDMLabel)

cdef extern from "petscdmlabel.h" nogil:
struct _n_DMLabel
Expand All @@ -71,6 +77,8 @@ cdef extern from "petscdm.h" nogil:
int DMCreateLabel(PETSc.PetscDM,char[])
int DMGetLabel(PETSc.PetscDM,char[],DMLabel*)
int DMGetPointSF(PETSc.PetscDM,PETSc.PetscSF*)
int DMSetLabelValue(PETSc.PetscDM,char[],PetscInt,PetscInt)
int DMGetLabelValue(PETSc.PetscDM,char[],PetscInt,PetscInt*)

cdef extern from "petscdmswarm.h" nogil:
int DMSwarmGetLocalSize(PETSc.PetscDM,PetscInt*)
Expand Down
Loading

0 comments on commit 37f79a2

Please sign in to comment.