Skip to content

Commit de80f68

Browse files
altanhjunrushao
authored andcommitted
[Parser][Printer] explicitly parse and print attrs_type_key in calls (apache#19)
* relax call_packed arity, return IRModule factory, print IRModule PrimFuncs * explicitly parse and print attrs_type_key on calls * print type even when attrs has no fields
1 parent 9860b33 commit de80f68

File tree

4 files changed

+30
-10
lines changed

4 files changed

+30
-10
lines changed

python/tvm/relax/parser.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -871,16 +871,25 @@ def parse_call(self, expr: ast.Call) -> Union[tir.PrimExpr, rx.Expr]:
871871
self.report_error(f"unsupported function in call: {op}", expr.func_name.span)
872872

873873
# parse call attributes if applicable
874-
if isinstance(op, rx.ExternFunc) or (isinstance(op, tvm.ir.Op) and op.attrs_type_key != ""):
875-
attrs_type_key = "DictAttrs" if isinstance(op, rx.ExternFunc) else op.attrs_type_key
876-
kwargs = {}
877-
for key, val in expr.keyword_params.items():
878-
assert isinstance(key, ast.Constant) and isinstance(key.value, str)
879-
# TODO(@altanh): might need separate attribute parsing eventually
880-
kwargs[key.value] = self.transform_expr(val)
881-
attrs = tvm.ir.attrs.make_node(attrs_type_key, **kwargs)
874+
kwargs = {}
875+
for key, val in expr.keyword_params.items():
876+
assert isinstance(key, ast.Constant) and isinstance(key.value, str)
877+
# TODO(@altanh): might need separate attribute parsing eventually
878+
kwargs[key.value] = self.transform_expr(val)
879+
880+
is_default = False
881+
if "attrs_type_key" in kwargs:
882+
attrs_type_key = kwargs["attrs_type_key"]
883+
kwargs.pop("attrs_type_key")
884+
elif isinstance(op, tvm.ir.Op) and op.attrs_type_key != "":
885+
attrs_type_key = op.attrs_type_key
882886
else:
883-
attrs = None
887+
attrs_type_key = "DictAttrs"
888+
is_default = True
889+
890+
attrs = None
891+
if kwargs or not is_default:
892+
attrs = tvm.ir.attrs.make_node(attrs_type_key, **kwargs)
884893

885894
return relay.Call(op, args, attrs=attrs, span=self.to_tvm_span(expr.span))
886895

src/relay/printer/relax_script_printer.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,9 @@ Doc RelaxScriptPrinter::VisitNode_(const relay::CallNode* op) {
194194
doc << "(" << Doc::Concat(args, Doc::Text(", "));
195195

196196
std::vector<Doc> attrs = PrintAttrs(op->attrs);
197+
if (op->attrs.defined()) {
198+
attrs.push_back(Doc::Text("attrs_type_key=") << Doc::StrLiteral(op->attrs->GetTypeKey()));
199+
}
197200
if (!attrs.empty()) {
198201
doc << ", " << Doc::Concat(attrs);
199202
}

tests/python/relax/test_parser.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,17 +435,22 @@ def test_call_packed():
435435
def f(x: Tensor[(3, 3), "float32"]):
436436
# test that we can intro dim vars
437437
z: Tensor[(n, m), "float32"] = relax.call_packed("contrib.my_matmul", x, x, mp=False)
438+
w = relax.call_packed(
439+
"contrib.my_shape_of", x, dtype="int32", attrs_type_key="relay.attrs.ShapeOfAttrs"
440+
)
438441
return z
439442

440443
x = f.params[0]
441-
(z_bind,) = f.body.blocks[0].bindings
444+
(z_bind, w_bind) = f.body.blocks[0].bindings
442445
check_tensor_var(z_bind.var, ("n", "m"), "float32")
443446

444447
assert isinstance(z_bind.value.op, rx.ExternFunc)
445448
assert z_bind.value.op.global_symbol == "contrib.my_matmul"
446449
assert "mp" in z_bind.value.attrs and z_bind.value.attrs["mp"] == False
447450
assert structural_equal(z_bind.value.args, [x, x])
448451

452+
assert isinstance(w_bind.value.attrs, relay.op.op_attrs.ShapeOfAttrs)
453+
449454

450455
def test_primexpr_arithmetic():
451456
@rx.script

tests/python/relax/test_printer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ def test_call_packed():
145145
def foo(x: Tensor[(3, 3), "float32"]):
146146
# test that we can intro dim vars
147147
z: Tensor[(n, m), "float32"] = relax.call_packed("contrib.my_matmul", x, x, mp=False)
148+
w = relax.call_packed(
149+
"contrib.my_shape_of", x, dtype="int32", attrs_type_key="relay.attrs.ShapeOfAttrs"
150+
)
148151
return z
149152

150153
check_roundtrip(foo)

0 commit comments

Comments
 (0)