Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion firedrake/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from firedrake import utils
from firedrake.adjoint_utils import FunctionMixin
from firedrake.petsc import PETSc
from firedrake.mesh import MeshGeometry, VertexOnlyMesh
from firedrake.mesh import MeshGeometry, VertexOnlyMesh, VertexOnlyMeshTopology
from firedrake.functionspace import FunctionSpace, VectorFunctionSpace, TensorFunctionSpace


Expand Down Expand Up @@ -282,6 +282,9 @@ def __init__(self, function_space, val=None, name=None, dtype=ScalarType,
if isinstance(function_space, Function):
self.assign(function_space)

if isinstance(V._mesh, VertexOnlyMeshTopology):
V._mesh.register_field(self)

@property
def topological(self):
r"""The underlying coordinateless function."""
Expand Down
149 changes: 112 additions & 37 deletions firedrake/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -1904,8 +1904,17 @@ def __init__(self, swarm, parentmesh, name, reorder, input_ordering_swarm=None,
"overlap_type": (DistributedMeshOverlapType.NONE, 0)}
self.input_ordering_swarm = input_ordering_swarm
self._parent_mesh = parentmesh
self._fields = weakref.WeakSet()
super().__init__(swarm, name, reorder, None, perm_is, distribution_name, permutation_name, parentmesh.comm)

def register_field(self, f) -> None:
self._fields.add(f)

def update_fields(self) -> None:
# Update all registered fields after VOM has moved
for field in self._fields:
print(f"Hello from {field.name()}, defined on {field.function_space()}")

def _distribute(self):
pass

Expand Down Expand Up @@ -2175,6 +2184,70 @@ def input_ordering_without_halos_sf(self):
# cells first; self.cell_set.size is the number of rank-local non-halo cells.
return self.input_ordering_sf.createEmbeddedLeafSF(np.arange(self.cell_set.size, dtype=IntType))

def _update_swarm(
self, new_coords: np.ndarray, new_global_idxs: np.ndarray, new_ref_coords: np.ndarray,
new_parent_cell_nums: np.ndarray, new_ranks: np.ndarray
) -> None:
"""Updates the VOM's DMSwarm to new coordinates. Assumes there are N new coordinates,
where N is the total number of vertices in the VOM.

Parameters
----------
new_coords : np.array
An (N, gdim) array of new global coordinates for the N vertices,
where gdim is the geometric dimension of the VOM / parent mesh.
new_global_idxs : np.array
An (N,) array of new global indices for the N vertices.
new_ref_coords : np.array
An (N, tdim) array of new reference coordinates for the N vertices,
where tdim is the topological dimension of the parent mesh.
new_parent_cell_nums : np.array
An (N,) array of new parent cell numbers for the N vertices.
new_ranks : np.array
An (N,) array of new MPI ranks for the N vertices.

"""
num_vertices = new_global_idxs.shape[0]
gdim = self.geometric_dimension()

if num_vertices != new_coords.shape[0]:
raise ValueError("Number of new coordinates does not match number of global indices")
if gdim != new_coords.shape[1]:
raise ValueError("New coordinates do not have the same geometric dimension as the mesh")
if num_vertices == 0:
raise ValueError("No points to move")
if np.unique(new_global_idxs).shape[0] != num_vertices:
raise ValueError("global_indices must be unique")
if isinstance(self._parent_mesh, ExtrudedMeshTopology):
raise NotImplementedError("move_points is not implemented for extruded meshes yet")

# mesh.topology.cell_closure[:, -1] maps Firedrake cell numbers to DMplex numbers
new_plex_parent_cell_nums = self._parent_mesh.topology.cell_closure[new_parent_cell_nums, -1]

tdim = self._parent_mesh.topological_dimension
swarm = self.topology_dm

current_coords = swarm.getField("DMSwarmPIC_coor").reshape((num_vertices, gdim))
current_dmplex_parent_cell_nums = swarm.getField("DMSwarm_cellid").ravel()
current_parent_cell_nums = swarm.getField("parentcellnum").ravel()
current_ref_coords = swarm.getField("refcoord").reshape((num_vertices, tdim))
current_global_idxs = swarm.getField("globalindex").ravel()
current_ranks = swarm.getField("DMSwarm_rank").ravel()

current_coords[...] = new_coords
current_dmplex_parent_cell_nums[...] = new_plex_parent_cell_nums
current_parent_cell_nums[...] = new_parent_cell_nums
current_ref_coords[...] = new_ref_coords
current_global_idxs[...] = new_global_idxs
current_ranks[...] = new_ranks

swarm.restoreField("DMSwarm_rank")
swarm.restoreField("globalindex")
swarm.restoreField("refcoord")
swarm.restoreField("parentcellnum")
swarm.restoreField("DMSwarmPIC_coor")
swarm.restoreField("DMSwarm_cellid")


class CellOrientationsRuntimeError(RuntimeError):
"""Exception raised when there are problems with cell orientations."""
Expand Down Expand Up @@ -3456,22 +3529,23 @@ def other_fields(self, fields):


def _pic_swarm_in_mesh(
parent_mesh,
coords,
fields=None,
tolerance=None,
redundant=True,
exclude_halos=True,
):
"""Create a Particle In Cell (PIC) DMSwarm immersed in a Mesh
parent_mesh: AbstractMeshTopology,
coords: np.ndarray,
fields: list[Tuple[str, int, np.dtype]] | None = None,
tolerance: float | None = None,
redundant: bool = True,
exclude_halos: bool = True,
) -> Tuple[FiredrakeDMSwarm, FiredrakeDMSwarm, int]:
"""Creates a Particle In Cell (PIC) DMSwarm immersed in a Mesh.

This should only by used for meshes with straight edges. If not, the
particles may be placed in the wrong cells.

:arg parent_mesh: the :class:`Mesh` within with the DMSwarm should be
immersed.
:arg coords: an ``ndarray`` of (npoints, coordsdim) shape.
:kwarg fields: An optional list of named data which can be stored for each
Parameters
----------
parent_mesh
The parent mesh in which the DMSwarm should be immersed.
coords
An array of shape (npoints, coordsdim) defining the point coordinates.
fields
An optional list of named data which can be stored for each
point in the DMSwarm. The format should be::

[(fieldname1, blocksize1, dtype1),
Expand All @@ -3484,29 +3558,35 @@ def _pic_swarm_in_mesh(
RealType)]``. All fields must have the same number of points. For more
information see `the DMSWARM API reference
<https://petsc.org/release/manualpages/DMSwarm/DMSWARM/>_.
:kwarg tolerance: The relative tolerance (i.e. as defined on the reference
tolerance
The relative tolerance (i.e. as defined on the reference
cell) for the distance a point can be from a cell and still be
considered to be in the cell. Note that this tolerance uses an L1
distance (aka 'manhattan', 'taxicab' or rectilinear distance) so
will scale with the dimension of the mesh. The default is the parent
mesh's ``tolerance`` property. Changing this from default will
cause the parent mesh's spatial index to be rebuilt which can take some
time.
:kwarg redundant: If True, the DMSwarm will be created using only the
points specified on MPI rank 0.
:kwarg exclude_halos: If True, the DMSwarm will not contain any points in
redundant
If True, the DMSwarm will be created using only the
points specified on MPI rank 0. Defaults to True.
exclude_halos
If True, the DMSwarm will not contain any points in
the mesh halos. If False, it will but the global index of the points
in the halos will match a global index of a point which is not in the
halo.
:returns: (swarm, input_ordering_swarm, n_missing_points)
- swarm: the immersed DMSwarm
- input_ordering_swarm: a DMSwarm with points in the same order and with the
same rank decomposition as the supplied ``coords`` argument. This
includes any points which are not found in the parent mesh! Note
that if ``redundant=True``, all points in the generated DMSwarm
will be found on rank 0 since that was where they were taken from.
- n_missing_points: the number of points in the supplied ``coords``
argument which were not found in the parent mesh.
halo. Defaults to True.

Returns
-------
(swarm, input_ordering_swarm, n_missing_points)
- swarm: the immersed DMSwarm
- input_ordering_swarm: a DMSwarm with points in the same order and with the
same rank decomposition as the supplied ``coords`` argument. This
includes any points which are not found in the parent mesh! Note
that if ``redundant=True``, all points in the generated DMSwarm
will be found on rank 0 since that was where they were taken from.
- n_missing_points: the number of points in the supplied ``coords``
argument which were not found in the parent mesh.

.. note::

Expand Down Expand Up @@ -3569,9 +3649,7 @@ def _pic_swarm_in_mesh(
directly with PETSc's DMSwarm API. For the ``swarm`` output, this is
the parent mesh's topology DM (in most cases a DMPlex). For the
``input_ordering_swarm`` output, this is the ``swarm`` itself.

"""

if tolerance is None:
tolerance = parent_mesh.tolerance
else:
Expand Down Expand Up @@ -4113,12 +4191,9 @@ def _parent_mesh_embedding(
(ncoords_global, coords.shape[1]), dtype=coords_local.dtype
)
parent_mesh._comm.Allgatherv(coords_local, (coords_global, coords_local_sizes))
# # ncoords_local_allranks is in rank order so we can just sum up the
# # previous ranks to get the starting index for the global numbering.
# # For rank 0 we make use of the fact that sum([]) = 0.
# startidx = sum(ncoords_local_allranks[:parent_mesh._comm.rank])
# endidx = startidx + ncoords_local
# global_idxs_global = np.arange(startidx, endidx)
# ncoords_local_allranks is in rank order so we can just sum up the
# previous ranks to get the starting index for the global numbering.
# For rank 0 we make use of the fact that sum([]) = 0.
global_idxs_global = np.arange(coords_global.shape[0])
input_coords_idxs_local = np.arange(ncoords_local)
input_coords_idxs_global = np.empty(ncoords_global, dtype=int)
Expand Down
7 changes: 4 additions & 3 deletions tests/firedrake/vertexonly/test_swarm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from firedrake import *
import firedrake.mesh as fd_mesh
from firedrake.utils import IntType, RealType
import pytest
import numpy as np
Expand Down Expand Up @@ -170,9 +171,9 @@ def test_pic_swarm_in_mesh(parentmesh, redundant, exclude_halos):
# global cell midpoints only on rank 0. Note that this is the default
# behaviour so it needn't be specified explicitly.
if MPI.COMM_WORLD.rank == 0:
swarm, original_swarm, n_missing_coords = mesh._pic_swarm_in_mesh(parentmesh, inputpointcoords, fields=other_fields, exclude_halos=exclude_halos)
swarm, original_swarm, n_missing_coords = fd_mesh._pic_swarm_in_mesh(parentmesh, inputpointcoords, fields=other_fields, exclude_halos=exclude_halos)
else:
swarm, original_swarm, n_missing_coords = mesh._pic_swarm_in_mesh(parentmesh, np.empty(inputpointcoords.shape), fields=other_fields, exclude_halos=exclude_halos)
swarm, original_swarm, n_missing_coords = fd_mesh._pic_swarm_in_mesh(parentmesh, np.empty(inputpointcoords.shape), fields=other_fields, exclude_halos=exclude_halos)
input_rank = 0
# inputcoordindices is the correct set of input indices for
# redundant==True but I need to work out where they will be after
Expand All @@ -191,7 +192,7 @@ def test_pic_swarm_in_mesh(parentmesh, redundant, exclude_halos):
# When redundant == False we expect the same behaviour by only
# supplying the local cell midpoints on each MPI ranks. Note that this
# is not the default behaviour so it must be specified explicitly.
swarm, original_swarm, n_missing_coords = mesh._pic_swarm_in_mesh(parentmesh, inputlocalpointcoords, fields=other_fields, redundant=redundant, exclude_halos=exclude_halos)
swarm, original_swarm, n_missing_coords = fd_mesh._pic_swarm_in_mesh(parentmesh, inputlocalpointcoords, fields=other_fields, redundant=redundant, exclude_halos=exclude_halos)
input_rank = parentmesh.comm.rank
input_local_coord_indices = np.arange(len(inputlocalpointcoords))

Expand Down