Skip to content

Commit

Permalink
compiler: Abstract semantically identical compounds
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Nov 9, 2023
1 parent 51cae2d commit 210a874
Show file tree
Hide file tree
Showing 15 changed files with 311 additions and 157 deletions.
20 changes: 8 additions & 12 deletions devito/ir/iet/efunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

__all__ = ['ElementalFunction', 'ElementalCall', 'make_efunc', 'make_callable',
'EntryFunction', 'AsyncCallable', 'AsyncCall', 'ThreadCallable',
'DeviceFunction', 'DeviceCall', 'KernelLaunch']
'DeviceFunction', 'DeviceCall', 'KernelLaunch', 'CommCallable']


# ElementalFunction machinery
Expand Down Expand Up @@ -105,17 +105,7 @@ def make_callable(name, iet, retval='void', prefix='static'):
"""
Utility function to create a Callable from an IET.
"""
parameters = derive_parameters(iet)

# TODO: this should be done by `derive_parameters`, and perhaps better, e.g.
# ordering such that TimeFunctions go first, then Functions, etc. However,
# doing it would require updating a *massive* number of tests and notebooks,
# hence for now we limit it here
# NOTE: doing it not just for code aesthetics, but also so that semantically
# identical callables can be abstracted homogeneously irrespective of the
# object names, which dictate the ordering in the callable signature
parameters = sorted(parameters, key=lambda p: str(type(p)))

parameters = derive_parameters(iet, ordering='canonical')
return Callable(name, iet, retval, parameters=parameters, prefix=prefix)


Expand Down Expand Up @@ -221,3 +211,9 @@ def functions(self):
if self.stream is not None:
launch_args += (self.stream.function,)
return super().functions + launch_args


# Other relevant Callable subclasses

class CommCallable(Callable):
pass
2 changes: 1 addition & 1 deletion devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def all_parameters(self):
@property
def functions(self):
return tuple(i.function for i in self.all_parameters
if isinstance(i.function, AbstractFunction))
if isinstance(i.function, (AbstractFunction, AbstractObject)))

@property
def defines(self):
Expand Down
13 changes: 12 additions & 1 deletion devito/ir/iet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,13 @@ def filter_iterations(tree, key=lambda i: i):
return filtered


def derive_parameters(iet, drop_locals=False):
def derive_parameters(iet, drop_locals=False, ordering='default'):
"""
Derive all input parameters (function call arguments) from an IET
by collecting all symbols not defined in the tree itself.
"""
assert ordering in ('default', 'canonical')

# Extract all candidate parameters
candidates = FindSymbols().visit(iet)

Expand All @@ -122,6 +124,15 @@ def derive_parameters(iet, drop_locals=False):
if drop_locals:
parameters = [p for p in parameters if not (p.is_ArrayBasic or p.is_LocalObject)]

# NOTE: This is requested by the caller when the parameters are used to
# construct Callables whose signature only depends on the object types,
# rather than on their name
# TODO: It should maybe be done systematically... but it's gonna change a huge
# amount of tests and examples; plus, it might break compatibility those
# using devito as a library-generator to be embedded within legacy codes
if ordering == 'canonical':
parameters = sorted(parameters, key=lambda p: str(type(p)))

return parameters


Expand Down
29 changes: 22 additions & 7 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from devito.ir.support.space import Backward
from devito.symbolics import ListInitializer, ccode, uxreplace
from devito.tools import (GenericVisitor, as_tuple, ctypes_to_cstr, filter_ordered,
filter_sorted, flatten, is_external_ctype, c_restrict_void_p)
filter_sorted, flatten, is_external_ctype,
c_restrict_void_p, sorted_priority)
from devito.types.basic import AbstractFunction, Basic
from devito.types import (ArrayObject, CompositeObject, Dimension, Pointer,
IndexedData, DeviceMap)
Expand Down Expand Up @@ -224,7 +225,7 @@ def _gen_struct_decl(self, obj, masked=()):

def _gen_value(self, obj, level=2, masked=()):
qualifiers = [v for k, v in self._qualifiers_mapper.items()
if getattr(obj, k, False) and v not in masked]
if getattr(obj.function, k, False) and v not in masked]

if (obj._mem_stack or obj._mem_constant) and level == 2:
strtype = obj._C_typedata
Expand All @@ -233,7 +234,8 @@ def _gen_value(self, obj, level=2, masked=()):
strtype = ctypes_to_cstr(obj._C_ctype)
strshape = ''
if isinstance(obj, (AbstractFunction, IndexedData)) and level >= 1:
strtype = '%s%s' % (strtype, self._restrict_keyword)
if not obj._mem_stack:
strtype = '%s%s' % (strtype, self._restrict_keyword)
strtype = ' '.join(qualifiers + [strtype])

