Skip to content

Commit 2c3b2b1

Browse files
author
Max Strange
committed
WIP: IR
1 parent 995678b commit 2c3b2b1

File tree

5 files changed

+105
-43
lines changed

5 files changed

+105
-43
lines changed

acc/api.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ def wrapper(*args, **kwargs):
6565
meta_data = MetaVars(src=source, stackframe=stackframe, signature=signature, funcs_name=funcname, funcs_module=module)
6666

6767
intermediate_rep = IntermediateRepresentation(meta_data)
68-
for pragma in frontend.parse_pragmas(intermediate_rep.src, *args, **kwargs):
68+
for pragma, linenumber in frontend.parse_pragmas(intermediate_rep.src, *args, **kwargs):
6969
# Side-effect-y: this function modifies intermediate_rep each time
70-
frontend.accumulate_pragma(intermediate_rep, pragma, *args, **kwargs)
70+
frontend.accumulate_pragma(intermediate_rep, pragma, linenumber, *args, **kwargs)
7171

7272
# Pass the intermediate representation into the backend to get the new source code
7373
new_source = back.compile(intermediate_rep)

acc/backend/backend.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ def for_loop(code_object, meta_data):
4848
new_region_src = " " * 4 + "return x * x" + os.linesep
4949
execute_signature = "ls"
5050
execute_src = " " * 4 + "p = Pool(5)" + os.linesep
51-
execute_src += " " * 4 + "return p.map(task, " + execute_signature + ")" +\
52-
os.linesep
51+
execute_src += " " * 4 + "return p.map(task, " + execute_signature + ")" + os.linesep
5352
#-------------------------------------------------------------------
5453

5554
new_src = "from multiprocessing import Pool" + os.linesep

