Skip to content

Commit

Permalink
dsl: rework sparse subfunctions
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed May 24, 2024
1 parent 67ce156 commit abdee0b
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 52 deletions.
10 changes: 2 additions & 8 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_tuple, flatten,
filter_sorted, frozendict, is_integer, split, timed_pass,
timed_region, contains_val)
from devito.types import Grid, Evaluable, SubFunction
from devito.types import Grid, Evaluable

__all__ = ['Operator']

Expand Down Expand Up @@ -659,14 +659,8 @@ def _postprocess_errors(self, retval):

def _postprocess_arguments(self, args, **kwargs):
"""Process runtime arguments upon returning from ``.apply()``."""
pnames = {p.name for p in self.parameters}
for p in self.parameters:
try:
subfuncs = (args[getattr(p, s).name] for s in p._sub_functions)
p._arg_apply(args[p.name], *subfuncs, alias=kwargs.get(p.name))
except AttributeError:
if not (isinstance(p, SubFunction) and p.parent in self.parameters):
p._arg_apply(args[p.name], alias=kwargs.get(p.name))
p._arg_apply(args[p.name], alias=kwargs.get(p.name))

@cached_property
def _known_arguments(self):
Expand Down
40 changes: 24 additions & 16 deletions devito/types/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@
_default_radius = {'linear': 1, 'sinc': 4}


class SparseSubFunction(SubFunction):

def _arg_apply(self, dataobj, **kwargs):
if self.parent is not None:
return self.parent._dist_subfunc_gather(dataobj, self)
return super()._arg_apply(dataobj, **kwargs)


class AbstractSparseFunction(DiscreteFunction):

