Skip to content

Commit

Permalink
Merge pull request #579 from OP2/DataCarrier-object-versionning
Browse files Browse the repository at this point in the history
Object versioning for DataCarrier object
  • Loading branch information
dham authored Jul 21, 2022
2 parents e1f7598 + ff58332 commit 784b217
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 14 deletions.
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ Florian Rathgeber <florian.rathgeber@gmail.com>
Francis Russell <francis@unchartedbackwaters.co.uk>
Kaho Sato <kahosato93@gmail.com>
Reuben W. Nixon-Hill <reuben.nixon-hill10@imperial.ac.uk>
Nacime Bouziani <n.bouziani18@imperial.ac.uk>
20 changes: 14 additions & 6 deletions pyop2/types/dat.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ def data(self):
:meth:`data_with_halos`.
"""
# Increment dat_version since this accessor assumes data modification
self.increment_dat_version()
if self.dataset.total_size > 0 and self._data.size == 0 and self.cdim > 0:
raise RuntimeError("Illegal access: no data associated with this Dat!")
self.halo_valid = False
Expand Down Expand Up @@ -255,6 +257,8 @@ def zero(self, subset=None):
"""Zero the data associated with this :class:`Dat`
:arg subset: A :class:`Subset` of entries to zero (optional)."""
# Data modification
self.increment_dat_version()
# If there is no subset we can safely zero the halo values.
if subset is None:
self._data[:] = 0
Expand Down Expand Up @@ -668,6 +672,12 @@ def data_ro_with_halos(self):

class Dat(AbstractDat, VecAccessMixin):

def __init__(self, *args, **kwargs):
AbstractDat.__init__(self, *args, **kwargs)
# Determine if we can rely on PETSc state counter
petsc_counter = (self.dtype == PETSc.ScalarType)
VecAccessMixin.__init__(self, petsc_counter=petsc_counter)

@utils.cached_property
def _vec(self):
assert self.dtype == PETSc.ScalarType, \
Expand All @@ -685,11 +695,6 @@ def vec_context(self, access):
r"""A context manager for a :class:`PETSc.Vec` from a :class:`Dat`.
:param access: Access descriptor: READ, WRITE, or RW."""
# PETSc Vecs have a state counter and cache norm computations
# to return immediately if the state counter is unchanged.
# Since we've updated the data behind their back, we need to
# change that state counter.
self._vec.stateIncrease()
yield self._vec
if access is not Access.READ:
self.halo_valid = False
Expand Down Expand Up @@ -729,6 +734,10 @@ def what(x):
# TODO: Think about different communicators on dats (c.f. MixedSet)
self.comm = self._dats[0].comm

@property
def dat_version(self):
return sum(d.dat_version for d in self._dats)

def __call__(self, access, path=None):
from pyop2.parloop import MixedDatLegacyArg
return MixedDatLegacyArg(self, path, access)
Expand Down Expand Up @@ -1012,7 +1021,6 @@ def vec_context(self, access):
size = v.local_size
array[offset:offset+size] = v.array_r[:]
offset += size
self._vec.stateIncrease()
yield self._vec
if access is not Access.READ:
# Reverse scatter to get the values back to their original locations
Expand Down
24 changes: 24 additions & 0 deletions pyop2/types/data_carrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def cdim(self):
the product of the dim tuple."""
return self._cdim

def increment_dat_version(self):
pass


