Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement per communicator universal identifiers for Firedrake objects #3633

Merged
merged 8 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion firedrake/cofunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self, function_space, val=None, name=None, dtype=ScalarType,
# Internal comm
self._comm = mpi.internal_comm(V.comm, self)
self._function_space = V
self.uid = utils._new_uid()
self.uid = utils._new_uid(self._comm)
self._name = name or 'cofunction_%d' % self.uid
self._label = "a cofunction"

Expand Down
6 changes: 2 additions & 4 deletions firedrake/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,12 @@ def __init__(self, value, domain=None, name=None, count=None):

self.dat, rank, self._ufl_shape = _create_dat(op2.Constant, value, None)

self.uid = utils._new_uid()
self.name = name or 'constant_%d' % self.uid

super().__init__()
Counted.__init__(self, count, Counted)
self.name = name or f"constant_{self._count}"

def __repr__(self):
return f"Constant({self.dat.data_ro}, {self.count()})"
return f"Constant({self.dat.data_ro}, name='{self.name}', count={self._count})"

def _ufl_signature_data_(self, renumbering):
return (type(self).__name__, renumbering[self])
Expand Down
2 changes: 1 addition & 1 deletion firedrake/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self, function_space, val=None, name=None, dtype=ScalarType):
# Internal comm
self._comm = mpi.internal_comm(function_space.comm, self)
self._function_space = function_space
self.uid = utils._new_uid()
self.uid = utils._new_uid(self._comm)
self._name = name or 'function_%d' % self.uid
self._label = "a function"

Expand Down
14 changes: 8 additions & 6 deletions firedrake/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -1872,11 +1872,11 @@ def init(self, coordinates):
class MeshGeometry(ufl.Mesh, MeshGeometryMixin):
"""A representation of mesh topology and geometry."""

def __new__(cls, element):
def __new__(cls, element, comm):
"""Create mesh geometry object."""
utils._init()
mesh = super(MeshGeometry, cls).__new__(cls)
uid = utils._new_uid()
uid = utils._new_uid(internal_comm(comm, mesh))
mesh.uid = uid
cargo = MeshGeometryCargo(uid)
assert isinstance(element, finat.ufl.FiniteElementBase)
Expand Down Expand Up @@ -2446,6 +2446,8 @@ def make_mesh_from_coordinates(coordinates, name, tolerance=0.5):
The name of the mesh.
tolerance : numbers.Number
The tolerance; see `Mesh`.
comm: mpi4py.Intracomm
Communicator.

Returns
-------
Expand All @@ -2467,7 +2469,7 @@ def make_mesh_from_coordinates(coordinates, name, tolerance=0.5):
cell = element.cell.reconstruct(geometric_dimension=V.value_size)
element = element.reconstruct(cell=cell)

mesh = MeshGeometry.__new__(MeshGeometry, element)
mesh = MeshGeometry.__new__(MeshGeometry, element, coordinates.comm)
mesh.__init__(coordinates)
mesh.name = name
# Mark mesh as being made from coordinates
Expand Down Expand Up @@ -2504,7 +2506,7 @@ def make_mesh_from_mesh_topology(topology, name, tolerance=0.5):
else:
element = finat.ufl.VectorElement("DQ" if cell in [ufl.quadrilateral, ufl.hexahedron] else "DG", cell, 1, variant="equispaced")
# Create mesh object
mesh = MeshGeometry.__new__(MeshGeometry, element)
mesh = MeshGeometry.__new__(MeshGeometry, element, topology.comm)
mesh._init_topology(topology)
mesh.name = name
mesh._tolerance = tolerance
Expand Down Expand Up @@ -2537,7 +2539,7 @@ def make_vom_from_vom_topology(topology, name, tolerance=0.5):
tcell = topology.ufl_cell()
cell = tcell.reconstruct(geometric_dimension=gdim)
element = finat.ufl.VectorElement("DG", cell, 0)
vmesh = MeshGeometry.__new__(MeshGeometry, element)
vmesh = MeshGeometry.__new__(MeshGeometry, element, topology.comm)
vmesh._init_topology(topology)
# Save vertex reference coordinate (within reference cell) in function
parent_tdim = topology._parent_mesh.ufl_cell().topological_dimension()
Expand Down Expand Up @@ -2709,7 +2711,7 @@ def Mesh(meshfile, **kwargs):
comm=user_comm)
mesh = make_mesh_from_mesh_topology(topology, name)
if netgen and isinstance(meshfile, netgen.libngpy._meshing.Mesh):
netgen_firedrake_mesh.createFromTopology(topology, name=plex.getName())
netgen_firedrake_mesh.createFromTopology(topology, name=plex.getName(), comm=user_comm)
mesh = netgen_firedrake_mesh.firedrakeMesh
mesh._tolerance = tolerance
return mesh
Expand Down
14 changes: 9 additions & 5 deletions firedrake/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from pyop2.datatypes import RealType # noqa: F401
from pyop2.datatypes import IntType # noqa: F401
from pyop2.datatypes import as_ctypes # noqa: F401
from pyop2.mpi import MPI
from firedrake_configuration import get_config