strname = obj._C_name
Expand Down Expand Up @@ -632,10 +634,10 @@ def visit_Operator(self, o, mode='all'):
# Elemental functions
esigns = []
efuncs = [blankline]
for i in o._func_table.values():
if i.local:
esigns.append(self._gen_signature(i.root))
efuncs.extend([self._visit(i.root), blankline])
items = [i.root for i in o._func_table.values() if i.local]
for i in sorted_efuncs(items):
esigns.append(self._gen_signature(i))
efuncs.extend([self._visit(i), blankline])

# Definitions
headers = [c.Define(*i) for i in o._headers] + [blankline]
Expand Down Expand Up @@ -1279,3 +1281,16 @@ def generate(self):
if self.cast:
tip = '(%s)%s' % (self.cast, tip)
yield tip


def sorted_efuncs(efuncs):
from devito.ir.iet.efunc import (CommCallable, DeviceFunction,
ThreadCallable, ElementalFunction)

priority = {
DeviceFunction: 3,
ThreadCallable: 2,
ElementalFunction: 1,
CommCallable: 1
}
return sorted_priority(efuncs, priority)
4 changes: 2 additions & 2 deletions devito/mpi/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from devito.ir.iet import (Call, Callable, Conditional, ElementalFunction,
Expression, ExpressionBundle, AugmentedExpression,
Iteration, List, Prodder, Return, make_efunc, FindNodes,
Transformer, ElementalCall)
Transformer, ElementalCall, CommCallable)
from devito.mpi import MPI
from devito.symbolics import (Byref, CondNe, FieldFromPointer, FieldFromComposite,
IndexedPointer, Macro, cast_mapper, subs_op_args)
Expand Down Expand Up @@ -1015,7 +1015,7 @@ def _call_poke(self, poke):
# Callable sub-hierarchy


class MPICallable(Callable):
class MPICallable(CommCallable):

def __init__(self, name, body, parameters):
super(MPICallable, self).__init__(name, body, 'void', parameters, ('static',))
Expand Down
84 changes: 53 additions & 31 deletions devito/passes/iet/asynchrony.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import cgen as c

from devito.ir import (AsyncCall, AsyncCallable, BlankLine, Call, Callable,
Conditional, Dereference, DummyExpr, FindNodes, FindSymbols,
Conditional, DummyExpr, FindNodes, FindSymbols,
Iteration, List, PointerCast, Return, ThreadCallable,
Transformer, While, maybe_alias)
Transformer, While, make_callable, maybe_alias)
from devito.passes.iet.definitions import DataManager
from devito.passes.iet.engine import iet_pass
from devito.symbolics import (CondEq, CondNe, FieldFromComposite, FieldFromPointer,
Expand Down Expand Up @@ -60,26 +60,26 @@ def lower_async_callables(iet, root=None, sregistry=None):
ncfields=ncfields,
pname=sregistry.make_name(prefix='tsdata')
)
sbase = sdata.symbolic_base
sbase = sdata.indexed

# Prepend the SharedData fields available upon thread activation
preactions = [DummyExpr(i, FieldFromPointer(i.name, sbase)) for i in ncfields]
preactions = [DummyExpr(i, FieldFromPointer(i.base, sbase)) for i in ncfields]
preactions.append(BlankLine)

# Append the flag reset
postactions = [List(body=[
BlankLine,
DummyExpr(FieldFromPointer(sdata._field_flag, sbase), 1)
DummyExpr(FieldFromPointer(sdata.symbolic_flag, sbase), 1)
])]

wrap = List(body=preactions + list(iet.body.body) + postactions)

# The thread has work to do when it receives the signal that all locks have
# been set to 0 by the main thread
wrap = Conditional(CondEq(FieldFromPointer(sdata._field_flag, sbase), 2), wrap)
wrap = Conditional(CondEq(FieldFromPointer(sdata.symbolic_flag, sbase), 2), wrap)

# The thread keeps spinning until the alive flag is set to 0 by the main thread
wrap = While(CondNe(FieldFromPointer(sdata._field_flag, sbase), 0), wrap)
wrap = While(CondNe(FieldFromPointer(sdata.symbolic_flag, sbase), 0), wrap)

# pthread functions expect exactly one argument of type void*
tparameter = Pointer(name='_%s' % sdata.name)
Expand All @@ -88,9 +88,11 @@ def lower_async_callables(iet, root=None, sregistry=None):
unpacks = [PointerCast(sdata, tparameter), BlankLine]
for i in cfields:
if i.is_AbstractFunction:
unpacks.append(Dereference(i, sdata))
unpacks.append(
DummyExpr(i._C_symbol, FieldFromPointer(i._C_symbol, sbase))
)
else:
unpacks.append(DummyExpr(i, FieldFromPointer(i.name, sbase)))
unpacks.append(DummyExpr(i, FieldFromPointer(i.base, sbase)))

body = iet.body._rebuild(body=[wrap, Return(Null)], unpacks=unpacks)
iet = ThreadCallable(iet.name, body, tparameter)
Expand All @@ -112,11 +114,20 @@ def lower_async_calls(iet, track=None, sregistry=None):