acc/frontend/frontend.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,21 @@ def parse_pragmas(src, *args, **kwargs):
1414
given in `src`.
1515
"""
1616
regexp = re.compile(r"^((\s)*#(\s)*(pragma)(\s)*(acc))")
17-
for line in src.splitlines():
17+
for lineno, line in enumerate(src.splitlines()):
1818
if regexp.match(line):
19-
yield line
19+
yield line, lineno
2020

21-
def accumulate_pragma(intermediate_rep, pragma, *args, **kwargs):
21+
def accumulate_pragma(intermediate_rep, pragma, lineno, *args, **kwargs):
2222
"""
2323
Modifies `intermediate_rep` according to `pragma`.
2424
"""
2525
directive_and_clauses = pragma.partition("acc")[-1].split(' ')
2626
directive_and_clauses = [word for word in directive_and_clauses if word != '']
2727
directive = directive_and_clauses[0]
2828
clause_list = directive_and_clauses[1:]
29-
_accumulate_pragma_helper(directive, clause_list, intermediate_rep, *args, **kwargs)
29+
_accumulate_pragma_helper(directive, clause_list, intermediate_rep, lineno, *args, **kwargs)
3030

31-
def _accumulate_pragma_helper(directive, clause_list, intermediate_rep, *args, **kwargs):
31+
def _accumulate_pragma_helper(directive, clause_list, intermediate_rep, lineno, *args, **kwargs):
3232
"""
3333
Applies the given directive and its associated clause list
3434
to the given intermediate_rep.
@@ -44,7 +44,7 @@ def _accumulate_pragma_helper(directive, clause_list, intermediate_rep, *args, *
4444
elif directive == "host_data":
4545
pass
4646
elif directive == "loop":
47-
loop(clause_list, intermediate_rep, *args, **kwargs)
47+
loop(clause_list, intermediate_rep, lineno, *args, **kwargs)
4848
elif directive == "atomic":
4949
pass
5050
elif directive == "cache":

acc/frontend/loop/loop.py

+31-27
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ class LoopNode(IrNode):
1414
"""
1515
Node for the IntermediateRepresentation tree that is used for loop constructs.
1616
"""
17-
def __init__(self):
18-
super().__init__()
17+
def __init__(self, lineno: int):
18+
super().__init__(lineno)
1919
self.collapse = None
2020
self.gang = None
2121
self.worker = None
@@ -28,7 +28,7 @@ def __init__(self):
2828
self.private = None
2929
self.reduction = None
3030

31-
def loop(clauses, intermediate_rep, *args, **kwargs):
31+
def loop(clauses, intermediate_rep, lineno, *args, **kwargs):
3232
"""
3333
From the docs:
3434
The loop construct can describe what type of parallelism to use to
@@ -65,26 +65,31 @@ def loop(clauses, intermediate_rep, *args, **kwargs):
6565
clause must be written such that the loop iteration count is
6666
computable when entering the loop construct.
6767
"""
68-
loop_node = LoopNode()
69-
intermediate_rep.root.add_child(loop_node)
68+
loop_node = LoopNode(lineno)
7069
index = 0
7170
while index != -1:
72-
index = _apply_clause(index, clauses, intermediate_rep)
71+
index = _apply_clause(index, clauses, intermediate_rep, loop_node)
72+
intermediate_rep.add_child(loop_node)
7373

74-
def _apply_clause(index, clause_list, intermediate_rep):
74+
def _apply_clause(index, clause_list, intermediate_rep, loop_node):
7575
"""
7676
Consumes however much of the clause list as necessary to apply the clause
7777
found at index in the clause_list.
7878
79-
@param index: The index into the clause_list of the clause we are
80-
interested in.
79+
@param index: The index into the clause_list of the clause we are
80+
interested in.
8181
82-
@param clause_list: The list of the clauses that this clause is indexed in.
82+
@param clause_list: The list of the clauses that this clause is indexed in.
8383
84-
@return: The new index. If there are no more
85-
clauses after this one is done, index will be -1.
84+
@param intermediate_rep: The intermediate representation, filled with information
85+
about the source code in general, but not yet this node.
86+
87+
@param loop_node: The node who's information we are filling in with the clauses.
88+
89+
@return: The new index. If there are no more
90+
clauses after this one is done, index will be -1.
8691
"""
87-
args = (index, clause_list, intermediate_rep)
92+
args = (index, clause_list, intermediate_rep, loop_node)
8893
clause = clause_list[index]
8994
# TODO: Remove this debug print
9095
print("clause:", clause)
@@ -111,11 +116,10 @@ def _apply_clause(index, clause_list, intermediate_rep):
111116
elif clause.startswith("reduction"):
112117
return _reduction(*args)
113118
else:
114-
raise InvalidClauseError("Clause either not allowed for this " +\
115-
"directive, or else it may be spelled " +\
116-
"incorrectly. Clause given: " + clause)
119+
errmsg = "Clause either not allowed for this directive, or else it may be spelled incorrectly. Clause given: {} at line: {}".format(clause, loop_node.lineno)
120+
raise InvalidClauseError(errmsg)
117121

118-
def _collapse(index, clause_list, intermediate_rep):
122+
def _collapse(index, clause_list, intermediate_rep, loop_node):
119123
"""
120124
The 'collapse' clause is used to specify how many tightly nested loops
121125
are associated with the 'loop' construct. The argument to the 'collapse'
@@ -158,7 +162,7 @@ def _collapse(index, clause_list, intermediate_rep):
158162

159163
return -1
160164

161-
def _gang(index, clause_list, intermediate_rep):
165+
def _gang(index, clause_list, intermediate_rep, loop_node):
162166
"""
163167
When the parent compute construct is a 'parallel' construct, or on an
164168
orphaned 'loop' construct, the 'gang' clause specifies that the
@@ -199,7 +203,7 @@ def _gang(index, clause_list, intermediate_rep):
199203
"""
200204
return -1
201205

202-
def _worker(index, clause_list, intermediate_rep):
206+
def _worker(index, clause_list, intermediate_rep, loop_node):
203207
"""
204208
When the parent compute construct is a 'parallel' construct, or on an
205209
orphaned 'loop' construct, the 'worker' clause specifies that the
@@ -227,7 +231,7 @@ def _worker(index, clause_list, intermediate_rep):
227231
"""
228232
return -1
229233

230-
def _vector(index, clause_list, intermediate_rep):
234+
def _vector(index, clause_list, intermediate_rep, loop_node):
231235
"""
232236
When the parent compute construct is a 'parallel' construct, or on an
233237
orphaned 'loop' construct, the 'vector' clause specifies that the
@@ -256,15 +260,15 @@ def _vector(index, clause_list, intermediate_rep):
256260
"""
257261
return -1
258262

259-
def _seq(index, clause_list, intermediate_rep):
263+
def _seq(index, clause_list, intermediate_rep, loop_node):
260264
"""
261265
The 'seq' clause specifies that the associated loop or loops are to be
262266
executed sequentially by the acclerator. This clause will override any
263267
automatic parallelization or vectorization.
264268
"""
265269
return -1
266270

267-
def _auto(index, clause_list, intermediate_rep):
271+
def _auto(index, clause_list, intermediate_rep, loop_node):
268272
"""
269273
The 'auto' clause specifies that the implementation must analyze the
270274
loop and determine whether to run the loop sequentially. The
@@ -276,7 +280,7 @@ def _auto(index, clause_list, intermediate_rep):
276280
"""
277281
return -1
278282

279-
def _tile(index, clause_list, intermediate_rep):
283+
def _tile(index, clause_list, intermediate_rep, loop_node):
280284
"""
281285
The 'tile' clause specifies that the implementation should split each
282286
loop nest into two loops, with an outer set of tile loops and an
@@ -303,14 +307,14 @@ def _tile(index, clause_list, intermediate_rep):
303307
"""
304308
return -1
305309

306-
def _device_type(index, clause_list, intermediate_rep):
310+
def _device_type(index, clause_list, intermediate_rep, loop_node):
307311
"""
308312
The 'device_type' clause is described in Section 2.4 Device-Specific
309313
Clauses.
310314
"""
311315
return -1
312316

313-
def _independent(index, clause_list, intermediate_rep):
317+
def _independent(index, clause_list, intermediate_rep, loop_node):
314318
"""
315319
The 'independent' clause tells the implementation that the iterations of
316320
this loop are data-independent with respect to each other. This allows
@@ -327,7 +331,7 @@ def _independent(index, clause_list, intermediate_rep):
327331
"""
328332
return -1
329333

330-
def _private(index, clause_list, intermediate_rep):
334+
def _private(index, clause_list, intermediate_rep, loop_node):
331335
"""
332336
The 'private' clause on a 'loop' construct specifies that a copy of each
333337
item in var-list will be created. If the body of the loop is executed
@@ -341,7 +345,7 @@ def _private(index, clause_list, intermediate_rep):
341345
"""
342346
return -1
343347

344-
def _reduction(index, clause_list, intermediate_rep):
348+
def _reduction(index, clause_list, intermediate_rep, loop_node):
345349
"""
346350
The 'reduction' clause specifies a reduction operator and one or more
347351
scalar variables. For each reduction variable, a private copy is created

acc/ir/intrep.py

+65-6
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,21 @@
66
about which particular backend it is using (the backend is passed into the
77
frontend as an argument).
88
"""
9+
import re
910

1011
class IrNode:
1112
"""
1213
IntermediateRepresentation tree node. Base class for all types
1314
of Nodes.
1415
"""
15-
def __init__(self, children=None):
16+
def __init__(self, lineno, children=None):
1617
if children is None:
1718
self.children = []
1819
else:
1920
self.children = children
2021

22+
self.lineno = lineno
23+
2124
def add_child(self, child):
2225
"""
2326
Adds `child` to this IrNode's list of children.
@@ -28,8 +31,8 @@ class AccNode(IrNode):
2831
"""
2932
The root of an IntermediateRepresentation tree.
3033
"""
31-
def __init__(self):
32-
super().__init__()
34+
def __init__(self, lineno):
35+
super().__init__(lineno)
3336

3437
class IntermediateRepresentation:
3538
"""
@@ -47,6 +50,62 @@ class IntermediateRepresentation:
4750
def __init__(self, meta_data):
4851
"""
4952
"""
50-
self.meta_data = meta_data
51-
self.src = meta_data.src
52-
self.root = AccNode()
53+
self.meta_data = meta_data # All the meta data
54+
self.src = meta_data.src # Shortcut to the source code
55+
self.root = AccNode(0) # The root of the tree
56+
self._lineno_lookup = {} # A hash table for line number -> IrNode
57+
58+
def add_child(self, child):
59+
"""
60+
Adds a child node to the tree. Determines where to add the node
61+
by seeing what the node's line number is and then going back
62+
up in the source code until it finds either the pragma construct
63+
that encompasses this new node or else it reaches the top of the
64+
function, in which case this node is added as a child of root.
65+
"""
66+
# Add the child to the hash table. There shouldn't already be a node for this line number.
67+
assert child.lineno not in self._lineno_lookup, "Line number {} already in hash. Hash: {}".format(child.lineno, self._lineno_lookup)
68+
self._lineno_lookup[child.lineno] = child
69+
70+
# Get the source code above child and reverse it for easy iteration
71+
src_lines_above_child = self.src.splitlines()[0:child.lineno]
72+
src_lines_above_child.reverse()
73+
line_numbers = [i for i in range(0, child.lineno)]
74+
line_numbers.reverse()
75+
76+
# Walk up from here, looking for pragmas
77+
regexp = re.compile(r"^((\s)*#(\s)*(pragma)(\s)*(acc))")
78+
for lineno, line in zip(line_numbers, src_lines_above_child):
79+
if regexp.match(line):
80+
# This line is a '#pragma acc' line. Check if it encompasses me.
81+
# If it does, we look it up and add the new child to it.
82+
# If it does not, then we keep going.
83+
if self._pragma_encompasses_child(child, lineno):
84+
parent = self._get_node_by_lineno(lineno)
85+
assert parent is not None, "Could not find node for pragma {} at line {}".format(line, lineno)
86+
parent.add_child(child)
87+
return
88+
89+
# If we get here without adding the new child to the tree, it is because
90+
# we have scanned the whole function and not found any nodes that
91+
# encompass the new child. Add the child to the top-level.
92+
self.root.add_child(child)
93+
94+
def _get_node_by_lineno(self, lineno: int) -> IrNode:
95+
"""
96+
Returns a reference to the node that was created for the pragma at the
97+
given lineno. If we can't find the node, we return None.
98+
"""
99+
if lineno in self._lineno_lookup:
100+
return self._lineno_lookup[lineno]
101+
else:
102+
return None
103+
104+
def _pragma_encompasses_child(self, child: IrNode, lineno: int) -> bool:
105+
"""
106+
Returns True if the given child node should be contained in the pragma
107+
found in the source code at lineno. Returns False otherwise.
108+
"""
109+
# TODO: Only certain types of constructs can actually encompass others
110+
# Do this method
111+
return False

0 commit comments

Comments
 (0)