_current_uid = 0
# MPI key value for storing a per communicator universal identifier
FIREDRAKE_UID = MPI.Comm.Create_keyval()

RealType_c = as_cstr(RealType)
ScalarType_c = as_cstr(ScalarType)
Expand All @@ -20,10 +22,12 @@
SLATE_SUPPORTS_COMPLEX = False


def _new_uid():
global _current_uid
_current_uid += 1
return _current_uid
def _new_uid(comm):
uid = comm.Get_attr(FIREDRAKE_UID)
if uid is None:
uid = 0
comm.Set_attr(FIREDRAKE_UID, uid + 1)
return uid


def _init():
Expand Down
4 changes: 2 additions & 2 deletions tests/slate/test_optimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def test_push_mul_nested(TC, TC_without_trace, TC_non_symm):
opt_expressions = [T*C+T*C+T2*C, T*C+T2*C+T*C, T*C-(T*C)+T2*C, T*C+T2*C-(T*C),
T*T.solve(C), T.solve(T*C), T2*T.solve(C), T2*T.solve(T*C),
(T.T.solve(T.T.solve(C))).T*T, (T3*(T3.T.solve(C3.T))).T, (T3.T.solve(T3.T.solve(C3))).T*T3]
compare_vector_expressions_mixed(expressions, rtol=1e-11)
compare_vector_expressions_mixed(expressions, rtol=1e-10)
compare_slate_tensors(expressions, opt_expressions)

# Make sure replacing inverse by solves does not introduce errors
Expand Down Expand Up @@ -325,7 +325,7 @@ def test_partially_optimised(TC_non_symm, TC_double_mass, TC):
A*A.solve(A*C+A*C)+A*A.solve(A*C+A*C),
(A.T.solve(A.T.solve(C.T))).T*A]

compare_vector_expressions(expressions, rtol=1e-11)
compare_vector_expressions(expressions, rtol=1e-10)
compare_slate_tensors(expressions, opt_expressions)

# Make sure optimised solve gives same answer as expression with inverses
Expand Down
56 changes: 56 additions & 0 deletions tests/unit/test_utils/test_uid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest
import numpy as np

from firedrake import *
from firedrake.utils import _new_uid
from functools import partial


def make_mesh(comm):
return UnitSquareMesh(5, 5, comm=comm)


def make_function(comm):
mesh = UnitSquareMesh(5, 5, comm=comm)
V = FunctionSpace(mesh, 'Lagrange', 1)
return Function(V)


def make_cofunction(comm):
mesh = UnitSquareMesh(5, 5, comm=comm)
V = FunctionSpace(mesh, 'Lagrange', 1)
return Cofunction(V.dual())


def make_constant(comm):
return Constant(677)


@pytest.fixture(
params=[Function, Cofunction, Constant, UnitSquareMesh],
ids=["Function", "Cofunction", "Constant", "Mesh"]
)
def obj(request):
''' Make a callable for creating a Firedrake object
'''
case = {
Function: make_function,
Cofunction: make_cofunction,
Constant: make_constant,
UnitSquareMesh: make_mesh
}
return partial(case[request.param])


@pytest.mark.parallel(nprocs=[1, 2, 3, 4])
def test_monotonic_uid(obj):
object_parallel = obj(comm=COMM_WORLD) # noqa: F841

if COMM_WORLD.rank == 0:
object_serial = obj(comm=COMM_SELF) # noqa: F841

for comm in [COMM_WORLD, COMM_SELF]:
new = np.array([_new_uid(comm)])
all_new = np.array([-1]*comm.size)
comm.Allgather(new, all_new)
assert all([a == all_new[comm.rank] for a in all_new])
Loading