Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Object versioning for DataCarrier object #579

Merged
merged 29 commits into from
Jul 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0ff6b73
Object versioning for DataCarrier object (finality: firedrake.Constant)
nbouziani May 4, 2020
c945975
Update object versioning + add test
nbouziani May 6, 2020
f44f7eb
Merge remote-tracking branch 'origin/master' into DataCarrier-object-…
nbouziani May 9, 2020
43aca1e
Merge remote-tracking branch 'origin/master' into DataCarrier-object-…
nbouziani Jul 22, 2020
6e0178b
Merge remote-tracking branch 'origin/master' into DataCarrier-object-…
nbouziani Sep 7, 2020
fe63bf7
Merge remote-tracking branch 'origin/master' into DataCarrier-object-…
nbouziani Dec 20, 2020
e5af975
Update AUTHORS
nbouziani Dec 20, 2020
ff738b2
Merge branch 'master' into DataCarrier-object-versionning
nbouziani Apr 26, 2021
516454d
Update data versioning for Dat and Mat (base and petsc_base classes)
nbouziani Apr 29, 2021
07a47b3
Merge remote-tracking branch 'origin/master' into DataCarrier-object-…
nbouziani May 24, 2021
29973b1
Add some missing DataCarrier inits
nbouziani May 26, 2021
7d3d4be
Merge master
nbouziani Sep 24, 2021
6d39e69
Merge remote-tracking branch 'origin/master' into DataCarrier-object-…
nbouziani Nov 14, 2021
520078b
Add mat.py
nbouziani Nov 14, 2021
d89e93c
Refactor object versioning using PETSc state counter
nbouziani Nov 16, 2021
ca5e50d
Fix typo
nbouziani Nov 16, 2021
54dd188
Add tests for Dat
nbouziani Nov 16, 2021
18f9fa4
Remove spurious incrementations from VecAccessMixin
nbouziani Nov 16, 2021
2836c7a
Cleanup
nbouziani Nov 21, 2021
7bcde5a
Add dat_version to MixedDat + test
nbouziani Nov 21, 2021
0e879bd
Specify petsc branch in ci.yml
nbouziani Nov 24, 2021
55ad119
Update few things
nbouziani Nov 25, 2021
82f9ac5
Update test + remove @abstractmethod from increment_dat_version
nbouziani Nov 25, 2021
3edaf9a
Merge branch 'DataCarrier-object-versionning' of github.com:OP2/PyOP2…
nbouziani Nov 25, 2021
ba8d12c
Merge master
nbouziani Jul 13, 2022
0989f1d
Fix lint
nbouziani Jul 13, 2022
a1d9d9f
Update ci.yml
nbouziani Jul 14, 2022
9eabf60
Add tests for Dat context managers
nbouziani Jul 21, 2022
ff58332
Merge branch 'DataCarrier-object-versionning' of github.com:OP2/PyOP2…
nbouziani Jul 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ Florian Rathgeber <florian.rathgeber@gmail.com>
Francis Russell <francis@unchartedbackwaters.co.uk>
Kaho Sato <kahosato93@gmail.com>
Reuben W. Nixon-Hill <reuben.nixon-hill10@imperial.ac.uk>
Nacime Bouziani <n.bouziani18@imperial.ac.uk>
20 changes: 14 additions & 6 deletions pyop2/types/dat.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ def data(self):
:meth:`data_with_halos`.

"""
# Increment dat_version since this accessor assumes data modification
self.increment_dat_version()
if self.dataset.total_size > 0 and self._data.size == 0 and self.cdim > 0:
raise RuntimeError("Illegal access: no data associated with this Dat!")
self.halo_valid = False
Expand Down Expand Up @@ -255,6 +257,8 @@ def zero(self, subset=None):
"""Zero the data associated with this :class:`Dat`

:arg subset: A :class:`Subset` of entries to zero (optional)."""
# Data modification
self.increment_dat_version()
# If there is no subset we can safely zero the halo values.
if subset is None:
self._data[:] = 0
Expand Down Expand Up @@ -668,6 +672,12 @@ def data_ro_with_halos(self):

class Dat(AbstractDat, VecAccessMixin):

def __init__(self, *args, **kwargs):
AbstractDat.__init__(self, *args, **kwargs)
# Determine if we can rely on PETSc state counter
petsc_counter = (self.dtype == PETSc.ScalarType)
VecAccessMixin.__init__(self, petsc_counter=petsc_counter)

@utils.cached_property
def _vec(self):
assert self.dtype == PETSc.ScalarType, \
Expand All @@ -685,11 +695,6 @@ def vec_context(self, access):
r"""A context manager for a :class:`PETSc.Vec` from a :class:`Dat`.