assert n.name in track
sdata = track[n.name]
sbase = sdata.symbolic_base
sbase = sdata.indexed
name = sregistry.make_name(prefix='init_%s' % sdata.name)
body = [DummyExpr(FieldFromPointer(i._C_name, sbase), i._C_symbol)
for i in sdata.cfields]
body.extend([BlankLine, DummyExpr(FieldFromPointer(sdata._field_flag, sbase), 1)])
body = []
for i in sdata.cfields:
if i.is_AbstractFunction:
body.append(
DummyExpr(FieldFromPointer(i._C_symbol, sbase), i._C_symbol)
)
else:
body.append(DummyExpr(FieldFromPointer(i.base, sbase), i.base))
body.extend([
BlankLine,
DummyExpr(FieldFromPointer(sdata.symbolic_flag, sbase), 1)
])
parameters = sdata.cfields + (sdata,)
efuncs[n.name] = Callable(name, body, 'void', parameters, 'static')

Expand All @@ -135,7 +146,7 @@ def lower_async_calls(iet, track=None, sregistry=None):
threads = PThreadArray(name=name, npthreads=sdata.npthreads)

# Call to `sdata` initialization Callable
sbase = sdata.symbolic_base
sbase = sdata.indexed
d = threads.index
arguments = []
for a in n.arguments:
Expand All @@ -152,7 +163,7 @@ def lower_async_calls(iet, track=None, sregistry=None):
call0 = Call(efuncs[n.name].name, arguments)

# Create pthreads
tbase = threads.symbolic_base
tbase = threads.indexed
call1 = Call('pthread_create', (
tbase + d, Null, Call(n.name, [], is_indirect=True), sbase + d
))
Expand All @@ -164,33 +175,34 @@ def lower_async_calls(iet, track=None, sregistry=None):
else:
callback = lambda body: Iteration(body, d, threads.size - 1)
initialization.append(List(
header=c.Comment("Fire up and initialize `%s`" % threads.name),
body=callback([call0, call1])
))

# Finalization
finalization.append(List(
header=c.Comment("Wait for completion of `%s`" % threads.name),
body=callback([
While(CondEq(FieldFromComposite(sdata._field_flag, sdata[d]), 2)),
DummyExpr(FieldFromComposite(sdata._field_flag, sdata[d]), 0),
Call('pthread_join', (threads[d], Null))
])
))
name = sregistry.make_name(prefix='shutdown')
body = List(body=callback([
While(CondEq(FieldFromComposite(sdata.symbolic_flag, sdata[d]), 2)),
DummyExpr(FieldFromComposite(sdata.symbolic_flag, sdata[d]), 0),
Call('pthread_join', (threads[d], Null))
]))
efunc = efuncs[name] = make_callable(name, body)
finalization.append(Call(name, efunc.parameters))

# Activation
if threads.size == 1:
d = threads.index
condition = CondNe(FieldFromComposite(sdata._field_flag, sdata[d]), 1)
condition = CondNe(FieldFromComposite(sdata.symbolic_flag, sdata[d]), 1)
activation = [While(condition)]
else:
d = Symbol(name=sregistry.make_name(prefix=threads.index.name))
condition = CondNe(FieldFromComposite(sdata._field_flag, sdata[d]), 1)
condition = CondNe(FieldFromComposite(sdata.symbolic_flag, sdata[d]), 1)
activation = [DummyExpr(d, 0),
While(condition, DummyExpr(d, (d + 1) % threads.size))]
activation.extend([DummyExpr(FieldFromComposite(i.name, sdata[d]), i)
activation.extend([DummyExpr(FieldFromComposite(i.base, sdata[d]), i)
for i in sdata.ncfields])
activation.append(DummyExpr(FieldFromComposite(sdata._field_flag, sdata[d]), 2))
activation.append(
DummyExpr(FieldFromComposite(sdata.symbolic_flag, sdata[d]), 2)
)
activation = List(
header=[c.Line(), c.Comment("Activate `%s`" % threads.name)],
body=activation,
Expand All @@ -203,9 +215,19 @@ def lower_async_calls(iet, track=None, sregistry=None):
iet = Transformer(mapper).visit(iet)

# Inject initialization and finalization
initialization.append(BlankLine)
finalization.insert(0, BlankLine)
body = iet.body._rebuild(body=initialization + list(iet.body.body) + finalization)
initialization = List(
header=c.Comment("Fire up and initialize pthreads"),
body=initialization + [BlankLine]
)

finalization = List(
header=c.Comment("Wait for completion of pthreads"),
body=finalization
)

body = iet.body._rebuild(
body=[initialization] + list(iet.body.body) + [BlankLine, finalization]
)
iet = iet._rebuild(body=body)
else:
assert not initialization
Expand Down
Loading

0 comments on commit 210a874

Please sign in to comment.