Skip to content

Commit b830cd1

Browse files
altanhYuchenJin
authored andcommitted
[Parser][Printer] Switch to output annotation for dataflow blocks (apache#9)
* Relax pretty printer initial prototype * call into TVMScriptPrinter for PrimFuncs * most round-trip tests pass * address comments * implement relax.output syntax for dataflow block outputs * remove leftover comments * fix Var constructor on ShapeExpr annotation * fix DataflowVar as well
1 parent db66461 commit b830cd1

File tree

5 files changed

+73
-31
lines changed

5 files changed

+73
-31
lines changed

python/tvm/relax/expr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(
6666
type_annotation: Optional[Type] = None,
6767
span: Span = None,
6868
) -> None:
69-
if shape_annotation is not None:
69+
if isinstance(shape_annotation, (list, tuple)):
7070
shape_annotation = make_shape(shape_annotation)
7171
self.__init_handle_by_constructor__(
7272
_ffi_api.Var, name_hint, shape_annotation, type_annotation, span
@@ -88,7 +88,7 @@ def __init__(
8888
type_annotation: Optional[Type] = None,
8989
span: Span = None,
9090
) -> None:
91-
if shape_annotation is not None:
91+
if isinstance(shape_annotation, (list, tuple)):
9292
shape_annotation = make_shape(shape_annotation)
9393
self.__init_handle_by_constructor__(
9494
_ffi_api.DataflowVar, name_hint, shape_annotation, type_annotation, span

python/tvm/relax/parser.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class SpecialOp(Enum):
9999
MATCH_SHAPE = "relax.match_shape"
100100
CALL_PACKED = "relax.call_packed"
101101
DATAFLOW = "relax.dataflow"
102+
DATAFLOW_OUTPUT = "relax.output"
102103

103104

104105
class RelaxTransformer(Transformer):
@@ -660,32 +661,30 @@ def parse_dataflow(self, block: ast.Block) -> rx.DataflowBlock:
660661
"""
661662
assert len(block.stmts) > 0, "should never have an empty dataflow block"
662663
bindings = []
663-
output_vars = []
664664

665665
with self.new_scope():
666-
# parse the return statement first to figure out which bindings assign normal Vars
666+
# parse the output statement first to figure out which bindings assign normal Vars
667667
output_stmt = block.stmts[-1]
668-
if not isinstance(output_stmt, ast.Return):
669-
self.report_error(
670-
"dataflow blocks must end with returning the output variables",
671-
output_stmt.span,
672-
)
668+
output_var_names = set()
669+
unbound_output_vars = {}
670+
output_vars = []
673671

674-
ret_val = output_stmt.value
675-
if isinstance(ret_val, ast.Var):
676-
ret_val = ast.Tuple(values=[ret_val], span=ret_val.span)
677-
678-
if not isinstance(ret_val, ast.Tuple) or any(
679-
[not isinstance(f, ast.Var) for f in ret_val.values]
672+
if (
673+
isinstance(output_stmt, ast.UnassignedCall)
674+
and self.transform_expr(output_stmt.call.func_name) == SpecialOp.DATAFLOW_OUTPUT
680675
):
676+
for var in output_stmt.call.params:
677+
if not isinstance(var, ast.Var):
678+
self.report_error(f"dataflow block outputs must be variables", var.span)
679+
output_var_names.add(var.id.name)
680+
unbound_output_vars[var.id.name] = var
681+
else:
681682
self.report_error(
682-
"the returned values must be variables",
683-
ret_val.span,
683+
f"dataflow blocks must end with a {SpecialOp.DATAFLOW_OUTPUT.value} statement",
684+
output_stmt.span,
684685
)
685686

686-
# output variables are bound to normal (not data flow) Vars
687-
output_var_names = {var.id.name for var in ret_val.values}
688-
687+
# output variables are bound to normal (not dataflow) Vars
689688
for binding_stmt in block.stmts[:-1]:
690689
if not isinstance(binding_stmt, (ast.Assign, ast.UnassignedCall)):
691690
self.report_error(
@@ -704,6 +703,18 @@ def parse_dataflow(self, block: ast.Block) -> rx.DataflowBlock:
704703
output_vars.append(var)
705704
else:
706705
output_vars.append(binding.var)
706+
unbound_output_vars.pop(binding_stmt.lhs.id.name)
707+
708+
# check that the output variables are all bound locally
709+
for unbound_var in unbound_output_vars.values():
710+
self._diagnostic_context.emit(
711+
"error",
712+
"dataflow output variables must be bound locally in the block",
713+
unbound_var.span,
714+
)
715+
# FIXME(@altanh): TVMDiagnosticCtx has hard-coded `emit` to always be an error and raise
716+
# an exception on the first call
717+
self._diagnostic_context.render()
707718

708719
# make output variables visible in parent scope
709720
for v in output_vars:
@@ -769,8 +780,10 @@ def transform_expr(self, expr: ast.Expr) -> rx.Expr:
769780
)
770781
op = rx.ExternFunc(extern_func.value, self.to_tvm_span(extern_func.span))
771782
args = [self.transform_expr(expr.params[1])]
772-
else:
783+
elif isinstance(op, (tvm.ir.Op, relay.Expr)):
773784
args = [self.transform_expr(arg) for arg in expr.params]
785+
else:
786+
self.report_error(f"unsupported function in call: {op}", expr.func_name.span)
774787
# TODO(@altanh): should we check for correct arity here eagerly, or defer to a pass?
775788
return relay.Call(op, args, span=self.to_tvm_span(expr.span))
776789

src/printer/relax_script_printer.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::DataflowBlockNode* op) {
246246
}
247247
}
248248
ICHECK(!return_vars.empty()) << "dataflow blocks should have at least one output variable";
249-
body << "return " << Doc::Concat(return_vars, Doc::Text(", "));
249+
body << "relax.output(" << Doc::Concat(return_vars, Doc::Text(", ")) << ")";
250250
block << "with relax.dataflow():" << Doc::NewLine(4);
251251
block << Doc::Indent(4, body) << Doc::NewLine();
252252
return block;

tests/python/relax/test_parser.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -258,12 +258,10 @@ def test_dataflow():
258258
@rx.script
259259
def foo(x: Tensor[_, _]):
260260
with relax.dataflow():
261-
# TODO: parse this
262-
# nonlocal y, w
263261
y = add(x, x)
264262
z = multiply(y, x)
265263
w = subtract(z, x)
266-
return y, w
264+
relax.output(y, w)
267265
t = divide(y, w)
268266
return t
269267

@@ -295,7 +293,7 @@ def foo(x: Tensor[_, _]):
295293
z = multiply(y, x)
296294
relax.match_shape((n, m), z.shape)
297295
w: Tensor[(n, m), _] = subtract(z, x)
298-
return y, w
296+
relax.output(y, w)
299297
t: Tensor[(n, m), _] = divide(y, w)
300298
return t
301299

@@ -308,7 +306,7 @@ def foo(x: Tensor[_, _]):
308306
y = add(x, x)
309307
z = multiply(y, x)
310308
w = subtract(z, x)
311-
return y, w
309+
relax.output(y, w)
312310
t = divide(y, z)
313311
return t
314312

@@ -321,7 +319,7 @@ def foo(x: Tensor[_, _]):
321319
y = add(x, x)
322320
z = multiply(y, x)
323321
w = subtract(z, x)
324-
return y, w
322+
relax.output(y, z)
325323
t = divide(y, z)
326324
return t
327325

@@ -334,11 +332,42 @@ def foo(x: Tensor[_, _]):
334332
y = add(x, x)
335333
z = multiply(y, x)
336334
w = subtract(z, x)
337-
return y, w
335+
relax.output(y, w)
338336
t = divide(y, z)
339337
return t
340338

341339

340+
@pytest.mark.xfail
341+
def test_dataflow_unbound_outputs():
342+
@rx.script
343+
def foo(x: Tensor[_, _]):
344+
with relax.dataflow():
345+
y = add(x, x)
346+
z = multiply(y, x)
347+
w = subtract(z, x)
348+
relax.output(x, y, w, q)
349+
t = divide(y, z)
350+
return t
351+
352+
353+
@pytest.mark.xfail
354+
def test_invalid_special_op_dataflow():
355+
@rx.script
356+
def foo(x: Tensor):
357+
y = add(x, x)
358+
z = relax.dataflow()
359+
return z
360+
361+
362+
@pytest.mark.xfail
363+
def test_invalid_special_op_output():
364+
@rx.script
365+
def foo(x: Tensor):
366+
y = add(x, x)
367+
z = relax.output(y)
368+
return z
369+
370+
342371
@pytest.mark.xfail
343372
def test_func_no_return_fail():
344373
@rx.script

tests/python/relax/test_printer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def foo(x: Tensor[_, _]):
8383
y = add(x, x)
8484
z = multiply(y, x)
8585
w = subtract(z, x)
86-
return y, w
86+
relax.output(y, w)
8787
t = divide(y, w)
8888
return t
8989

@@ -98,7 +98,7 @@ def foo(x: Tensor[_, _]):
9898
z = multiply(y, x)
9999
relax.match_shape((n, m), z.shape)
100100
w: Tensor[(n, m), _] = subtract(z, x)
101-
return y, w
101+
relax.output(y, w)
102102
t: Tensor[(n, m), _] = divide(y, w)
103103
return t
104104

0 commit comments

Comments
 (0)