12
12
from .util import _internal_assert
13
13
from . import calls
14
14
from . import util
15
- from .var_decl import determine_variable_usage
15
+ from .preprocessor import determine_variable_usage
16
16
from ..api import all as _all
17
17
from ..api import any as _any
18
18
from ..container import Array
@@ -61,6 +61,7 @@ class Symbol(Enum):
61
61
BufferVar = 7
62
62
LoopVar = 8
63
63
ConstLoopVar = 9
64
+ ThreadBind = 10
64
65
65
66
66
67
class HybridParser (ast .NodeVisitor ):
@@ -117,7 +118,10 @@ def __init__(self, args, usage, symbols, func_name=None):
117
118
self .symbols = {} # Symbol table
118
119
for k , v in symbols .items ():
119
120
if isinstance (v , types .FunctionType ):
120
- self .symbols [k ] = Symbol .Callable , v
121
+ self .add_symbol (k , Symbol .Callable , v )
122
+
123
+ self .binds = {} # Thread binds
124
+ self .device = 0 # Is it generating device
121
125
122
126
self .func_name = func_name # The name of the function to be lowered
123
127
self .outputs = [] # Output tensors' name
@@ -126,6 +130,25 @@ def __init__(self, args, usage, symbols, func_name=None):
126
130
self .returned = False # If this function has a valid return
127
131
128
132
133
+ def add_symbol (self , key , ty , val ): #pylint: disable=invalid-name
134
+ """Add value to the symbol table context"""
135
+ if key in self .symbols .keys ():
136
+ old = str (self .symbols [key ])
137
+ new = str ((ty , val ))
138
+ _internal_assert (False ,
139
+ "Name conflict in symbol table! [%s] %s -> %s" % (key , old , new ))
140
+
141
+ self .symbols [key ] = ty , val
142
+
143
+ if ty == Symbol .ThreadBind :
144
+ if val .var .name not in self .binds .keys ():
145
+ self .binds [val .var .name ] = val
146
+ return
147
+ val_ = self .binds [val .var .name ]
148
+ _internal_assert (_ir_pass .Equal (val_ .dom .extent , val .dom .extent ),
149
+ "Thread extents should be uniform!" )
150
+ self .symbols [key ] = ty , val_
151
+
129
152
130
153
def wrap_up_realize (self , node , body ):
131
154
"""Wrap up all the variables which will no longer be used"""
@@ -141,11 +164,14 @@ def wrap_up_realize(self, node, body):
141
164
continue
142
165
elif 'Buffer' in ty .name :
143
166
_buf = entry
144
- _scope = ty . name [: - 6 ]. lower () if ty is not Symbol .BufferVar else 'global'
167
+ _scope = 'global' if ty is Symbol .BufferVar else ty . name [: - 6 ]. lower ()
145
168
to_pop .append (key )
146
169
else :
147
170
continue
148
171
172
+ if _scope == 'global' :
173
+ body = self .wrap_up_binds (body )
174
+
149
175
_domain = [_make .range_by_min_extent (0 , i ) for i in _buf .shape ]
150
176
_dtype = _buf .dtype
151
177
_true = _api .convert (True )
@@ -158,6 +184,14 @@ def wrap_up_realize(self, node, body):
158
184
return body
159
185
160
186
187
+ def wrap_up_binds (self , body ):
188
+ for _ , iter_var in self .binds .items ():
189
+ ext = iter_var .dom .extent
190
+ body = _make .AttrStmt (iter_var , 'thread_extent' , ext , body )
191
+ self .binds = {}
192
+ return body
193
+
194
+
161
195
#pylint: disable=invalid-name, missing-docstring
162
196
def visit_Module (self , node ):
163
197
_internal_assert (len (node .body ) == 1 , \
@@ -173,10 +207,10 @@ def visit_FunctionDef(self, node):
173
207
self .func_name = node .name
174
208
for idx , arg in enumerate (node .args .args ):
175
209
_attr = 'id' if sys .version_info [0 ] < 3 else 'arg' # To make py2 and 3 compatible
176
- self .symbols [ getattr (arg , _attr )] = ( Symbol .Input , self .args [idx ])
210
+ self .add_symbol ( getattr (arg , _attr ), Symbol .Input , self .args [idx ])
177
211
res = visit_list_to_block (self .visit , node .body )
178
212
res = self .wrap_up_realize (node , res )
179
- return res
213
+ return self . wrap_up_binds ( res )
180
214
181
215
182
216
def visit_Expr (self , node ):
@@ -189,6 +223,8 @@ def visit_Name(self, node):
189
223
_internal_assert (name in self .symbols , "Unknown symbol %s!" % name )
190
224
if ty in [Symbol .LoopVar , Symbol .Input , Symbol .ConstLoopVar ]:
191
225
return entry
226
+ if ty is Symbol .ThreadBind :
227
+ return entry .var
192
228
if ty is Symbol .ConstVar :
193
229
return entry if isinstance (node .ctx , ast .Load ) else None
194
230
if ty is Symbol .BufferVar :
@@ -237,7 +273,7 @@ def visit_Assign(self, node):
237
273
for i in range (rhs .num_outputs ):
238
274
_internal_assert (isinstance (node .targets [i ], ast .Name ),
239
275
"You should bind a pure name to the tensors" )
240
- self .symbols [ node .targets [i ].id ] = Symbol .GlobalBuffer , rhs .output (i )
276
+ self .add_symbol ( node .targets [i ].id , Symbol .GlobalBuffer , rhs .output (i ) )
241
277
rmap [rhs .outputs [i ].op ] = rhs .output (i )
242
278
return util .replace_io (rhs .body , rmap )
243
279
@@ -260,15 +296,19 @@ def visit_Assign(self, node):
260
296
if isinstance (rhs , tuple ):
261
297
shape , dtype , scope = rhs
262
298
ph = _api .placeholder (shape , dtype = dtype , name = lhs )
263
- self .symbols [ lhs ] = getattr (Symbol , scope .title () + "Buffer" ), ph
299
+ self .add_symbol ( lhs , getattr (Symbol , scope .title () + "Buffer" ), ph )
264
300
if scope == 'output' :
265
301
self .outputs .append (lhs )
266
302
return util .make_nop ()
267
303
if isinstance (rhs , util .halide_imm_types ) and ast .Store not in rw :
268
- self .symbols [ lhs ] = Symbol .ConstVar , rhs
304
+ self .add_symbol ( lhs , Symbol .ConstVar , rhs )
269
305
else :
306
+ _internal_assert (self .device == 0 ,
307
+ "Single variable not supported in devices' side!\n " + \
308
+ "If you are using GPU, please allocate a 'local' spad " + \
309
+ "outside the bind body" )
270
310
ph = _api .placeholder ((1 , ), dtype = rhs .dtype , name = lhs )
271
- self .symbols [ lhs ] = Symbol .BufferVar , ph
311
+ self .add_symbol ( lhs , Symbol .BufferVar , ph )
272
312
lhs = self .visit (lhs_ )
273
313
if lhs is not None :
274
314
buf , args = lhs
@@ -356,7 +396,7 @@ def visit_If(self, node):
356
396
if node .orelse :
357
397
else_body = visit_list_to_block (self .visit , node .orelse )
358
398
else :
359
- else_body = util . make_nop ()
399
+ else_body = None
360
400
return _make .IfThenElse (cond , if_body , else_body )
361
401
362
402
@@ -445,28 +485,31 @@ def visit_For(self, node):
445
485
446
486
bodies = []
447
487
for i in range (low , low + ext ):
448
- self .symbols [ _name ] = Symbol .ConstLoopVar , i
488
+ self .add_symbol ( _name , Symbol .ConstLoopVar , i )
449
489
body = visit_list_to_block (self .visit , node .body )
450
490
body = self .wrap_up_realize (node , body )
451
491
bodies .append (body )
492
+ self .symbols .pop (_name )
452
493
return concat_list_to_block (bodies )
453
494
454
495
if iter_var is None :
455
- _internal_assert (for_type is not None , "The loop bind function parse error!" )
496
+ _internal_assert (for_type is not None , "The loop iterating function parse error!" )
456
497
offset = iter_var = _api .var (_name )
457
498
if not _ir_pass .Equal (low , _api .const (0 , 'int32' )):
458
499
offset = iter_var + low
459
- self .symbols [ _name ] = Symbol .LoopVar , offset
500
+ self .add_symbol ( _name , Symbol .LoopVar , offset )
460
501
_body = visit_list_to_block (self .visit , node .body )
461
502
else :
462
- _internal_assert (for_type is None , "The loop iterating function parse error!" )
463
- self .symbols [_name ] = Symbol .LoopVar , iter_var .var
503
+ _internal_assert (for_type is None , "The loop bind function parse error!" )
504
+ self .add_symbol (_name , Symbol .ThreadBind , iter_var )
505
+ self .device += 1
464
506
_body = visit_list_to_block (self .visit , node .body )
507
+ self .device -= 1
465
508
466
509
_body = self .wrap_up_realize (node , _body )
467
510
468
511
if for_type is None :
469
- res = _make . AttrStmt ( iter_var , 'thread_extent' , ext , _body )
512
+ res = _body
470
513
else :
471
514
_internal_assert (not isinstance (for_type , tuple ), \
472
515
"Micro expansion should be handled before!" )
0 commit comments