Skip to content

Commit

Permalink
algorithms: Update Fns on SubDomain's indexing here + bug fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
rhodrin committed Jul 29, 2020
1 parent 36a4619 commit 980e8c4
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 20 deletions.
4 changes: 4 additions & 0 deletions devito/ir/equations/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def lower_exprs(expressions, **kwargs):

processed = []
for expr in as_tuple(expressions):
# Update access maps for `Function`'s defined on a `SubDomain`
fosd = [f for f in retrieve_functions(expr, mode='unique')
if f.is_Function and f._subdomain]
expr = expr.subs({f: f.subs(f._subdomain._access_map) for f in fosd})
try:
dimension_map = expr.subdomain.dimension_map
except AttributeError:
Expand Down
13 changes: 1 addition & 12 deletions devito/types/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class DiscreteFunction(AbstractFunction, ArgProvider, Differentiable):
# its key routines (e.g., solve)
_iterable = False

is_Function = False
is_Input = True
is_DiscreteFunction = True
is_Tensor = True
Expand Down Expand Up @@ -970,18 +971,6 @@ def __init_finalize__(self, *args, **kwargs):
# parameter has to be computed at x + hx/2)
self._is_parameter = kwargs.get('parameter', False)

# TODO: Review/tidy new properties
@cached_property
def on_subdomain(self):
return bool(self._subdomain)

@cached_property
def _domain(self):
""" Shortcut to the computational domain on which the function
is defined """
# TODO: Add sanity check here
return self._subdomain if self._subdomain else self.grid

@cached_property
def _fd_priority(self):
return 1 if self.staggered in [NODE, None] else 2
Expand Down
16 changes: 11 additions & 5 deletions devito/types/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,9 +395,12 @@ def __subdomain_finalize__(self, dimensions, shape, distributor=None, **kwargs):
# Derive the local shape for `SubDomain`'s on distributed grids along with the
# memory access map for any `Function` defined on this `SubDomain`.
access_map = {}
shift = {}
shape_local = []
for dim, d, s in zip(sub_dimensions, distributor.decomposition, self._shape):
if dim.is_Sub:
c_name = 'c_%s' % dim.name
shift[c_name] = Constant(name=c_name, dtype=np.int32)
if dim._local:
if distributor and distributor.is_parallel:
if dim.thickness.right[1] == 0:
Expand All @@ -422,12 +425,14 @@ def __subdomain_finalize__(self, dimensions, shape, distributor=None, **kwargs):
if r is None:
r = 0
shape_local.append(ls-l-r)
access_map.update({dim: dim-l if l else dim})
shift[c_name].data = l
access_map.update({dim: dim-shift[c_name]})
else:
if dim.thickness.left[1] == 0:
access_map.update({dim: dim-(s-dim.thickness.right[1])})
shift[c_name].data = (s-dim.thickness.right[1])
else:
access_map.update({dim: dim})
shift[c_name].data = 0
access_map.update({dim: dim-shift[c_name]})
shape_local.append(s)
else:
if distributor and distributor.is_parallel:
Expand All @@ -448,10 +453,11 @@ def __subdomain_finalize__(self, dimensions, shape, distributor=None, **kwargs):
if r is None:
r = 0
shape_local.append(ls-l-r)
access_map.update({dim: dim-l if l else dim})
shift[c_name].data = l
else:
access_map.update({dim: dim-dim.thickness.left[1]})
shift[c_name].data = dim.thickness.left[1]
shape_local.append(s)
access_map.update({dim: dim-shift[c_name]})
else:
shape_local.append(len(d.loc_abs_numb))
access_map.update({dim: dim})
Expand Down
2 changes: 1 addition & 1 deletion tests/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ def test_solve(self, operate_on_empty_cache):
# to u(t + dt), u(x - h_x) and u(x + h_x) that have to be cleared.
# Then `u` points to the various Dimensions, the Dimensions point to the various
# spacing symbols, hence, we need four sweeps to clear up the cache.
assert len(_SymbolCache) == 16
assert len(_SymbolCache) == 17
clear_cache()
assert len(_SymbolCache) == 9
clear_cache()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_subdomains.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def define(self, dimensions):

grid = Grid(shape=(10, 10), extent=(9., 9.), subdomains=(mid, ))
f = Function(name='f', grid=grid, subdomain=grid.subdomains['middle'])
eq = Eq(f, f+1).subs(f._domain._access_map)
eq = Eq(f, f+1)

assert(f.shape == grid.subdomains['middle'].shape)

Expand All @@ -406,7 +406,7 @@ def define(self, dimensions):

grid = Grid(shape=(10, 10), extent=(9., 9.), subdomains=(mid, ))
f = Function(name='f', grid=grid, subdomain=grid.subdomains['middle'])
eq = Eq(f, f+1).subs(f._domain._access_map)
eq = Eq(f, f+1)

assert(f.shape == grid.subdomains['middle'].shape_local)

Expand Down

0 comments on commit 980e8c4

Please sign in to comment.