"""
Expand Down Expand Up @@ -58,8 +66,13 @@ def __init_finalize__(self, *args, **kwargs):
@classmethod
def __indices_setup__(cls, *args, **kwargs):
dimensions = as_tuple(kwargs.get('dimensions'))
# Need this not to break MatrixSparseFunction
try:
_sub_funcs = tuple(cls._sub_functions)
except TypeError:
_sub_funcs = ()
# If a subfunction provided use the sparse dimension
for f in cls._sub_functions:
for f in _sub_funcs:
if f in kwargs:
try:
sparse_dim = kwargs[f].indices[0]
Expand Down Expand Up @@ -128,18 +141,19 @@ def __subfunc_setup__(self, key, suffix, dtype=None):
key = np.array(key)

# Check if already a SubFunction
d = self._sparse_dim
d = self.indices[self._sparse_position]
if isinstance(key, SubFunction):
if d in key.dimensions:
if d in key.indices:
# Can use as is, dimension already matches
return key
if self.alias:
return key._rebuild(alias=self.alias, name=name)
else:
return key
else:
# Need to rebuild so the dimensions match the parent
# SparseFunction, for example we end up here via `.subs(d, new_d)`
print("rebuilding")
indices = (d, *key.indices[1:])
return key._rebuild(*indices, name=name, shape=shape,
alias=self.alias, halo=None)
return key._rebuild(*indices, name=name, alias=self.alias)

# Given an array or nothing, create dimension and SubFunction
if key is not None:
Expand Down Expand Up @@ -170,7 +184,7 @@ def __subfunc_setup__(self, key, suffix, dtype=None):
else:
dtype = dtype or self.dtype

sf = SubFunction(
sf = SparseSubFunction(
name=name, dtype=dtype, dimensions=dimensions,
shape=shape, space_order=0, initializer=key, alias=self.alias,
distributor=self._distributor, parent=self
Expand Down Expand Up @@ -597,12 +611,6 @@ def _dist_scatter(self, data=None):
mapper.update(self._dist_subfunc_scatter(getattr(self, i)))
return mapper

def _dist_gather(self, data, *subfunc):
self._dist_data_gather(data)
for (sg, s) in zip(subfunc, self._sub_functions):
if getattr(self, s) is not None:
self._dist_subfunc_gather(sg, getattr(self, s))

def _eval_at(self, func):
return self

Expand Down Expand Up @@ -650,11 +658,11 @@ def _arg_values(self, **kwargs):

return values

def _arg_apply(self, dataobj, *subfuncs, alias=None):
def _arg_apply(self, dataobj, alias=None):
key = alias if alias is not None else self
if isinstance(key, AbstractSparseFunction):
# Gather into `self.data`
key._dist_gather(dataobj, *subfuncs)
key._dist_data_gather(dataobj)
elif self._distributor.nprocs > 1:
raise NotImplementedError("Don't know how to gather data from an "
"object of type `%s`" % type(key))
Expand Down
9 changes: 5 additions & 4 deletions examples/seismic/inversion/fwi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from devito import configuration, Function, norm, mmax, mmin

from examples.seismic import demo_model, AcquisitionGeometry, Receiver
from examples.seismic import demo_model, AcquisitionGeometry
from examples.seismic.acoustic import AcousticWaveSolver

from inversion_utils import compute_residual, update_with_box
Expand Down Expand Up @@ -57,11 +57,12 @@

# Create placeholders for the data residual and data
residual = geometry.new_rec(name='residual')
d_obs = geometry.new_rec(name='d_obs')
d_syn = geometry.new_rec(name='d_syn')
d_obs = geometry.new_rec(name='d_obs', coordinates=residual.coordinates)
d_syn = geometry.new_rec(name='d_syn', coordinates=residual.coordinates)

src = solver.geometry.src


def fwi_gradient(vp_in):
# Create symbols to hold the gradient
grad = Function(name="grad", grid=model.grid)
Expand All @@ -74,7 +75,7 @@ def fwi_gradient(vp_in):
solver.forward(vp=model.vp, rec=d_obs, src=src)

# Compute smooth data and full forward wavefield u0
_, u0, _ = solver.forward(vp=vp_in, save=True, rec=d_syn)
_, u0, _ = solver.forward(vp=vp_in, save=True, rec=d_syn, src=src)

# Compute gradient from data residual and update objective function
compute_residual(residual, d_obs, d_syn)
Expand Down
6 changes: 5 additions & 1 deletion examples/seismic/test_seismic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,5 +97,9 @@ def test_geom(shape):

src1 = geometry.src
src2 = geometry.src
assert src1.coordinates is src2.coordinates
assert src1.coordinates is not src2.coordinates
assert src1._sparse_dim is src2._sparse_dim

src3 = geometry.new_src(name="src3", coordinates=src1.coordinates)
assert src1.coordinates is src3.coordinates
assert src1._sparse_dim is src3._sparse_dim
26 changes: 5 additions & 21 deletions examples/seismic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,37 +159,24 @@ def r(self):
def interpolation(self):
return self._interpolation

@property
def src_coords(self):
return self._src_coordinates

@property
def rec_coords(self):
return self._rec_coordinates

@property
def rec(self):
return self.new_rec()
self._rec_coordinates = rec.coordinates
return rec

def new_rec(self, name='rec'):
coords = self.rec_coords or self.rec_positions
def new_rec(self, name='rec', coordinates=None):
coords = coordinates or self.rec_positions
rec = Receiver(name=name, grid=self.grid,
time_range=self.time_axis, npoint=self.nrec,
interpolation=self.interpolation, r=self._r,
coordinates=coords)

if self.rec_coords is None:
self._rec_coordinates = rec.coordinates

return rec

@property
def adj_src(self):
if self.src_type is None:
return self.new_rec()
coords = self.rec_coords or self.rec_positions
coords = self.rec_positions
adj_src = sources[self.src_type](name='rec', grid=self.grid, f0=self.f0,
time_range=self.time_axis, npoint=self.nrec,
interpolation=self.interpolation, r=self._r,
Expand All @@ -203,8 +190,8 @@ def adj_src(self):
def src(self):
return self.new_src()

def new_src(self, name='src', src_type='self'):
coords = self.src_coords or self.src_positions
def new_src(self, name='src', src_type='self', coordinates=None):
coords = coordinates or self.src_positions
if self.src_type is None or src_type is None:
warning("No source type defined, returning uninitiallized (zero) source")
src = PointSource(name=name, grid=self.grid,
Expand All @@ -218,9 +205,6 @@ def new_src(self, name='src', src_type='self'):
t0=self._t0w, a=self._a,
interpolation=self.interpolation, r=self._r)

if self.src_coords is None:
self._src_coordinates = src.coordinates

return src


Expand Down
3 changes: 1 addition & 2 deletions tests/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,7 @@ def test_rebuild(self, sptype):
# Check new subfunction
for subf in sp2._sub_functions:
if getattr(sp2, subf) is not None:
assert getattr(sp2, subf).name.startswith("sr_")
assert np.all(getattr(sp2, subf).data == 0)
assert getattr(sp2, subf) == getattr(sp, subf)

# Rebuild with different name as an alias
sp2 = sp._rebuild(name="sr2", alias=True)
Expand Down

0 comments on commit abdee0b

Please sign in to comment.