From 78541d515bd24c5965dec7e4856e07dce8693778 Mon Sep 17 00:00:00 2001 From: Jack Betteridge <43041811+JDBetteridge@users.noreply.github.com> Date: Wed, 19 Jun 2024 19:21:27 +0100 Subject: [PATCH] Implement per communicator universal identifiers for Firedrake objects (#3633) * Implement per communicator universal identifiers for Firedrake objects --- firedrake/cofunction.py | 2 +- firedrake/constant.py | 6 ++-- firedrake/function.py | 2 +- firedrake/mesh.py | 14 ++++---- firedrake/utils.py | 14 +++++--- tests/slate/test_optimise.py | 4 +-- tests/unit/test_utils/test_uid.py | 56 +++++++++++++++++++++++++++++++ 7 files changed, 79 insertions(+), 19 deletions(-) create mode 100644 tests/unit/test_utils/test_uid.py diff --git a/firedrake/cofunction.py b/firedrake/cofunction.py index 7524a6b00b..b848c08f74 100644 --- a/firedrake/cofunction.py +++ b/firedrake/cofunction.py @@ -65,7 +65,7 @@ def __init__(self, function_space, val=None, name=None, dtype=ScalarType, # Internal comm self._comm = mpi.internal_comm(V.comm, self) self._function_space = V - self.uid = utils._new_uid() + self.uid = utils._new_uid(self._comm) self._name = name or 'cofunction_%d' % self.uid self._label = "a cofunction" diff --git a/firedrake/constant.py b/firedrake/constant.py index 4000716535..1011906c1b 100644 --- a/firedrake/constant.py +++ b/firedrake/constant.py @@ -104,14 +104,12 @@ def __init__(self, value, domain=None, name=None, count=None): self.dat, rank, self._ufl_shape = _create_dat(op2.Constant, value, None) - self.uid = utils._new_uid() - self.name = name or 'constant_%d' % self.uid - super().__init__() Counted.__init__(self, count, Counted) + self.name = name or f"constant_{self._count}" def __repr__(self): - return f"Constant({self.dat.data_ro}, {self.count()})" + return f"Constant({self.dat.data_ro}, name='{self.name}', count={self._count})" def _ufl_signature_data_(self, renumbering): return (type(self).__name__, renumbering[self]) diff --git a/firedrake/function.py b/firedrake/function.py index d5a5646e1c..7307877de2 100644 --- a/firedrake/function.py +++ b/firedrake/function.py @@ -70,7 +70,7 @@ def __init__(self, function_space, val=None, name=None, dtype=ScalarType): # Internal comm self._comm = mpi.internal_comm(function_space.comm, self) self._function_space = function_space - self.uid = utils._new_uid() + self.uid = utils._new_uid(self._comm) self._name = name or 'function_%d' % self.uid self._label = "a function" diff --git a/firedrake/mesh.py b/firedrake/mesh.py index 729a86113d..12bc514b4e 100644 --- a/firedrake/mesh.py +++ b/firedrake/mesh.py @@ -1872,11 +1872,11 @@ def init(self, coordinates): class MeshGeometry(ufl.Mesh, MeshGeometryMixin): """A representation of mesh topology and geometry.""" - def __new__(cls, element): + def __new__(cls, element, comm): """Create mesh geometry object.""" utils._init() mesh = super(MeshGeometry, cls).__new__(cls) - uid = utils._new_uid() + uid = utils._new_uid(internal_comm(comm, mesh)) mesh.uid = uid cargo = MeshGeometryCargo(uid) assert isinstance(element, finat.ufl.FiniteElementBase) @@ -2446,6 +2446,8 @@ def make_mesh_from_coordinates(coordinates, name, tolerance=0.5): The name of the mesh. tolerance : numbers.Number The tolerance; see `Mesh`. + comm: mpi4py.Intracomm + Communicator. Returns ------- @@ -2467,7 +2469,7 @@ def make_mesh_from_coordinates(coordinates, name, tolerance=0.5): cell = element.cell.reconstruct(geometric_dimension=V.value_size) element = element.reconstruct(cell=cell) - mesh = MeshGeometry.__new__(MeshGeometry, element) + mesh = MeshGeometry.__new__(MeshGeometry, element, coordinates.comm) mesh.__init__(coordinates) mesh.name = name # Mark mesh as being made from coordinates @@ -2504,7 +2506,7 @@ def make_mesh_from_mesh_topology(topology, name, tolerance=0.5): else: element = finat.ufl.VectorElement("DQ" if cell in [ufl.quadrilateral, ufl.hexahedron] else "DG", cell, 1, variant="equispaced") # Create mesh object - mesh = MeshGeometry.__new__(MeshGeometry, element) + mesh = MeshGeometry.__new__(MeshGeometry, element, topology.comm) mesh._init_topology(topology) mesh.name = name mesh._tolerance = tolerance @@ -2537,7 +2539,7 @@ def make_vom_from_vom_topology(topology, name, tolerance=0.5): tcell = topology.ufl_cell() cell = tcell.reconstruct(geometric_dimension=gdim) element = finat.ufl.VectorElement("DG", cell, 0) - vmesh = MeshGeometry.__new__(MeshGeometry, element) + vmesh = MeshGeometry.__new__(MeshGeometry, element, topology.comm) vmesh._init_topology(topology) # Save vertex reference coordinate (within reference cell) in function parent_tdim = topology._parent_mesh.ufl_cell().topological_dimension() @@ -2709,7 +2711,7 @@ def Mesh(meshfile, **kwargs): comm=user_comm) mesh = make_mesh_from_mesh_topology(topology, name) if netgen and isinstance(meshfile, netgen.libngpy._meshing.Mesh): - netgen_firedrake_mesh.createFromTopology(topology, name=plex.getName()) + netgen_firedrake_mesh.createFromTopology(topology, name=plex.getName(), comm=user_comm) mesh = netgen_firedrake_mesh.firedrakeMesh mesh._tolerance = tolerance return mesh diff --git a/firedrake/utils.py b/firedrake/utils.py index c2b5440e57..2dd768fb28 100644 --- a/firedrake/utils.py +++ b/firedrake/utils.py @@ -6,9 +6,11 @@ from pyop2.datatypes import RealType # noqa: F401 from pyop2.datatypes import IntType # noqa: F401 from pyop2.datatypes import as_ctypes # noqa: F401 +from pyop2.mpi import MPI from firedrake_configuration import get_config -_current_uid = 0 +# MPI key value for storing a per communicator universal identifier +FIREDRAKE_UID = MPI.Comm.Create_keyval() RealType_c = as_cstr(RealType) ScalarType_c = as_cstr(ScalarType) @@ -20,10 +22,12 @@ SLATE_SUPPORTS_COMPLEX = False -def _new_uid(): - global _current_uid - _current_uid += 1 - return _current_uid +def _new_uid(comm): + uid = comm.Get_attr(FIREDRAKE_UID) + if uid is None: + uid = 0 + comm.Set_attr(FIREDRAKE_UID, uid + 1) + return uid def _init(): diff --git a/tests/slate/test_optimise.py b/tests/slate/test_optimise.py index 13710c3617..276369df61 100644 --- a/tests/slate/test_optimise.py +++ b/tests/slate/test_optimise.py @@ -266,7 +266,7 @@ def test_push_mul_nested(TC, TC_without_trace, TC_non_symm): opt_expressions = [T*C+T*C+T2*C, T*C+T2*C+T*C, T*C-(T*C)+T2*C, T*C+T2*C-(T*C), T*T.solve(C), T.solve(T*C), T2*T.solve(C), T2*T.solve(T*C), (T.T.solve(T.T.solve(C))).T*T, (T3*(T3.T.solve(C3.T))).T, (T3.T.solve(T3.T.solve(C3))).T*T3] - compare_vector_expressions_mixed(expressions, rtol=1e-11) + compare_vector_expressions_mixed(expressions, rtol=1e-10) compare_slate_tensors(expressions, opt_expressions) # Make sure replacing inverse by solves does not introduce errors @@ -325,7 +325,7 @@ def test_partially_optimised(TC_non_symm, TC_double_mass, TC): A*A.solve(A*C+A*C)+A*A.solve(A*C+A*C), (A.T.solve(A.T.solve(C.T))).T*A] - compare_vector_expressions(expressions, rtol=1e-11) + compare_vector_expressions(expressions, rtol=1e-10) compare_slate_tensors(expressions, opt_expressions) # Make sure optimised solve gives same answer as expression with inverses diff --git a/tests/unit/test_utils/test_uid.py b/tests/unit/test_utils/test_uid.py new file mode 100644 index 0000000000..042fea2384 --- /dev/null +++ b/tests/unit/test_utils/test_uid.py @@ -0,0 +1,56 @@ +import pytest +import numpy as np + +from firedrake import * +from firedrake.utils import _new_uid +from functools import partial + + +def make_mesh(comm): + return UnitSquareMesh(5, 5, comm=comm) + + +def make_function(comm): + mesh = UnitSquareMesh(5, 5, comm=comm) + V = FunctionSpace(mesh, 'Lagrange', 1) + return Function(V) + + +def make_cofunction(comm): + mesh = UnitSquareMesh(5, 5, comm=comm) + V = FunctionSpace(mesh, 'Lagrange', 1) + return Cofunction(V.dual()) + + +def make_constant(comm): + return Constant(677) + + +@pytest.fixture( + params=[Function, Cofunction, Constant, UnitSquareMesh], + ids=["Function", "Cofunction", "Constant", "Mesh"] +) +def obj(request): + ''' Make a callable for creating a Firedrake object + ''' + case = { + Function: make_function, + Cofunction: make_cofunction, + Constant: make_constant, + UnitSquareMesh: make_mesh + } + return partial(case[request.param]) + + +@pytest.mark.parallel(nprocs=[1, 2, 3, 4]) +def test_monotonic_uid(obj): + object_parallel = obj(comm=COMM_WORLD) # noqa: F841 + + if COMM_WORLD.rank == 0: + object_serial = obj(comm=COMM_SELF) # noqa: F841 + + for comm in [COMM_WORLD, COMM_SELF]: + new = np.array([_new_uid(comm)]) + all_new = np.array([-1]*comm.size) + comm.Allgather(new, all_new) + assert all([a == all_new[comm.rank] for a in all_new])