Skip to content

Commit

Permalink
compiler: Add optional pass to check stability
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Mar 11, 2024
1 parent 5df7b3d commit d4abfcd
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 8 deletions.
9 changes: 7 additions & 2 deletions devito/core/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from devito.passes.clusters import (Lift, blocking, buffering, cire, cse,
factorize, fission, fuse, optimize_pows,
optimize_hyperplanes)
from devito.passes.iet import (CTarget, OmpTarget, avoid_denormals, linearize, mpiize,
hoist_prodders, relax_incr_dimensions)
from devito.passes.iet import (CTarget, OmpTarget, avoid_denormals, linearize,
mpiize, hoist_prodders, relax_incr_dimensions,
check_stability)
from devito.tools import timed_pass

__all__ = ['Cpu64NoopCOperator', 'Cpu64NoopOmpOperator', 'Cpu64AdvCOperator',
Expand Down Expand Up @@ -76,6 +77,7 @@ def _normalize_kwargs(cls, **kwargs):
o['mapify-reduce'] = oo.pop('mapify-reduce', cls.MAPIFY_REDUCE)
o['index-mode'] = oo.pop('index-mode', cls.INDEX_MODE)
o['place-transfers'] = oo.pop('place-transfers', True)
o['errctl'] = oo.pop('errctl', cls.ERRCTL)

# Recognised but unused by the CPU backend
oo.pop('par-disabled', None)
Expand Down Expand Up @@ -189,6 +191,9 @@ def _specialize_iet(cls, graph, **kwargs):
# Misc optimizations
hoist_prodders(graph)

# Perform error checking
check_stability(graph, **kwargs)

# Symbol definitions
cls._Target.DataManager(**kwargs).process(graph)

Expand Down
9 changes: 7 additions & 2 deletions devito/core/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from devito.passes.clusters import (Lift, Streaming, Tasker, blocking, buffering,
cire, cse, factorize, fission, fuse,
optimize_pows)
from devito.passes.iet import (DeviceOmpTarget, DeviceAccTarget, mpiize, hoist_prodders,
linearize, pthreadify, relax_incr_dimensions)
from devito.passes.iet import (DeviceOmpTarget, DeviceAccTarget, mpiize,
hoist_prodders, linearize, pthreadify,
relax_incr_dimensions, check_stability)
from devito.tools import as_tuple, timed_pass

__all__ = ['DeviceNoopOperator', 'DeviceAdvOperator', 'DeviceCustomOperator',
Expand Down Expand Up @@ -91,6 +92,7 @@ def _normalize_kwargs(cls, **kwargs):
o['mapify-reduce'] = oo.pop('mapify-reduce', cls.MAPIFY_REDUCE)
o['index-mode'] = oo.pop('index-mode', cls.INDEX_MODE)
o['place-transfers'] = oo.pop('place-transfers', True)
o['errctl'] = oo.pop('errctl', cls.ERRCTL)

if oo:
raise InvalidOperator("Unsupported optimization options: [%s]"
Expand Down Expand Up @@ -226,6 +228,9 @@ def _specialize_iet(cls, graph, **kwargs):
# Misc optimizations
hoist_prodders(graph)

# Perform error checking
check_stability(graph, **kwargs)

# Symbol definitions
cls._Target.DataManager(**kwargs).process(graph)

Expand Down
10 changes: 10 additions & 0 deletions devito/core/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ class BasicOperator(Operator):
(default) or `int32`.
"""

ERRCTL = 'basic'
"""
Runtime error checking. If this option is enabled, the generated code will
include runtime checks for various things that might go south, such as
instability (e.g., NaNs), failed library calls (e.g., kernel launches).
"""

_Target = None
"""
The target language constructor, to be specified by subclasses.
Expand Down Expand Up @@ -155,6 +162,9 @@ def _check_kwargs(cls, **kwargs):
if oo['deriv-unroll'] not in (False, 'inner', 'full'):
raise InvalidArgument("Illegal `deriv-unroll` value")

if oo['errctl'] not in (None, False, 'basic', 'max'):
raise InvalidArgument("Illegal `errctl` value")

def _autotune(self, args, setup):
if setup in [False, 'off']:
return args
Expand Down
78 changes: 74 additions & 4 deletions devito/passes/iet/errors.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,86 @@
import cgen as c
from sympy import Not

from devito.finite_differences import Abs
from devito.finite_differences.differentiable import Pow
from devito.ir.iet import (Call, Conditional, EntryFunction, Iteration, List,
Return, FindNodes, FindSymbols, Transformer,
make_callable)
from devito.passes.iet.engine import iet_pass
from devito.symbolics import CondEq, DefFunction
from devito.tools import dtype_to_cstr
from devito.types import Eq, Inc, Symbol

__all__ = ['check_stability', 'error_mapper']


@iet_pass
def check_stability(iet, **kwargs):
def check_stability(graph, options=None, rcompile=None, sregistry=None, **kwargs):
"""
Check if the simulation is stable. If not, return to Python as quickly as
possible with an error code.
"""
# TODO
return iet, {}
if options['errctl'] != 'max':
return

_, wmovs = graph.data_movs

_check_stability(graph, wmovs=wmovs, rcompile=rcompile, sregistry=sregistry)


@iet_pass
def _check_stability(iet, wmovs=(), rcompile=None, sregistry=None):
if not isinstance(iet, EntryFunction):
return iet, {}

# NOTE: Stability is a domain-specific concept, hence looking for time
# Iterations and TimeFunctions is acceptable
efuncs = []
includes = []
mapper = {}
for n in FindNodes(Iteration).visit(iet):
if not n.dim.is_Time:
continue

functions = [f for f in FindSymbols().visit(n)
if f.is_TimeFunction and f.time_dim.is_Stepping]

# We compute the norm of just one TimeFunction, hence we sort for
# determinism and reproducibility
candidates = sorted(set(functions) & set(wmovs), key=lambda f: f.name)
for f in candidates:
if f in wmovs:
break
else:
continue

name = sregistry.make_name(prefix='energy')
energy = Symbol(name=name, dtype=f.dtype)

eqns = [Eq(energy, 0.0),
Inc(energy, Abs(Pow(f.subs(f.time_dim, 0), 2)))]
irs, byproduct = rcompile(eqns)
body = irs.iet.body.body + (Return(energy),)

name = sregistry.make_name(prefix='compute_energy')
retval = dtype_to_cstr(energy.dtype)
efunc = make_callable(name, body, retval=retval)

efuncs.extend([i.root for i in byproduct.funcs])
efuncs.append(efunc)

includes.extend(byproduct.includes)

errctl = Conditional(CondEq(n.dim % 100, 0), List(body=[
Call(efunc.name, efunc.parameters, retobj=energy),
Conditional(Not(DefFunction('isfinite', energy)),
Return(error_mapper['Stability']))
]))
errctl = List(header=c.Comment("Stability check"), body=[errctl])
mapper[n] = n._rebuild(nodes=n.nodes + (errctl,))

iet = Transformer(mapper).visit(iet)

return iet, {'efuncs': efuncs, 'includes': includes}


error_mapper = {
Expand Down
21 changes: 21 additions & 0 deletions tests/test_error_checking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest

from devito import Grid, Function, TimeFunction, Eq, Operator, switchconfig
from devito.exceptions import ExecutionError


@switchconfig(safe_math=True)
def test_stability():
grid = Grid(shape=(10, 10))

f = Function(name='f', grid=grid, space_order=2)
u = TimeFunction(name='u', grid=grid, space_order=2)

eq = Eq(u.forward, u/f)

op = Operator(eq, opt=('advanced', {'errctl': 'max'}))

u.data[:] = 1.

with pytest.raises(ExecutionError):
op.apply(time_M=200, dt=.1)

0 comments on commit d4abfcd

Please sign in to comment.