Skip to content

Commit 96f324b

Browse files
werewweic
authored andcommitted
fix lint (apache#2649)
1 parent 7814755 commit 96f324b

File tree

4 files changed

+83
-32
lines changed

4 files changed

+83
-32
lines changed

python/tvm/hybrid/calls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def bind(func_id, args):
4545
_internal_assert(args.__len__() == 2, "A loop bind should only have 2 arguments!")
4646
_internal_assert(isinstance(args[0], str), \
4747
"A loop bind's first argument should be a string!")
48-
iter_var = _api.thread_axis(args[0])
4948
low, ext = _api.const(0, "int32"), args[1]
49+
iter_var = _api.thread_axis((low, ext), args[0])
5050
for_type = None
5151
return iter_var, low, ext, for_type
5252

python/tvm/hybrid/parser.py

Lines changed: 59 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .util import _internal_assert
1313
from . import calls
1414
from . import util
15-
from .var_decl import determine_variable_usage
15+
from .preprocessor import determine_variable_usage
1616
from ..api import all as _all
1717
from ..api import any as _any
1818
from ..container import Array
@@ -61,6 +61,7 @@ class Symbol(Enum):
6161
BufferVar = 7
6262
LoopVar = 8
6363
ConstLoopVar = 9
64+
ThreadBind = 10
6465

6566

6667
class HybridParser(ast.NodeVisitor):
@@ -117,7 +118,10 @@ def __init__(self, args, usage, symbols, func_name=None):
117118
self.symbols = {} # Symbol table
118119
for k, v in symbols.items():
119120
if isinstance(v, types.FunctionType):
120-
self.symbols[k] = Symbol.Callable, v
121+
self.add_symbol(k, Symbol.Callable, v)
122+
123+
self.binds = {} # Thread binds
124+
self.device = 0 # Is it generating device
121125

122126
self.func_name = func_name # The name of the function to be lowered
123127
self.outputs = [] # Output tensors' name
@@ -126,6 +130,25 @@ def __init__(self, args, usage, symbols, func_name=None):
126130
self.returned = False # If this function has a valid return
127131

128132

133+
def add_symbol(self, key, ty, val): #pylint: disable=invalid-name
134+
"""Add value to the symbol table context"""
135+
if key in self.symbols.keys():
136+
old = str(self.symbols[key])
137+
new = str((ty, val))
138+
_internal_assert(False,
139+
"Name conflict in symbol table! [%s] %s -> %s" % (key, old, new))
140+
141+
self.symbols[key] = ty, val
142+
143+
if ty == Symbol.ThreadBind:
144+
if val.var.name not in self.binds.keys():
145+
self.binds[val.var.name] = val
146+
return
147+
val_ = self.binds[val.var.name]
148+
_internal_assert(_ir_pass.Equal(val_.dom.extent, val.dom.extent),
149+
"Thread extents should be uniform!")
150+
self.symbols[key] = ty, val_
151+
129152

130153
def wrap_up_realize(self, node, body):
131154
"""Wrap up all the variables which will no longer be used"""
@@ -141,11 +164,14 @@ def wrap_up_realize(self, node, body):
141164
continue
142165
elif 'Buffer' in ty.name:
143166
_buf = entry
144-
_scope = ty.name[:-6].lower() if ty is not Symbol.BufferVar else 'global'
167+
_scope = 'global' if ty is Symbol.BufferVar else ty.name[:-6].lower()
145168
to_pop.append(key)
146169
else:
147170
continue
148171

172+
if _scope == 'global':
173+
body = self.wrap_up_binds(body)
174+
149175
_domain = [_make.range_by_min_extent(0, i) for i in _buf.shape]
150176
_dtype = _buf.dtype
151177
_true = _api.convert(True)
@@ -158,6 +184,14 @@ def wrap_up_realize(self, node, body):
158184
return body
159185

160186

187+
def wrap_up_binds(self, body):
188+
for _, iter_var in self.binds.items():
189+
ext = iter_var.dom.extent
190+
body = _make.AttrStmt(iter_var, 'thread_extent', ext, body)
191+
self.binds = {}
192+
return body
193+
194+
161195
#pylint: disable=invalid-name, missing-docstring
162196
def visit_Module(self, node):
163197
_internal_assert(len(node.body) == 1, \
@@ -173,10 +207,10 @@ def visit_FunctionDef(self, node):
173207
self.func_name = node.name
174208
for idx, arg in enumerate(node.args.args):
175209
_attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible
176-
self.symbols[getattr(arg, _attr)] = (Symbol.Input, self.args[idx])
210+
self.add_symbol(getattr(arg, _attr), Symbol.Input, self.args[idx])
177211
res = visit_list_to_block(self.visit, node.body)
178212
res = self.wrap_up_realize(node, res)
179-
return res
213+
return self.wrap_up_binds(res)
180214

181215

182216
def visit_Expr(self, node):
@@ -189,6 +223,8 @@ def visit_Name(self, node):
189223
_internal_assert(name in self.symbols, "Unknown symbol %s!" % name)
190224
if ty in [Symbol.LoopVar, Symbol.Input, Symbol.ConstLoopVar]:
191225
return entry
226+
if ty is Symbol.ThreadBind:
227+
return entry.var
192228
if ty is Symbol.ConstVar:
193229
return entry if isinstance(node.ctx, ast.Load) else None
194230
if ty is Symbol.BufferVar:
@@ -237,7 +273,7 @@ def visit_Assign(self, node):
237273
for i in range(rhs.num_outputs):
238274
_internal_assert(isinstance(node.targets[i], ast.Name),
239275
"You should bind a pure name to the tensors")
240-
self.symbols[node.targets[i].id] = Symbol.GlobalBuffer, rhs.output(i)
276+
self.add_symbol(node.targets[i].id, Symbol.GlobalBuffer, rhs.output(i))
241277
rmap[rhs.outputs[i].op] = rhs.output(i)
242278
return util.replace_io(rhs.body, rmap)
243279

@@ -260,15 +296,19 @@ def visit_Assign(self, node):
260296
if isinstance(rhs, tuple):
261297
shape, dtype, scope = rhs
262298
ph = _api.placeholder(shape, dtype=dtype, name=lhs)
263-
self.symbols[lhs] = getattr(Symbol, scope.title() + "Buffer"), ph
299+
self.add_symbol(lhs, getattr(Symbol, scope.title() + "Buffer"), ph)
264300
if scope == 'output':
265301
self.outputs.append(lhs)
266302
return util.make_nop()
267303
if isinstance(rhs, util.halide_imm_types) and ast.Store not in rw:
268-
self.symbols[lhs] = Symbol.ConstVar, rhs
304+
self.add_symbol(lhs, Symbol.ConstVar, rhs)
269305
else:
306+
_internal_assert(self.device == 0,
307+
"Single variable not supported in devices' side!\n" + \
308+
"If you are using GPU, please allocate a 'local' spad " + \
309+
"outside the bind body")
270310
ph = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs)
271-
self.symbols[lhs] = Symbol.BufferVar, ph
311+
self.add_symbol(lhs, Symbol.BufferVar, ph)
272312
lhs = self.visit(lhs_)
273313
if lhs is not None:
274314
buf, args = lhs
@@ -356,7 +396,7 @@ def visit_If(self, node):
356396
if node.orelse:
357397
else_body = visit_list_to_block(self.visit, node.orelse)
358398
else:
359-
else_body = util.make_nop()
399+
else_body = None
360400
return _make.IfThenElse(cond, if_body, else_body)
361401

362402

@@ -445,28 +485,31 @@ def visit_For(self, node):
445485

446486
bodies = []
447487
for i in range(low, low + ext):
448-
self.symbols[_name] = Symbol.ConstLoopVar, i
488+
self.add_symbol(_name, Symbol.ConstLoopVar, i)
449489
body = visit_list_to_block(self.visit, node.body)
450490
body = self.wrap_up_realize(node, body)
451491
bodies.append(body)
492+
self.symbols.pop(_name)
452493
return concat_list_to_block(bodies)
453494

454495
if iter_var is None:
455-
_internal_assert(for_type is not None, "The loop bind function parse error!")
496+
_internal_assert(for_type is not None, "The loop iterating function parse error!")
456497
offset = iter_var = _api.var(_name)
457498
if not _ir_pass.Equal(low, _api.const(0, 'int32')):
458499
offset = iter_var + low
459-
self.symbols[_name] = Symbol.LoopVar, offset
500+
self.add_symbol(_name, Symbol.LoopVar, offset)
460501
_body = visit_list_to_block(self.visit, node.body)
461502
else:
462-
_internal_assert(for_type is None, "The loop iterating function parse error!")
463-
self.symbols[_name] = Symbol.LoopVar, iter_var.var
503+
_internal_assert(for_type is None, "The loop bind function parse error!")
504+
self.add_symbol(_name, Symbol.ThreadBind, iter_var)
505+
self.device += 1
464506
_body = visit_list_to_block(self.visit, node.body)
507+
self.device -= 1
465508

466509
_body = self.wrap_up_realize(node, _body)
467510

468511
if for_type is None:
469-
res = _make.AttrStmt(iter_var, 'thread_extent', ext, _body)
512+
res = _body
470513
else:
471514
_internal_assert(not isinstance(for_type, tuple), \
472515
"Micro expansion should be handled before!")
File renamed without changes.

tests/python/unittest/test_hybrid_script.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ def test_bind():
300300
if not tvm.gpu(0).exist:
301301
print('[Warning] No GPU found! Skip bind test!')
302302
return
303+
303304
@script
304305
def vec_add(a, b):
305306
c = output_tensor((1000, ), 'float32')
@@ -326,23 +327,29 @@ def raw(a, b):
326327
func, ins, outs = run_and_check(raw, [a, b], sch=sch, outs=[c], target='cuda')
327328
run_and_check(func, ins, outs=outs, target='cuda')
328329

329-
# Test loop binds
330+
330331
@tvm.hybrid.script
331-
def goo(a, b):
332-
c = output_tensor(a.shape, a.dtype)
333-
len_b = len(b)
334-
for i in const_range(len_b * 2):
335-
if i < len_b:
336-
c[i] = a[i] + b[i]
337-
else:
338-
c[i - len_b] = a[i - len_b] + b[i - len_b]
332+
def foo(a):
333+
c = output_tensor((a.shape[0],), a.dtype)
334+
total = allocate((1,), a.dtype, 'local')
335+
len_i = a.shape[0]
336+
len_j = a.shape[1]
337+
for i in bind('threadIdx.x', len_i):
338+
total[0] = 0.
339+
for k in const_range(len_j):
340+
total[0] += a[i, k]
341+
c[i] = total[0]
342+
339343
return c
340-
a = tvm.placeholder((5, ), name='a', dtype='int32')
341-
b = [1, 2, 3, 4, 5]
342-
c = goo(a, tvm.convert(b))
343-
sch = tvm.create_schedule(c.op)
344-
func, ins, outs = run_and_check(goo, [a, b], sch=sch, outs=[c])
345-
run_and_check(func, ins, outs=outs)
344+
345+
a = tvm.placeholder((8, 4), 'float32')
346+
c = foo(a)
347+
s = tvm.create_schedule(c.op)
348+
ir = tvm.lower(s, [a, c], simple_mode=True)
349+
assert not isinstance(ir, tvm.stmt.AttrStmt)
350+
func, ins, outs = run_and_check(foo, [a], target='cuda')
351+
run_and_check(func, ins, outs=outs, target='cuda')
352+
346353

347354
def test_math_intrin():
348355
@script
@@ -455,6 +462,7 @@ def share_vec_add(a, b):
455462

456463
a = tvm.placeholder((256, ), dtype='float32', name='a')
457464
b = tvm.placeholder((256, ), dtype='float32', name='b')
465+
c = share_vec_add(a, b)
458466
func, ins, outs = run_and_check(share_vec_add, [a, b], target='cuda')
459467
run_and_check(func, ins, outs=outs, target='cuda')
460468
else:

0 commit comments

Comments
 (0)