Skip to content

Commit

Permalink
Implement per communicator universal identifiers for Firedrake objects (
Browse files Browse the repository at this point in the history
#3633)

* Implement per communicator universal identifiers for Firedrake objects
  • Loading branch information
JDBetteridge authored and Ig-dolci committed Jul 6, 2024
1 parent addd00d commit 584c8a5
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 19 deletions.
2 changes: 1 addition & 1 deletion firedrake/cofunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self, function_space, val=None, name=None, dtype=ScalarType,
# Internal comm
self._comm = mpi.internal_comm(V.comm, self)
self._function_space = V
self.uid = utils._new_uid()
self.uid = utils._new_uid(self._comm)
self._name = name or 'cofunction_%d' % self.uid
self._label = "a cofunction"

Expand Down
6 changes: 2 additions & 4 deletions firedrake/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,12 @@ def __init__(self, value, domain=None, name=None, count=None):

self.dat, rank, self._ufl_shape = _create_dat(op2.Constant, value, None)

self.uid = utils._new_uid()
self.name = name or 'constant_%d' % self.uid

super().__init__()
Counted.__init__(self, count, Counted)
self.name = name or f"constant_{self._count}"

def __repr__(self):
return f"Constant({self.dat.data_ro}, {self.count()})"
return f"Constant({self.dat.data_ro}, name='{self.name}', count={self._count})"

def _ufl_signature_data_(self, renumbering):
return (type(self).__name__, renumbering[self])
Expand Down
2 changes: 1 addition & 1 deletion firedrake/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self, function_space, val=None, name=None, dtype=ScalarType):
# Internal comm
self._comm = mpi.internal_comm(function_space.comm, self)
self._function_space = function_space
self.uid = utils._new_uid()
self.uid = utils._new_uid(self._comm)
self._name = name or 'function_%d' % self.uid
self._label = "a function"

Expand Down
14 changes: 8 additions & 6 deletions firedrake/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -1872,11 +1872,11 @@ def init(self, coordinates):
class MeshGeometry(ufl.Mesh, MeshGeometryMixin):
"""A representation of mesh topology and geometry."""

def __new__(cls, element):
def __new__(cls, element, comm):
"""Create mesh geometry object."""
utils._init()
mesh = super(MeshGeometry, cls).__new__(cls)
uid = utils._new_uid()
uid = utils._new_uid(internal_comm(comm, mesh))
mesh.uid = uid
cargo = MeshGeometryCargo(uid)
assert isinstance(element, finat.ufl.FiniteElementBase)
Expand Down Expand Up @@ -2446,6 +2446,8 @@ def make_mesh_from_coordinates(coordinates, name, tolerance=0.5):
The name of the mesh.
tolerance : numbers.Number
The tolerance; see `Mesh`.
comm: mpi4py.Intracomm
Communicator.
Returns
-------
Expand All @@ -2467,7 +2469,7 @@ def make_mesh_from_coordinates(coordinates, name, tolerance=0.5):
cell = element.cell.reconstruct(geometric_dimension=V.value_size)
element = element.reconstruct(cell=cell)

mesh = MeshGeometry.__new__(MeshGeometry, element)
mesh = MeshGeometry.__new__(MeshGeometry, element, coordinates.comm)
mesh.__init__(coordinates)
mesh.name = name
# Mark mesh as being made from coordinates
Expand Down Expand Up @@ -2504,7 +2506,7 @@ def make_mesh_from_mesh_topology(topology, name, tolerance=0.5):
else:
element = finat.ufl.VectorElement("DQ" if cell in [ufl.quadrilateral, ufl.hexahedron] else "DG", cell, 1, variant="equispaced")
# Create mesh object
mesh = MeshGeometry.__new__(MeshGeometry, element)
mesh = MeshGeometry.__new__(MeshGeometry, element, topology.comm)
mesh._init_topology(topology)
mesh.name = name
mesh._tolerance = tolerance
Expand Down Expand Up @@ -2537,7 +2539,7 @@ def make_vom_from_vom_topology(topology, name, tolerance=0.5):
tcell = topology.ufl_cell()
cell = tcell.reconstruct(geometric_dimension=gdim)
element = finat.ufl.VectorElement("DG", cell, 0)
vmesh = MeshGeometry.__new__(MeshGeometry, element)
vmesh = MeshGeometry.__new__(MeshGeometry, element, topology.comm)
vmesh._init_topology(topology)
# Save vertex reference coordinate (within reference cell) in function
parent_tdim = topology._parent_mesh.ufl_cell().topological_dimension()
Expand Down Expand Up @@ -2709,7 +2711,7 @@ def Mesh(meshfile, **kwargs):
comm=user_comm)
mesh = make_mesh_from_mesh_topology(topology, name)
if netgen and isinstance(meshfile, netgen.libngpy._meshing.Mesh):
netgen_firedrake_mesh.createFromTopology(topology, name=plex.getName())
netgen_firedrake_mesh.createFromTopology(topology, name=plex.getName(), comm=user_comm)
mesh = netgen_firedrake_mesh.firedrakeMesh
mesh._tolerance = tolerance
return mesh
Expand Down
14 changes: 9 additions & 5 deletions firedrake/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from pyop2.datatypes import RealType # noqa: F401
from pyop2.datatypes import IntType # noqa: F401
from pyop2.datatypes import as_ctypes # noqa: F401
from pyop2.mpi import MPI
from firedrake_configuration import get_config

