Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functions on subdomains (WIP) #1380

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Placeholder.
  • Loading branch information
rhodrin committed Sep 15, 2020
commit b00a5238dade4531b5dd8d9275a2db9ac30013d3
15 changes: 9 additions & 6 deletions devito/ir/equations/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,22 @@ def lower_exprs(expressions, **kwargs):

processed = []
for expr in as_tuple(expressions):
# Update access maps for `Function`'s defined on a `SubDomain`
fosd = [f for f in retrieve_functions(expr, mode='unique')
if f.is_Function and f._subdomain]
expr = expr.subs({f: f.subs(f._subdomain._access_map) for f in fosd})
try:
dimension_map = expr.subdomain.dimension_map
except AttributeError:
# Some Relationals may be pure SymPy objects, thus lacking the subdomain
dimension_map = {}

# Gather `Function`'s defined on a `SubDomain`
fosd = set([f for f in retrieve_functions(expr, mode='unique')
if f.is_Function and f._subdomain])

# Handle Functions (typical case)
mapper = {f: f.indexify(lshift=True, subs=dimension_map)
for f in retrieve_functions(expr)}
mapper = {**{f: f.indexify(lshift=True, subs=dimension_map)
for f in set(retrieve_functions(expr)).difference(fosd)},
**{f: f.indexify(lshift=True, subs=f._subdomain._access_map)
for f in fosd}}
from IPython import embed; embed()

# Handle Indexeds (from index notation)
for i in retrieve_indexed(expr):
Expand Down
138 changes: 135 additions & 3 deletions tests/test_subdomains.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from devito import (Grid, Function, TimeFunction, Eq, solve, Operator, SubDomain,
SubDomainSet, Dimension)
from devito.tools import timed_region
from examples.seismic import TimeAxis, RickerSource, Receiver


class TestSubdomains(object):
Expand Down Expand Up @@ -360,12 +361,12 @@ class Inner(SubDomainSet):

class TestSubdomainFunctions(object):
"""
Class for testing `Function`'s defined on `SubDomain`'s
Class for testing `Function`'s defined on `SubDomain`'s without MPI.
"""

def test_basic_function(self):
"""
Fill me in.
Test a single `Function`
"""

class Middle(SubDomain):
Expand All @@ -388,10 +389,47 @@ def define(self, dimensions):

assert(np.all(f.data[:] == 1))

def test_mixed_functions(self):
"""
Test with one Function on a `SubDomain` and one not.
"""

class Middle(SubDomain):

name = 'middle'

def define(self, dimensions):
x, y = dimensions
return {x: ('middle', 2, 2), y: ('middle', 3, 1)}

mid = Middle()

grid = Grid(shape=(10, 10), extent=(9., 9.), subdomains=(mid, ))
f = Function(name='f', grid=grid, subdomain=grid.subdomains['middle'])
g = Function(name='g', grid=grid)

assert(f.shape == grid.subdomains['middle'].shape)
assert(g.shape == grid.shape)

eq0 = Eq(f, g+f+1, subdomain=grid.subdomains['middle'])
eq1 = Eq(g, 2*f, subdomain=grid.subdomains['middle'])
eq2 = Eq(f, g+1, subdomain=grid.subdomains['middle'])

Operator([eq0, eq1, eq2])()

assert(np.all(f.data[:] == 3))
assert(np.all(g.data[2:-2, 3:-1] == 2))


class TestSubdomainFunctionsParallel(object):
"""
Class for testing `Function`'s defined on `SubDomain`'s with MPI.
"""

@pytest.mark.parallel(mode=4)
def test_mpi_function(self):
"""
Fill me in.
Test a single `Function`
"""

class Middle(SubDomain):
Expand All @@ -413,3 +451,97 @@ def define(self, dimensions):
Operator(eq)()

assert(np.all(f.data[:] == 1))

@pytest.mark.parallel(mode=4)
def test_mixed_functions_mpi(self):
"""
Test with one Function on a `SubDomain` and one not.
"""

class Middle(SubDomain):

name = 'middle'

def define(self, dimensions):
x, y = dimensions
return {x: ('middle', 2, 2), y: ('middle', 3, 1)}

mid = Middle()

grid = Grid(shape=(10, 10), extent=(9., 9.), subdomains=(mid, ))
f = Function(name='f', grid=grid, subdomain=grid.subdomains['middle'])
g = Function(name='g', grid=grid)

assert(f.shape == grid.subdomains['middle'].shape_local)
assert(g.shape == grid.shape_local)

eq0 = Eq(f, g+f+1, subdomain=grid.subdomains['middle'])
eq1 = Eq(g, 2*f, subdomain=grid.subdomains['middle'])
eq2 = Eq(f, g+1, subdomain=grid.subdomains['middle'])

Operator([eq0, eq1, eq2])()

assert(np.all(f.data[:] == 3))
assert(np.all(g.data[2:-2, 3:-1] == 2))

@pytest.mark.parallel(mode=4)
def test_acoustic_on_sd(self):

class CompDom(SubDomain):

name = 'comp_domain'

def define(self, dimensions):
x, y = dimensions
return {x: ('middle', 20, 10), y: ('middle', 20, 10)}

cdomain = CompDom()

shape = (131, 131)
extent = (1300, 1300)
origin = (200., 200.)

v = np.empty(shape, dtype=np.float32)
v[:, :71] = 1.5
v[:, 71:] = 2.5

grid = Grid(shape=shape, extent=extent, origin=origin, subdomains=(cdomain, ))

t0 = 0.
tn = 1000.
dt = 1.6
time_range = TimeAxis(start=t0, stop=tn, step=dt)

f0 = 0.010
src = RickerSource(name='src', grid=grid, f0=f0,
npoint=1, time_range=time_range)

domain_size = np.array(extent)

src.coordinates.data[0, :] = domain_size*.5
src.coordinates.data[0, -1] = 20.

rec = Receiver(name='rec', grid=grid, npoint=101, time_range=time_range)
rec.coordinates.data[:, 0] = np.linspace(0, domain_size[0], num=101)
rec.coordinates.data[:, 1] = 20.

u = TimeFunction(name="u", grid=grid, time_order=2, space_order=2,
subdomain=grid.subdomains['comp_domain'])
m = Function(name='m', grid=grid)
m.data[:] = 1./(v*v)

pde = m * u.dt2 - u.laplace
stencil = Eq(u.forward, solve(pde, u.forward),
subdomain=grid.subdomains['comp_domain'])

src_term = src.inject(field=u.forward, expr=src * dt**2 / m)
rec_term = rec.interpolate(expr=u.forward)

op = Operator([stencil] + src_term + rec_term)

# Make sure we've indeed generated OpenMP offloading code
assert 'omp target' in str(op)

op(time=time_range.num-1, dt=dt)

assert np.isclose(norm(rec), 490.55, atol=1e-2, rtol=0)