From d4abfcd212f080b214dfa227c8512b79f207cdeb Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 7 Mar 2024 21:50:23 +0000 Subject: [PATCH] compiler: Add optional pass to check stability --- devito/core/cpu.py | 9 ++++- devito/core/gpu.py | 9 ++++- devito/core/operator.py | 10 +++++ devito/passes/iet/errors.py | 78 ++++++++++++++++++++++++++++++++++-- tests/test_error_checking.py | 21 ++++++++++ 5 files changed, 119 insertions(+), 8 deletions(-) create mode 100644 tests/test_error_checking.py diff --git a/devito/core/cpu.py b/devito/core/cpu.py index 81611da9b49..1d343146974 100644 --- a/devito/core/cpu.py +++ b/devito/core/cpu.py @@ -6,8 +6,9 @@ from devito.passes.clusters import (Lift, blocking, buffering, cire, cse, factorize, fission, fuse, optimize_pows, optimize_hyperplanes) -from devito.passes.iet import (CTarget, OmpTarget, avoid_denormals, linearize, mpiize, - hoist_prodders, relax_incr_dimensions) +from devito.passes.iet import (CTarget, OmpTarget, avoid_denormals, linearize, + mpiize, hoist_prodders, relax_incr_dimensions, + check_stability) from devito.tools import timed_pass __all__ = ['Cpu64NoopCOperator', 'Cpu64NoopOmpOperator', 'Cpu64AdvCOperator', @@ -76,6 +77,7 @@ def _normalize_kwargs(cls, **kwargs): o['mapify-reduce'] = oo.pop('mapify-reduce', cls.MAPIFY_REDUCE) o['index-mode'] = oo.pop('index-mode', cls.INDEX_MODE) o['place-transfers'] = oo.pop('place-transfers', True) + o['errctl'] = oo.pop('errctl', cls.ERRCTL) # Recognised but unused by the CPU backend oo.pop('par-disabled', None) @@ -189,6 +191,9 @@ def _specialize_iet(cls, graph, **kwargs): # Misc optimizations hoist_prodders(graph) + # Perform error checking + check_stability(graph, **kwargs) + # Symbol definitions cls._Target.DataManager(**kwargs).process(graph) diff --git a/devito/core/gpu.py b/devito/core/gpu.py index de070c68900..a0c2da774a4 100644 --- a/devito/core/gpu.py +++ b/devito/core/gpu.py @@ -10,8 +10,9 @@ from devito.passes.clusters import (Lift, Streaming, Tasker, blocking, buffering, cire, cse, factorize, fission, fuse, optimize_pows) -from devito.passes.iet import (DeviceOmpTarget, DeviceAccTarget, mpiize, hoist_prodders, - linearize, pthreadify, relax_incr_dimensions) +from devito.passes.iet import (DeviceOmpTarget, DeviceAccTarget, mpiize, + hoist_prodders, linearize, pthreadify, + relax_incr_dimensions, check_stability) from devito.tools import as_tuple, timed_pass __all__ = ['DeviceNoopOperator', 'DeviceAdvOperator', 'DeviceCustomOperator', @@ -91,6 +92,7 @@ def _normalize_kwargs(cls, **kwargs): o['mapify-reduce'] = oo.pop('mapify-reduce', cls.MAPIFY_REDUCE) o['index-mode'] = oo.pop('index-mode', cls.INDEX_MODE) o['place-transfers'] = oo.pop('place-transfers', True) + o['errctl'] = oo.pop('errctl', cls.ERRCTL) if oo: raise InvalidOperator("Unsupported optimization options: [%s]" @@ -226,6 +228,9 @@ def _specialize_iet(cls, graph, **kwargs): # Misc optimizations hoist_prodders(graph) + # Perform error checking + check_stability(graph, **kwargs) + # Symbol definitions cls._Target.DataManager(**kwargs).process(graph) diff --git a/devito/core/operator.py b/devito/core/operator.py index 1ba976d3b68..ff3630a83c5 100644 --- a/devito/core/operator.py +++ b/devito/core/operator.py @@ -119,6 +119,13 @@ class BasicOperator(Operator): (default) or `int32`. """ + ERRCTL = 'basic' + """ + Runtime error checking. If this option is enabled, the generated code will + include runtime checks for various things that might go south, such as + instability (e.g., NaNs), failed library calls (e.g., kernel launches). + """ + _Target = None """ The target language constructor, to be specified by subclasses. @@ -155,6 +162,9 @@ def _check_kwargs(cls, **kwargs): if oo['deriv-unroll'] not in (False, 'inner', 'full'): raise InvalidArgument("Illegal `deriv-unroll` value") + if oo['errctl'] not in (None, False, 'basic', 'max'): + raise InvalidArgument("Illegal `errctl` value") + def _autotune(self, args, setup): if setup in [False, 'off']: return args diff --git a/devito/passes/iet/errors.py b/devito/passes/iet/errors.py index 440b92762c5..a3810c72eb3 100644 --- a/devito/passes/iet/errors.py +++ b/devito/passes/iet/errors.py @@ -1,16 +1,86 @@ +import cgen as c +from sympy import Not + +from devito.finite_differences import Abs +from devito.finite_differences.differentiable import Pow +from devito.ir.iet import (Call, Conditional, EntryFunction, Iteration, List, + Return, FindNodes, FindSymbols, Transformer, + make_callable) from devito.passes.iet.engine import iet_pass +from devito.symbolics import CondEq, DefFunction +from devito.tools import dtype_to_cstr +from devito.types import Eq, Inc, Symbol __all__ = ['check_stability', 'error_mapper'] -@iet_pass -def check_stability(iet, **kwargs): +def check_stability(graph, options=None, rcompile=None, sregistry=None, **kwargs): """ Check if the simulation is stable. If not, return to Python as quickly as possible with an error code. """ - # TODO - return iet, {} + if options['errctl'] != 'max': + return + + _, wmovs = graph.data_movs + + _check_stability(graph, wmovs=wmovs, rcompile=rcompile, sregistry=sregistry) + + +@iet_pass +def _check_stability(iet, wmovs=(), rcompile=None, sregistry=None): + if not isinstance(iet, EntryFunction): + return iet, {} + + # NOTE: Stability is a domain-specific concept, hence looking for time + # Iterations and TimeFunctions is acceptable + efuncs = [] + includes = [] + mapper = {} + for n in FindNodes(Iteration).visit(iet): + if not n.dim.is_Time: + continue + + functions = [f for f in FindSymbols().visit(n) + if f.is_TimeFunction and f.time_dim.is_Stepping] + + # We compute the norm of just one TimeFunction, hence we sort for + # determinism and reproducibility + candidates = sorted(set(functions) & set(wmovs), key=lambda f: f.name) + for f in candidates: + if f in wmovs: + break + else: + continue + + name = sregistry.make_name(prefix='energy') + energy = Symbol(name=name, dtype=f.dtype) + + eqns = [Eq(energy, 0.0), + Inc(energy, Abs(Pow(f.subs(f.time_dim, 0), 2)))] + irs, byproduct = rcompile(eqns) + body = irs.iet.body.body + (Return(energy),) + + name = sregistry.make_name(prefix='compute_energy') + retval = dtype_to_cstr(energy.dtype) + efunc = make_callable(name, body, retval=retval) + + efuncs.extend([i.root for i in byproduct.funcs]) + efuncs.append(efunc) + + includes.extend(byproduct.includes) + + errctl = Conditional(CondEq(n.dim % 100, 0), List(body=[ + Call(efunc.name, efunc.parameters, retobj=energy), + Conditional(Not(DefFunction('isfinite', energy)), + Return(error_mapper['Stability'])) + ])) + errctl = List(header=c.Comment("Stability check"), body=[errctl]) + mapper[n] = n._rebuild(nodes=n.nodes + (errctl,)) + + iet = Transformer(mapper).visit(iet) + + return iet, {'efuncs': efuncs, 'includes': includes} error_mapper = { diff --git a/tests/test_error_checking.py b/tests/test_error_checking.py new file mode 100644 index 00000000000..5a3929f10cc --- /dev/null +++ b/tests/test_error_checking.py @@ -0,0 +1,21 @@ +import pytest + +from devito import Grid, Function, TimeFunction, Eq, Operator, switchconfig +from devito.exceptions import ExecutionError + + +@switchconfig(safe_math=True) +def test_stability(): + grid = Grid(shape=(10, 10)) + + f = Function(name='f', grid=grid, space_order=2) + u = TimeFunction(name='u', grid=grid, space_order=2) + + eq = Eq(u.forward, u/f) + + op = Operator(eq, opt=('advanced', {'errctl': 'max'})) + + u.data[:] = 1. + + with pytest.raises(ExecutionError): + op.apply(time_M=200, dt=.1)