_current_uid = 0
# MPI key value for storing a per communicator universal identifier
FIREDRAKE_UID = MPI.Comm.Create_keyval()

RealType_c = as_cstr(RealType)
ScalarType_c = as_cstr(ScalarType)
Expand All @@ -20,10 +22,12 @@
SLATE_SUPPORTS_COMPLEX = False


def _new_uid():
global _current_uid
_current_uid += 1
return _current_uid
def _new_uid(comm):
uid = comm.Get_attr(FIREDRAKE_UID)
if uid is None:
uid = 0
comm.Set_attr(FIREDRAKE_UID, uid + 1)
return uid


def _init():
Expand Down
4 changes: 2 additions & 2 deletions tests/slate/test_optimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def test_push_mul_nested(TC, TC_without_trace, TC_non_symm):
opt_expressions = [T*C+T*C+T2*C, T*C+T2*C+T*C, T*C-(T*C)+T2*C, T*C+T2*C-(T*C),
T*T.solve(C), T.solve(T*C), T2*T.solve(C), T2*T.solve(T*C),
(T.T.solve(T.T.solve(C))).T*T, (T3*(T3.T.solve(C3.T))).T, (T3.T.solve(T3.T.solve(C3))).T*T3]
compare_vector_expressions_mixed(expressions, rtol=1e-11)
compare_vector_expressions_mixed(expressions, rtol=1e-10)
compare_slate_tensors(expressions, opt_expressions)

# Make sure replacing inverse by solves does not introduce errors
Expand Down Expand Up @@ -325,7 +325,7 @@ def test_partially_optimised(TC_non_symm, TC_double_mass, TC):
A*A.solve(A*C+A*C)+A*A.solve(A*C+A*C),
(A.T.solve(A.T.solve(C.T))).T*A]

compare_vector_expressions(expressions, rtol=1e-11)
compare_vector_expressions(expressions, rtol=1e-10)
compare_slate_tensors(expressions, opt_expressions)

# Make sure optimised solve gives same answer as expression with inverses
Expand Down
56 changes: 56 additions & 0 deletions tests/unit/test_utils/test_uid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest
import numpy as np

from firedrake import *
from firedrake.utils import _new_uid
from functools import partial


def make_mesh(comm):
return UnitSquareMesh(5, 5, comm=comm)


def make_function(comm):
mesh = UnitSquareMesh(5, 5, comm=comm)
V = FunctionSpace(mesh, 'Lagrange', 1)
return Function(V)


def make_cofunction(comm):
mesh = UnitSquareMesh(5, 5, comm=comm)
V = FunctionSpace(mesh, 'Lagrange', 1)
return Cofunction(V.dual())


def make_constant(comm):
return Constant(677)


@pytest.fixture(
params=[Function, Cofunction, Constant, UnitSquareMesh],
ids=["Function", "Cofunction", "Constant", "Mesh"]
)
def obj(request):
''' Make a callable for creating a Firedrake object
'''
case = {
Function: make_function,
Cofunction: make_cofunction,
Constant: make_constant,
UnitSquareMesh: make_mesh
}
return partial(case[request.param])


@pytest.mark.parallel(nprocs=[1, 2, 3, 4])
def test_monotonic_uid(obj):
object_parallel = obj(comm=COMM_WORLD) # noqa: F841

if COMM_WORLD.rank == 0:
object_serial = obj(comm=COMM_SELF) # noqa: F841

for comm in [COMM_WORLD, COMM_SELF]:
new = np.array([_new_uid(comm)])
all_new = np.array([-1]*comm.size)
comm.Allgather(new, all_new)
assert all([a == all_new[comm.rank] for a in all_new])

0 comments on commit 584c8a5

Please sign in to comment.