From abdee0b396ec61cefa611f3dd4cdb1001204d57c Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 23 May 2024 13:16:57 -0400 Subject: [PATCH] dsl: rework sparse subfunctions --- devito/operator/operator.py | 10 ++----- devito/types/sparse.py | 40 +++++++++++++++----------- examples/seismic/inversion/fwi.py | 9 +++--- examples/seismic/test_seismic_utils.py | 6 +++- examples/seismic/utils.py | 26 ++++------------- tests/test_sparse.py | 3 +- 6 files changed, 42 insertions(+), 52 deletions(-) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 9c4f20031e0..c9e84ac1c05 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -27,7 +27,7 @@ from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_tuple, flatten, filter_sorted, frozendict, is_integer, split, timed_pass, timed_region, contains_val) -from devito.types import Grid, Evaluable, SubFunction +from devito.types import Grid, Evaluable __all__ = ['Operator'] @@ -659,14 +659,8 @@ def _postprocess_errors(self, retval): def _postprocess_arguments(self, args, **kwargs): """Process runtime arguments upon returning from ``.apply()``.""" - pnames = {p.name for p in self.parameters} for p in self.parameters: - try: - subfuncs = (args[getattr(p, s).name] for s in p._sub_functions) - p._arg_apply(args[p.name], *subfuncs, alias=kwargs.get(p.name)) - except AttributeError: - if not (isinstance(p, SubFunction) and p.parent in self.parameters): - p._arg_apply(args[p.name], alias=kwargs.get(p.name)) + p._arg_apply(args[p.name], alias=kwargs.get(p.name)) @cached_property def _known_arguments(self): diff --git a/devito/types/sparse.py b/devito/types/sparse.py index 3ffe03a4ac7..abd60bdf734 100644 --- a/devito/types/sparse.py +++ b/devito/types/sparse.py @@ -29,6 +29,14 @@ _default_radius = {'linear': 1, 'sinc': 4} +class SparseSubFunction(SubFunction): + + def _arg_apply(self, dataobj, **kwargs): + if self.parent is not None: + return self.parent._dist_subfunc_gather(dataobj, self) + return super()._arg_apply(dataobj, **kwargs) + + class AbstractSparseFunction(DiscreteFunction): """ @@ -58,8 +66,13 @@ def __init_finalize__(self, *args, **kwargs): @classmethod def __indices_setup__(cls, *args, **kwargs): dimensions = as_tuple(kwargs.get('dimensions')) + # Need this not to break MatrixSparseFunction + try: + _sub_funcs = tuple(cls._sub_functions) + except TypeError: + _sub_funcs = () # If a subfunction provided use the sparse dimension - for f in cls._sub_functions: + for f in _sub_funcs: if f in kwargs: try: sparse_dim = kwargs[f].indices[0] @@ -128,18 +141,19 @@ def __subfunc_setup__(self, key, suffix, dtype=None): key = np.array(key) # Check if already a SubFunction - d = self._sparse_dim + d = self.indices[self._sparse_position] if isinstance(key, SubFunction): - if d in key.dimensions: + if d in key.indices: # Can use as is, dimension already matches - return key + if self.alias: + return key._rebuild(alias=self.alias, name=name) + else: + return key else: # Need to rebuild so the dimensions match the parent # SparseFunction, for example we end up here via `.subs(d, new_d)` - print("rebuilding") indices = (d, *key.indices[1:]) - return key._rebuild(*indices, name=name, shape=shape, - alias=self.alias, halo=None) + return key._rebuild(*indices, name=name, alias=self.alias) # Given an array or nothing, create dimension and SubFunction if key is not None: @@ -170,7 +184,7 @@ def __subfunc_setup__(self, key, suffix, dtype=None): else: dtype = dtype or self.dtype - sf = SubFunction( + sf = SparseSubFunction( name=name, dtype=dtype, dimensions=dimensions, shape=shape, space_order=0, initializer=key, alias=self.alias, distributor=self._distributor, parent=self @@ -597,12 +611,6 @@ def _dist_scatter(self, data=None): mapper.update(self._dist_subfunc_scatter(getattr(self, i))) return mapper - def _dist_gather(self, data, *subfunc): - self._dist_data_gather(data) - for (sg, s) in zip(subfunc, self._sub_functions): - if getattr(self, s) is not None: - self._dist_subfunc_gather(sg, getattr(self, s)) - def _eval_at(self, func): return self @@ -650,11 +658,11 @@ def _arg_values(self, **kwargs): return values - def _arg_apply(self, dataobj, *subfuncs, alias=None): + def _arg_apply(self, dataobj, alias=None): key = alias if alias is not None else self if isinstance(key, AbstractSparseFunction): # Gather into `self.data` - key._dist_gather(dataobj, *subfuncs) + key._dist_data_gather(dataobj) elif self._distributor.nprocs > 1: raise NotImplementedError("Don't know how to gather data from an " "object of type `%s`" % type(key)) diff --git a/examples/seismic/inversion/fwi.py b/examples/seismic/inversion/fwi.py index 0298aac1309..831b981763f 100644 --- a/examples/seismic/inversion/fwi.py +++ b/examples/seismic/inversion/fwi.py @@ -2,7 +2,7 @@ from devito import configuration, Function, norm, mmax, mmin -from examples.seismic import demo_model, AcquisitionGeometry, Receiver +from examples.seismic import demo_model, AcquisitionGeometry from examples.seismic.acoustic import AcousticWaveSolver from inversion_utils import compute_residual, update_with_box @@ -57,11 +57,12 @@ # Create placeholders for the data residual and data residual = geometry.new_rec(name='residual') -d_obs = geometry.new_rec(name='d_obs') -d_syn = geometry.new_rec(name='d_syn') +d_obs = geometry.new_rec(name='d_obs', coordinates=residual.coordinates) +d_syn = geometry.new_rec(name='d_syn', coordinates=residual.coordinates) src = solver.geometry.src + def fwi_gradient(vp_in): # Create symbols to hold the gradient grad = Function(name="grad", grid=model.grid) @@ -74,7 +75,7 @@ def fwi_gradient(vp_in): solver.forward(vp=model.vp, rec=d_obs, src=src) # Compute smooth data and full forward wavefield u0 - _, u0, _ = solver.forward(vp=vp_in, save=True, rec=d_syn) + _, u0, _ = solver.forward(vp=vp_in, save=True, rec=d_syn, src=src) # Compute gradient from data residual and update objective function compute_residual(residual, d_obs, d_syn) diff --git a/examples/seismic/test_seismic_utils.py b/examples/seismic/test_seismic_utils.py index aea8721a6bf..f02ae48cb94 100644 --- a/examples/seismic/test_seismic_utils.py +++ b/examples/seismic/test_seismic_utils.py @@ -97,5 +97,9 @@ def test_geom(shape): src1 = geometry.src src2 = geometry.src - assert src1.coordinates is src2.coordinates + assert src1.coordinates is not src2.coordinates assert src1._sparse_dim is src2._sparse_dim + + src3 = geometry.new_src(name="src3", coordinates=src1.coordinates) + assert src1.coordinates is src3.coordinates + assert src1._sparse_dim is src3._sparse_dim diff --git a/examples/seismic/utils.py b/examples/seismic/utils.py index 9acb4baed36..6491d9ca5ae 100644 --- a/examples/seismic/utils.py +++ b/examples/seismic/utils.py @@ -159,37 +159,24 @@ def r(self): def interpolation(self): return self._interpolation - @property - def src_coords(self): - return self._src_coordinates - - @property - def rec_coords(self): - return self._rec_coordinates - @property def rec(self): return self.new_rec() - self._rec_coordinates = rec.coordinates - return rec - def new_rec(self, name='rec'): - coords = self.rec_coords or self.rec_positions + def new_rec(self, name='rec', coordinates=None): + coords = coordinates or self.rec_positions rec = Receiver(name=name, grid=self.grid, time_range=self.time_axis, npoint=self.nrec, interpolation=self.interpolation, r=self._r, coordinates=coords) - if self.rec_coords is None: - self._rec_coordinates = rec.coordinates - return rec @property def adj_src(self): if self.src_type is None: return self.new_rec() - coords = self.rec_coords or self.rec_positions + coords = self.rec_positions adj_src = sources[self.src_type](name='rec', grid=self.grid, f0=self.f0, time_range=self.time_axis, npoint=self.nrec, interpolation=self.interpolation, r=self._r, @@ -203,8 +190,8 @@ def adj_src(self): def src(self): return self.new_src() - def new_src(self, name='src', src_type='self'): - coords = self.src_coords or self.src_positions + def new_src(self, name='src', src_type='self', coordinates=None): + coords = coordinates or self.src_positions if self.src_type is None or src_type is None: warning("No source type defined, returning uninitiallized (zero) source") src = PointSource(name=name, grid=self.grid, @@ -218,9 +205,6 @@ def new_src(self, name='src', src_type='self'): t0=self._t0w, a=self._a, interpolation=self.interpolation, r=self._r) - if self.src_coords is None: - self._src_coordinates = src.coordinates - return src diff --git a/tests/test_sparse.py b/tests/test_sparse.py index 136d1ef9e34..7c94de69cee 100644 --- a/tests/test_sparse.py +++ b/tests/test_sparse.py @@ -423,8 +423,7 @@ def test_rebuild(self, sptype): # Check new subfunction for subf in sp2._sub_functions: if getattr(sp2, subf) is not None: - assert getattr(sp2, subf).name.startswith("sr_") - assert np.all(getattr(sp2, subf).data == 0) + assert getattr(sp2, subf) == getattr(sp, subf) # Rebuild with different name as an alias sp2 = sp._rebuild(name="sr2", alias=True)