diff --git a/devito/ir/iet/utils.py b/devito/ir/iet/utils.py index 6908668fc6..99662ce1da 100644 --- a/devito/ir/iet/utils.py +++ b/devito/ir/iet/utils.py @@ -3,7 +3,8 @@ from devito.tools import filter_ordered from devito.types import Global -__all__ = ['filter_iterations', 'retrieve_iteration_tree', 'derive_parameters'] +__all__ = ['filter_iterations', 'retrieve_iteration_tree', 'derive_parameters', + 'maybe_alias'] class IterationTree(tuple): @@ -122,3 +123,34 @@ def derive_parameters(iet, drop_locals=False): parameters = [p for p in parameters if not (p.is_ArrayBasic or p.is_LocalObject)] return parameters + + +def maybe_alias(obj, candidate): + """ + True if `candidate` can act as an alias for `obj`, False otherwise. + """ + if obj is candidate: + return True + + # Names are unique throughout compilation, so this is another case we can handle + # straightforwardly. It might happen that we have an alias used in a subroutine + # with different type qualifiers (e.g., const vs not const, volatile vs not + # volatile), but if the names match, they definitely represent the same + # logical object + if obj.name == candidate.name: + return True + + if obj.is_AbstractFunction: + if not candidate.is_AbstractFunction: + # Obv + return False + + # E.g. TimeFunction vs SparseFunction -> False + if type(obj).__base__ is not type(candidate).__base__: + return False + + # TODO: At some point we may need to introduce some logic here, but we'll + # also need to introduce something like __eq_weak__ that compares most of + # the __rkwargs__ except for e.g. the name + + return False diff --git a/devito/passes/iet/asynchrony.py b/devito/passes/iet/asynchrony.py index 93bb66a81d..7e1ff6e36b 100644 --- a/devito/passes/iet/asynchrony.py +++ b/devito/passes/iet/asynchrony.py @@ -1,16 +1,18 @@ from collections import OrderedDict +from ctypes import c_int import cgen as c from devito.ir import (AsyncCall, AsyncCallable, BlankLine, Call, Callable, Conditional, Dereference, DummyExpr, FindNodes, FindSymbols, Iteration, List, PointerCast, Return, ThreadCallable, - Transformer, While) + Transformer, While, maybe_alias) from devito.passes.iet.engine import iet_pass from devito.symbolics import (CondEq, CondNe, FieldFromComposite, FieldFromPointer, Null) from devito.tools import DefaultOrderedDict, Bunch, split -from devito.types import Lock, Pointer, PThreadArray, QueueID, SharedData, Symbol +from devito.types import (Lock, Pointer, PThreadArray, QueueID, SharedData, Symbol, + VolatileInt) __all__ = ['pthreadify'] @@ -48,6 +50,9 @@ def lower_async_callables(iet, track=None, root=None, sregistry=None): defines = FindSymbols('defines').visit(root.body) ncfields, cfields = split(fields, lambda i: i in defines) + # Postprocess `ncfields` + ncfields = sanitize_ncfields(ncfields) + # SharedData -- that is the data structure that will be used by the # main thread to pass information down to the child thread(s) sdata = track[iet.name].sdata = SharedData(name='sdata', @@ -135,7 +140,7 @@ def lower_async_calls(iet, track=None, sregistry=None): d = threads.index arguments = [] for a in n.arguments: - if a in sdata.ncfields: + if any(maybe_alias(a, i) for i in sdata.ncfields): continue elif isinstance(a, QueueID): # Different pthreads use different queues @@ -208,3 +213,20 @@ def lower_async_calls(iet, track=None, sregistry=None): assert not finalization return iet, {'efuncs': tuple(efuncs.values())} + + +# *** Utils + +def sanitize_ncfields(ncfields): + # Due to a bug in the NVC compiler (v<=22.7 and potentially later), + # we have to use C's `volatile` more extensively than strictly necessary + # to avoid flaky optimizations that would cause fauly behaviour in rare, + # non-deterministic scenarios + sanitized = [] + for i in ncfields: + if i._C_ctype is c_int: + sanitized.append(VolatileInt(name=i.name)) + else: + sanitized.append(i) + + return sanitized diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index de5df60639..8426873260 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -182,7 +182,7 @@ def test_tasking_fused(self): exprs = FindNodes(Expression).visit(op._func_table['copy_to_host0'].root) b = 13 if configuration['language'] == 'openacc' else 12 # No `qid` w/ OMP assert str(exprs[b]) == 'const int deviceid = sdata->deviceid;' - assert str(exprs[b+1]) == 'int time = sdata->time;' + assert str(exprs[b+1]) == 'volatile int time = sdata->time;' assert str(exprs[b+2]) == 'lock0[0] = 1;' assert exprs[b+3].write is u assert exprs[b+4].write is v