:param access: Access descriptor: READ, WRITE, or RW."""
# PETSc Vecs have a state counter and cache norm computations
# to return immediately if the state counter is unchanged.
# Since we've updated the data behind their back, we need to
# change that state counter.
self._vec.stateIncrease()
yield self._vec
if access is not Access.READ:
self.halo_valid = False
Expand Down Expand Up @@ -729,6 +734,10 @@ def what(x):
# TODO: Think about different communicators on dats (c.f. MixedSet)
self.comm = self._dats[0].comm

@property
def dat_version(self):
return sum(d.dat_version for d in self._dats)

def __call__(self, access, path=None):
from pyop2.parloop import MixedDatLegacyArg
return MixedDatLegacyArg(self, path, access)
Expand Down Expand Up @@ -1012,7 +1021,6 @@ def vec_context(self, access):
size = v.local_size
array[offset:offset+size] = v.array_r[:]
offset += size
self._vec.stateIncrease()
yield self._vec
if access is not Access.READ:
# Reverse scatter to get the values back to their original locations
Expand Down
24 changes: 24 additions & 0 deletions pyop2/types/data_carrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def cdim(self):
the product of the dim tuple."""
return self._cdim

def increment_dat_version(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

declare this abstract.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that this requires all the subclasses to implement it and at the moment Global, Dat and MixedDat get the implementation of increment_dat_version via another route, i.e. VecAccessMixin. That’s because we don’t want to duplicate the object versioning setup in Global and AbstractDat objects so I put it in the appropriate common ancestor (VecAccessMixin).

pass


class EmptyDataMixin(abc.ABC):
"""A mixin for :class:`Dat` and :class:`Global` objects that takes
Expand Down Expand Up @@ -75,6 +78,27 @@ def _is_allocated(self):


class VecAccessMixin(abc.ABC):

def __init__(self, petsc_counter=None):
if petsc_counter:
# Use lambda since `_vec` allocates the data buffer
# -> Dat/Global should not allocate storage until accessed
self._dat_version = lambda: self._vec.stateGet()
self.increment_dat_version = lambda: self._vec.stateIncrease()
else:
# No associated PETSc Vec if incompatible type:
# -> Equip Dat/Global with their own counter.
self._version = 0
self._dat_version = lambda: self._version

def _inc():
self._version += 1
self.increment_dat_version = _inc

@property
def dat_version(self):
return self._dat_version()

@abc.abstractmethod
def vec_context(self, access):
pass
Expand Down
13 changes: 7 additions & 6 deletions pyop2/types/glob.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def __init__(self, dim, data=None, dtype=None, name=None, comm=None):
self._buf = np.empty(self.shape, dtype=self.dtype)
self._name = name or "global_#x%x" % id(self)
self.comm = comm
# Object versioning setup
petsc_counter = (self.comm and self.dtype == PETSc.ScalarType)
VecAccessMixin.__init__(self, petsc_counter=petsc_counter)

@utils.cached_property
def _kernel_args_(self):
Expand Down Expand Up @@ -104,6 +107,7 @@ def shape(self):
@property
def data(self):
"""Data array."""
self.increment_dat_version()
if len(self._data) == 0:
raise RuntimeError("Illegal access: No data associated with this Global!")
return self._data
Expand All @@ -115,12 +119,13 @@ def dtype(self):
@property
def data_ro(self):
"""Data array."""
view = self.data.view()
view = self._data.view()
view.setflags(write=False)
return view

@data.setter
def data(self, value):
self.increment_dat_version()
self._data[:] = utils.verify_reshape(value, self.dtype, self.dim)

@property
Expand Down Expand Up @@ -153,6 +158,7 @@ def copy(self, other, subset=None):
@mpi.collective
def zero(self, subset=None):
assert subset is None
self.increment_dat_version()
self._data[...] = 0

@mpi.collective
Expand Down Expand Up @@ -282,11 +288,6 @@ def vec_context(self, access):
"""A context manager for a :class:`PETSc.Vec` from a :class:`Global`.

:param access: Access descriptor: READ, WRITE, or RW."""
# PETSc Vecs have a state counter and cache norm computations
# to return immediately if the state counter is unchanged.
# Since we've updated the data behind their back, we need to
# change that state counter.
self._vec.stateIncrease()
yield self._vec
if access is not Access.READ:
data = self._data
Expand Down
10 changes: 10 additions & 0 deletions pyop2/types/mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,12 @@ def __iter__(self):
"""Iterate over all :class:`Mat` blocks by row and then by column."""
yield from itertools.chain(*self.blocks)

@property
def dat_version(self):
if self.assembly_state != Mat.ASSEMBLED:
raise RuntimeError("Should not ask for state counter if the matrix is not assembled.")
return self.handle.stateGet()
nbouziani marked this conversation as resolved.
Show resolved Hide resolved

@mpi.collective
def zero(self):
"""Zero the matrix."""
Expand Down Expand Up @@ -936,6 +942,10 @@ def __init__(self, parent, i, j):
self.comm = parent.comm
self.local_to_global_maps = self.handle.getLGMap()

@property
def dat_version(self):
return self.handle.stateGet()

@utils.cached_property
def _kernel_args_(self):
return (self.handle.handle, )
Expand Down
4 changes: 2 additions & 2 deletions test/unit/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,8 +832,8 @@ def test_dat_zero_cdim(self, set):
dset = set**0
d = op2.Dat(dset)
assert d.shape == (set.total_size, 0)
assert d.data.size == 0
assert d.data.shape == (set.total_size, 0)
assert d._data.size == 0
assert d._data.shape == (set.total_size, 0)


class TestMixedDatAPI:
Expand Down
86 changes: 86 additions & 0 deletions test/unit/test_dats.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,92 @@ def test_dat_save_and_load(self, tmpdir, d1, s, mdat):
mdat2.load(output)
assert all(all(d.data_ro == d_.data_ro) for d, d_ in zip(mdat, mdat2))

def test_dat_version(self, s, d1):
"""Check object versioning for Dat"""
d2 = op2.Dat(s)

assert d1.dat_version == 0
assert d2.dat_version == 0

# Access data property
d1.data

assert d1.dat_version == 1
assert d2.dat_version == 0

# Access data property
d2.data[:] += 1

assert d1.dat_version == 1
assert d2.dat_version == 1

# Access zero property
d1.zero()

assert d1.dat_version == 2
assert d2.dat_version == 1

# Copy d2 into d1
d2.copy(d1)

assert d1.dat_version == 3
assert d2.dat_version == 1

# Context managers (without changing d1 and d2)
with d1.vec_wo as _:
pass

with d2.vec as _:
pass

# Dat version shouldn't change as we are just calling the context manager
# and not changing the Dat objects.
assert d1.dat_version == 3
assert d2.dat_version == 1

# Context managers (modify d1 and d2)
with d1.vec_wo as x:
x += 1

with d2.vec as x:
x += 1

assert d1.dat_version == 4
assert d2.dat_version == 2

def test_mixed_dat_version(self, s, d1, mdat):
"""Check object versioning for MixedDat"""
d2 = op2.Dat(s)
mdat2 = op2.MixedDat([d1, d2])

assert mdat.dat_version == 0
assert mdat2.dat_version == 0

# Access data property
mdat2.data

# mdat2.data will call d1.data and d2.data
assert d1.dat_version == 1
assert d2.dat_version == 1
assert mdat.dat_version == 2
assert mdat2.dat_version == 2

# Access zero property
mdat.zero()

# mdat.zero() will call d1.zero() twice
assert d1.dat_version == 3
assert d2.dat_version == 1
assert mdat.dat_version == 6
assert mdat2.dat_version == 4

# Access zero property
d1.zero()

assert d1.dat_version == 4
assert mdat.dat_version == 8
assert mdat2.dat_version == 5


if __name__ == '__main__':
import os
Expand Down
32 changes: 32 additions & 0 deletions test/unit/test_globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,35 @@ def test_global_operations():
assert (g1 * g2).data == 10.
g1 *= g2
assert g1.data == 10.


def test_global_dat_version():
g1 = op2.Global(1, data=1.)
g2 = op2.Global(1, data=2.)

assert g1.dat_version == 0
assert g2.dat_version == 0

# Access data property
d1 = g1.data

assert g1.dat_version == 1
assert g2.dat_version == 0

# Access data property
g2.data[:] += 1

assert g1.dat_version == 1
assert g2.dat_version == 1

# Access zero property
g1.zero()

assert g1.dat_version == 2
assert g2.dat_version == 1

# Access data setter
g2.data = d1

assert g1.dat_version == 2
assert g2.dat_version == 2