diff --git a/AUTHORS b/AUTHORS index f84cf5780..49dac8204 100644 --- a/AUTHORS +++ b/AUTHORS @@ -21,3 +21,4 @@ Florian Rathgeber Francis Russell Kaho Sato Reuben W. Nixon-Hill +Nacime Bouziani diff --git a/pyop2/types/dat.py b/pyop2/types/dat.py index 07a40e98a..ed2e6f66c 100644 --- a/pyop2/types/dat.py +++ b/pyop2/types/dat.py @@ -145,6 +145,8 @@ def data(self): :meth:`data_with_halos`. """ + # Increment dat_version since this accessor assumes data modification + self.increment_dat_version() if self.dataset.total_size > 0 and self._data.size == 0 and self.cdim > 0: raise RuntimeError("Illegal access: no data associated with this Dat!") self.halo_valid = False @@ -255,6 +257,8 @@ def zero(self, subset=None): """Zero the data associated with this :class:`Dat` :arg subset: A :class:`Subset` of entries to zero (optional).""" + # Data modification + self.increment_dat_version() # If there is no subset we can safely zero the halo values. if subset is None: self._data[:] = 0 @@ -668,6 +672,12 @@ def data_ro_with_halos(self): class Dat(AbstractDat, VecAccessMixin): + def __init__(self, *args, **kwargs): + AbstractDat.__init__(self, *args, **kwargs) + # Determine if we can rely on PETSc state counter + petsc_counter = (self.dtype == PETSc.ScalarType) + VecAccessMixin.__init__(self, petsc_counter=petsc_counter) + @utils.cached_property def _vec(self): assert self.dtype == PETSc.ScalarType, \ @@ -685,11 +695,6 @@ def vec_context(self, access): r"""A context manager for a :class:`PETSc.Vec` from a :class:`Dat`. :param access: Access descriptor: READ, WRITE, or RW.""" - # PETSc Vecs have a state counter and cache norm computations - # to return immediately if the state counter is unchanged. - # Since we've updated the data behind their back, we need to - # change that state counter. - self._vec.stateIncrease() yield self._vec if access is not Access.READ: self.halo_valid = False @@ -729,6 +734,10 @@ def what(x): # TODO: Think about different communicators on dats (c.f. MixedSet) self.comm = self._dats[0].comm + @property + def dat_version(self): + return sum(d.dat_version for d in self._dats) + def __call__(self, access, path=None): from pyop2.parloop import MixedDatLegacyArg return MixedDatLegacyArg(self, path, access) @@ -1012,7 +1021,6 @@ def vec_context(self, access): size = v.local_size array[offset:offset+size] = v.array_r[:] offset += size - self._vec.stateIncrease() yield self._vec if access is not Access.READ: # Reverse scatter to get the values back to their original locations diff --git a/pyop2/types/data_carrier.py b/pyop2/types/data_carrier.py index 78a268a84..73d3974c2 100644 --- a/pyop2/types/data_carrier.py +++ b/pyop2/types/data_carrier.py @@ -44,6 +44,9 @@ def cdim(self): the product of the dim tuple.""" return self._cdim + def increment_dat_version(self): + pass + class EmptyDataMixin(abc.ABC): """A mixin for :class:`Dat` and :class:`Global` objects that takes @@ -75,6 +78,27 @@ def _is_allocated(self): class VecAccessMixin(abc.ABC): + + def __init__(self, petsc_counter=None): + if petsc_counter: + # Use lambda since `_vec` allocates the data buffer + # -> Dat/Global should not allocate storage until accessed + self._dat_version = lambda: self._vec.stateGet() + self.increment_dat_version = lambda: self._vec.stateIncrease() + else: + # No associated PETSc Vec if incompatible type: + # -> Equip Dat/Global with their own counter. + self._version = 0 + self._dat_version = lambda: self._version + + def _inc(): + self._version += 1 + self.increment_dat_version = _inc + + @property + def dat_version(self): + return self._dat_version() + @abc.abstractmethod def vec_context(self, access): pass diff --git a/pyop2/types/glob.py b/pyop2/types/glob.py index 86b713cef..05fa0b4f5 100644 --- a/pyop2/types/glob.py +++ b/pyop2/types/glob.py @@ -51,6 +51,9 @@ def __init__(self, dim, data=None, dtype=None, name=None, comm=None): self._buf = np.empty(self.shape, dtype=self.dtype) self._name = name or "global_#x%x" % id(self) self.comm = comm + # Object versioning setup + petsc_counter = (self.comm and self.dtype == PETSc.ScalarType) + VecAccessMixin.__init__(self, petsc_counter=petsc_counter) @utils.cached_property def _kernel_args_(self): @@ -104,6 +107,7 @@ def shape(self): @property def data(self): """Data array.""" + self.increment_dat_version() if len(self._data) == 0: raise RuntimeError("Illegal access: No data associated with this Global!") return self._data @@ -115,12 +119,13 @@ def dtype(self): @property def data_ro(self): """Data array.""" - view = self.data.view() + view = self._data.view() view.setflags(write=False) return view @data.setter def data(self, value): + self.increment_dat_version() self._data[:] = utils.verify_reshape(value, self.dtype, self.dim) @property @@ -153,6 +158,7 @@ def copy(self, other, subset=None): @mpi.collective def zero(self, subset=None): assert subset is None + self.increment_dat_version() self._data[...] = 0 @mpi.collective @@ -282,11 +288,6 @@ def vec_context(self, access): """A context manager for a :class:`PETSc.Vec` from a :class:`Global`. :param access: Access descriptor: READ, WRITE, or RW.""" - # PETSc Vecs have a state counter and cache norm computations - # to return immediately if the state counter is unchanged. - # Since we've updated the data behind their back, we need to - # change that state counter. - self._vec.stateIncrease() yield self._vec if access is not Access.READ: data = self._data diff --git a/pyop2/types/mat.py b/pyop2/types/mat.py index c7dc06f3f..723647edc 100644 --- a/pyop2/types/mat.py +++ b/pyop2/types/mat.py @@ -824,6 +824,12 @@ def __iter__(self): """Iterate over all :class:`Mat` blocks by row and then by column.""" yield from itertools.chain(*self.blocks) + @property + def dat_version(self): + if self.assembly_state != Mat.ASSEMBLED: + raise RuntimeError("Should not ask for state counter if the matrix is not assembled.") + return self.handle.stateGet() + @mpi.collective def zero(self): """Zero the matrix.""" @@ -936,6 +942,10 @@ def __init__(self, parent, i, j): self.comm = parent.comm self.local_to_global_maps = self.handle.getLGMap() + @property + def dat_version(self): + return self.handle.stateGet() + @utils.cached_property def _kernel_args_(self): return (self.handle.handle, ) diff --git a/test/unit/test_api.py b/test/unit/test_api.py index 6ea2a6832..8f61805c8 100644 --- a/test/unit/test_api.py +++ b/test/unit/test_api.py @@ -832,8 +832,8 @@ def test_dat_zero_cdim(self, set): dset = set**0 d = op2.Dat(dset) assert d.shape == (set.total_size, 0) - assert d.data.size == 0 - assert d.data.shape == (set.total_size, 0) + assert d._data.size == 0 + assert d._data.shape == (set.total_size, 0) class TestMixedDatAPI: diff --git a/test/unit/test_dats.py b/test/unit/test_dats.py index a34df99e2..54bb491d5 100644 --- a/test/unit/test_dats.py +++ b/test/unit/test_dats.py @@ -130,6 +130,92 @@ def test_dat_save_and_load(self, tmpdir, d1, s, mdat): mdat2.load(output) assert all(all(d.data_ro == d_.data_ro) for d, d_ in zip(mdat, mdat2)) + def test_dat_version(self, s, d1): + """Check object versioning for Dat""" + d2 = op2.Dat(s) + + assert d1.dat_version == 0 + assert d2.dat_version == 0 + + # Access data property + d1.data + + assert d1.dat_version == 1 + assert d2.dat_version == 0 + + # Access data property + d2.data[:] += 1 + + assert d1.dat_version == 1 + assert d2.dat_version == 1 + + # Access zero property + d1.zero() + + assert d1.dat_version == 2 + assert d2.dat_version == 1 + + # Copy d2 into d1 + d2.copy(d1) + + assert d1.dat_version == 3 + assert d2.dat_version == 1 + + # Context managers (without changing d1 and d2) + with d1.vec_wo as _: + pass + + with d2.vec as _: + pass + + # Dat version shouldn't change as we are just calling the context manager + # and not changing the Dat objects. + assert d1.dat_version == 3 + assert d2.dat_version == 1 + + # Context managers (modify d1 and d2) + with d1.vec_wo as x: + x += 1 + + with d2.vec as x: + x += 1 + + assert d1.dat_version == 4 + assert d2.dat_version == 2 + + def test_mixed_dat_version(self, s, d1, mdat): + """Check object versioning for MixedDat""" + d2 = op2.Dat(s) + mdat2 = op2.MixedDat([d1, d2]) + + assert mdat.dat_version == 0 + assert mdat2.dat_version == 0 + + # Access data property + mdat2.data + + # mdat2.data will call d1.data and d2.data + assert d1.dat_version == 1 + assert d2.dat_version == 1 + assert mdat.dat_version == 2 + assert mdat2.dat_version == 2 + + # Access zero property + mdat.zero() + + # mdat.zero() will call d1.zero() twice + assert d1.dat_version == 3 + assert d2.dat_version == 1 + assert mdat.dat_version == 6 + assert mdat2.dat_version == 4 + + # Access zero property + d1.zero() + + assert d1.dat_version == 4 + assert mdat.dat_version == 8 + assert mdat2.dat_version == 5 + if __name__ == '__main__': import os diff --git a/test/unit/test_globals.py b/test/unit/test_globals.py index 61449de33..b7adf57c6 100644 --- a/test/unit/test_globals.py +++ b/test/unit/test_globals.py @@ -44,3 +44,35 @@ def test_global_operations(): assert (g1 * g2).data == 10. g1 *= g2 assert g1.data == 10. + + +def test_global_dat_version(): + g1 = op2.Global(1, data=1.) + g2 = op2.Global(1, data=2.) + + assert g1.dat_version == 0 + assert g2.dat_version == 0 + + # Access data property + d1 = g1.data + + assert g1.dat_version == 1 + assert g2.dat_version == 0 + + # Access data property + g2.data[:] += 1 + + assert g1.dat_version == 1 + assert g2.dat_version == 1 + + # Access zero property + g1.zero() + + assert g1.dat_version == 2 + assert g2.dat_version == 1 + + # Access data setter + g2.data = d1 + + assert g1.dat_version == 2 + assert g2.dat_version == 2