class EmptyDataMixin(abc.ABC):
"""A mixin for :class:`Dat` and :class:`Global` objects that takes
Expand Down Expand Up @@ -75,6 +78,27 @@ def _is_allocated(self):


class VecAccessMixin(abc.ABC):

def __init__(self, petsc_counter=None):
if petsc_counter:
# Use lambda since `_vec` allocates the data buffer
# -> Dat/Global should not allocate storage until accessed
self._dat_version = lambda: self._vec.stateGet()
self.increment_dat_version = lambda: self._vec.stateIncrease()
else:
# No associated PETSc Vec if incompatible type:
# -> Equip Dat/Global with their own counter.
self._version = 0
self._dat_version = lambda: self._version

def _inc():
self._version += 1
self.increment_dat_version = _inc

@property
def dat_version(self):
return self._dat_version()

@abc.abstractmethod
def vec_context(self, access):
pass
Expand Down
13 changes: 7 additions & 6 deletions pyop2/types/glob.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def __init__(self, dim, data=None, dtype=None, name=None, comm=None):
self._buf = np.empty(self.shape, dtype=self.dtype)
self._name = name or "global_#x%x" % id(self)
self.comm = comm
# Object versioning setup
petsc_counter = (self.comm and self.dtype == PETSc.ScalarType)
VecAccessMixin.__init__(self, petsc_counter=petsc_counter)

@utils.cached_property
def _kernel_args_(self):
Expand Down Expand Up @@ -104,6 +107,7 @@ def shape(self):
@property
def data(self):
"""Data array."""
self.increment_dat_version()
if len(self._data) == 0:
raise RuntimeError("Illegal access: No data associated with this Global!")
return self._data
Expand All @@ -115,12 +119,13 @@ def dtype(self):
@property
def data_ro(self):
"""Data array."""
view = self.data.view()
view = self._data.view()
view.setflags(write=False)
return view

@data.setter
def data(self, value):
self.increment_dat_version()
self._data[:] = utils.verify_reshape(value, self.dtype, self.dim)

@property
Expand Down Expand Up @@ -153,6 +158,7 @@ def copy(self, other, subset=None):
@mpi.collective
def zero(self, subset=None):
assert subset is None
self.increment_dat_version()
self._data[...] = 0

@mpi.collective
Expand Down Expand Up @@ -282,11 +288,6 @@ def vec_context(self, access):
"""A context manager for a :class:`PETSc.Vec` from a :class:`Global`.
:param access: Access descriptor: READ, WRITE, or RW."""
# PETSc Vecs have a state counter and cache norm computations
# to return immediately if the state counter is unchanged.
# Since we've updated the data behind their back, we need to
# change that state counter.
self._vec.stateIncrease()
yield self._vec
if access is not Access.READ:
data = self._data
Expand Down
10 changes: 10 additions & 0 deletions pyop2/types/mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,12 @@ def __iter__(self):
"""Iterate over all :class:`Mat` blocks by row and then by column."""
yield from itertools.chain(*self.blocks)

@property
def dat_version(self):
if self.assembly_state != Mat.ASSEMBLED:
raise RuntimeError("Should not ask for state counter if the matrix is not assembled.")
return self.handle.stateGet()

@mpi.collective
def zero(self):
"""Zero the matrix."""
Expand Down Expand Up @@ -936,6 +942,10 @@ def __init__(self, parent, i, j):
self.comm = parent.comm
self.local_to_global_maps = self.handle.getLGMap()

@property
def dat_version(self):
return self.handle.stateGet()

@utils.cached_property
def _kernel_args_(self):
return (self.handle.handle, )
Expand Down
4 changes: 2 additions & 2 deletions test/unit/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,8 +832,8 @@ def test_dat_zero_cdim(self, set):
dset = set**0
d = op2.Dat(dset)
assert d.shape == (set.total_size, 0)
assert d.data.size == 0
assert d.data.shape == (set.total_size, 0)
assert d._data.size == 0
assert d._data.shape == (set.total_size, 0)


class TestMixedDatAPI:
Expand Down
86 changes: 86 additions & 0 deletions test/unit/test_dats.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,92 @@ def test_dat_save_and_load(self, tmpdir, d1, s, mdat):
mdat2.load(output)
assert all(all(d.data_ro == d_.data_ro) for d, d_ in zip(mdat, mdat2))

def test_dat_version(self, s, d1):
"""Check object versioning for Dat"""
d2 = op2.Dat(s)

assert d1.dat_version == 0
assert d2.dat_version == 0

# Access data property
d1.data

assert d1.dat_version == 1
assert d2.dat_version == 0

# Access data property
d2.data[:] += 1

assert d1.dat_version == 1
assert d2.dat_version == 1

# Access zero property
d1.zero()

assert d1.dat_version == 2
assert d2.dat_version == 1

# Copy d2 into d1
d2.copy(d1)

assert d1.dat_version == 3
assert d2.dat_version == 1

# Context managers (without changing d1 and d2)
with d1.vec_wo as _:
pass

with d2.vec as _:
pass

# Dat version shouldn't change as we are just calling the context manager
# and not changing the Dat objects.
assert d1.dat_version == 3
assert d2.dat_version == 1

# Context managers (modify d1 and d2)
with d1.vec_wo as x:
x += 1

with d2.vec as x:
x += 1

assert d1.dat_version == 4
assert d2.dat_version == 2

def test_mixed_dat_version(self, s, d1, mdat):
"""Check object versioning for MixedDat"""
d2 = op2.Dat(s)
mdat2 = op2.MixedDat([d1, d2])

assert mdat.dat_version == 0
assert mdat2.dat_version == 0

# Access data property
mdat2.data

# mdat2.data will call d1.data and d2.data
assert d1.dat_version == 1
assert d2.dat_version == 1
assert mdat.dat_version == 2
assert mdat2.dat_version == 2

# Access zero property
mdat.zero()

# mdat.zero() will call d1.zero() twice
assert d1.dat_version == 3
assert d2.dat_version == 1
assert mdat.dat_version == 6
assert mdat2.dat_version == 4

# Access zero property
d1.zero()

assert d1.dat_version == 4
assert mdat.dat_version == 8
assert mdat2.dat_version == 5


if __name__ == '__main__':
import os
Expand Down
32 changes: 32 additions & 0 deletions test/unit/test_globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,35 @@ def test_global_operations():
assert (g1 * g2).data == 10.
g1 *= g2
assert g1.data == 10.


def test_global_dat_version():
g1 = op2.Global(1, data=1.)
g2 = op2.Global(1, data=2.)

assert g1.dat_version == 0
assert g2.dat_version == 0

# Access data property
d1 = g1.data

assert g1.dat_version == 1
assert g2.dat_version == 0

# Access data property
g2.data[:] += 1

assert g1.dat_version == 1
assert g2.dat_version == 1

# Access zero property
g1.zero()

assert g1.dat_version == 2
assert g2.dat_version == 1

# Access data setter
g2.data = d1

assert g1.dat_version == 2
assert g2.dat_version == 2

0 comments on commit 784b217

Please sign in to comment.