Skip to content

Commit

Permalink
compiler: Tweak stability check
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Mar 12, 2024
1 parent 8cad55c commit aaea94a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 18 deletions.
28 changes: 13 additions & 15 deletions devito/passes/iet/errors.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import cgen as c
import numpy as np
from sympy import Not

from devito.finite_differences import Abs
from devito.finite_differences.differentiable import Pow
from devito.ir.iet import (Call, Conditional, EntryFunction, Iteration, List,
Return, FindNodes, FindSymbols, Transformer,
make_callable)
from devito.passes.iet.engine import iet_pass
from devito.symbolics import CondEq, DefFunction
from devito.tools import dtype_to_cstr
from devito.types import Eq, Inc, Symbol

__all__ = ['check_stability', 'error_mapper']
Expand Down Expand Up @@ -53,27 +51,27 @@ def _check_stability(iet, wmovs=(), rcompile=None, sregistry=None):
else:
continue

name = sregistry.make_name(prefix='energy')
energy = Symbol(name=name, dtype=f.dtype)

eqns = [Eq(energy, 0.0),
Inc(energy, Abs(Pow(f.subs(f.time_dim, 0), 2)))]
accumulator = Symbol(name='accumulator', dtype=f.dtype)
eqns = [Eq(accumulator, 0.0),
Inc(accumulator, f.subs(f.time_dim, 0))]
irs, byproduct = rcompile(eqns)
body = irs.iet.body.body + (Return(energy),)

name = sregistry.make_name(prefix='compute_energy')
retval = dtype_to_cstr(energy.dtype)
efunc = make_callable(name, body, retval=retval)
name = sregistry.make_name(prefix='is_finite')
retval = Return(DefFunction('isfinite', accumulator))
body = irs.iet.body.body + (retval,)
efunc = make_callable(name, body, retval='int')

efuncs.extend([i.root for i in byproduct.funcs])
efuncs.append(efunc)

includes.extend(byproduct.includes)

name = sregistry.make_name(prefix='check')
check = Symbol(name=name, dtype=np.int32)

errctl = Conditional(CondEq(n.dim % 100, 0), List(body=[
Call(efunc.name, efunc.parameters, retobj=energy),
Conditional(Not(DefFunction('isfinite', energy)),
Return(error_mapper['Stability']))
Call(efunc.name, efunc.parameters, retobj=check),
Conditional(Not(check), Return(error_mapper['Stability']))
]))
errctl = List(header=c.Comment("Stability check"), body=[errctl])
mapper[n] = n._rebuild(nodes=n.nodes + (errctl,))
Expand Down
12 changes: 9 additions & 3 deletions tests/test_error_checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,23 @@


@switchconfig(safe_math=True)
def test_stability():
@pytest.mark.parametrize("expr", [
'u/f',
'(u + v)/f',
])
def test_stability(expr):
grid = Grid(shape=(10, 10))

f = Function(name='f', grid=grid, space_order=2)
f = Function(name='f', grid=grid, space_order=2) # noqa
u = TimeFunction(name='u', grid=grid, space_order=2)
v = TimeFunction(name='v', grid=grid, space_order=2)

eq = Eq(u.forward, u/f)
eq = Eq(u.forward, eval(expr))

op = Operator(eq, opt=('advanced', {'errctl': 'max'}))

u.data[:] = 1.
v.data[:] = 2.

with pytest.raises(ExecutionError):
op.apply(time_M=200, dt=.1)

0 comments on commit aaea94a

Please sign in to comment.