2222 */
2323
2424#include < tvm/relax/ir_builder.h>
25+ #include < tvm/relax/op_attr_types.h>
2526#include < tvm/relay/op.h>
2627
2728namespace tvm {
@@ -38,59 +39,84 @@ IRBuilder IRBuilderNode::Create() {
3839
3940void IRBuilderNode::FillFuncNameParam (const Array<Var>& params, const std::string& func_name) {
4041 if (!func_name.empty ()) {
41- this ->func .func_name = GlobalVar (func_name);
42+ this ->func_ .func_name = GlobalVar (func_name);
4243 }
43-
44- this ->func .params = params;
44+
45+ this ->func_ .params = params;
4546}
4647
4748void IRBuilderNode::BuildFunction () {
48- SeqExpr seq = SeqExpr (this ->func .binding_blocks , this ->func .ret );
49- this ->func .func = Function (this ->func .func_name , this ->func .params , seq, {});
50- this ->global_var_counter = 0 ;
49+ SeqExpr seq = SeqExpr (this ->func_ .binding_blocks , this ->func_ .ret );
50+ this ->func_ .func = Function (this ->func_ .func_name , this ->func_ .params , seq, {});
51+ this ->global_var_counter_ = 0 ;
5152}
5253
5354void IRBuilderNode::BuildBlock () {
54- if (!this ->func .bindings .empty ()) {
55- if (is_dataflow ) {
56- this ->func .binding_blocks .emplace_back (DataflowBlock (this ->func .bindings ));
55+ if (!this ->func_ .bindings .empty ()) {
56+ if (is_dataflow_ ) {
57+ this ->func_ .binding_blocks .emplace_back (DataflowBlock (this ->func_ .bindings ));
5758 } else {
58- this ->func .binding_blocks .emplace_back (BindingBlock (this ->func .bindings ));
59+ this ->func_ .binding_blocks .emplace_back (BindingBlock (this ->func_ .bindings ));
5960 }
60- this ->func .bindings .clear ();
61+ this ->func_ .bindings .clear ();
6162 }
62- this ->dataflow_var_counter = 0 ;
63- this ->is_dataflow = !this ->is_dataflow ;
63+ this ->dataflow_var_counter_ = 0 ;
64+ this ->is_dataflow_ = !this ->is_dataflow_ ;
65+ }
66+
67+ Optional<RelayExpr> InferShape (const Call& call, DiagnosticContext diag_ctx) {
68+ auto op_map = Op::GetAttrMap<FInferShape>(" FInferShape" );
69+ Op op = Downcast<Op>(call->op );
70+ return op_map[op](call, diag_ctx);
71+ }
72+
73+ Type InferType (const Call& call, DiagnosticContext diag_ctx) {
74+ auto op_map = Op::GetAttrMap<FInferType>(" FInferType" );
75+ Op op = Downcast<Op>(call->op );
76+ return op_map[op](call, diag_ctx);
6477}
6578
6679Var IRBuilderNode::Emit (const Call& call) {
6780 Var var;
68- if (is_dataflow ) {
69- var = DataflowVar (Id (" lv" + std::to_string (dataflow_var_counter ++)), NullOpt, NullOpt);
81+ if (is_dataflow_ ) {
82+ var = DataflowVar (Id (" lv" + std::to_string (dataflow_var_counter_ ++)), NullOpt, NullOpt);
7083 } else {
71- var = Var (Id (" gv" + std::to_string (global_var_counter++)), NullOpt, NullOpt);
84+ var = Var (Id (" gv" + std::to_string (global_var_counter_++)), NullOpt, NullOpt);
85+ }
86+
87+ // Shape inference
88+ auto inferred_shape = InferShape (call, this ->diag_ctx_ );
89+ if (inferred_shape.defined ()) {
90+ if (auto * shape_expr = inferred_shape.value ().as <ShapeExprNode>()) {
91+ call->shape_ = GetRef<Expr>(shape_expr);
92+ var->shape_ = call->shape_ ;
93+ }
7294 }
95+ // Type inference
96+ auto inferred_type = InferType (call, this ->diag_ctx_ );
97+ call->checked_type_ = inferred_type;
98+ var->checked_type_ = inferred_type;
7399
74- this ->func .bindings .emplace_back (VarBinding (var, call));
100+ this ->func_ .bindings .emplace_back (VarBinding (var, call));
75101 return var;
76102}
77103
78104Var IRBuilderNode::EmitOutput (const Expr& output) {
79105 Var ret;
80- if (is_dataflow ) {
81- ret = Var (Id (" gv" + std::to_string (global_var_counter ++)), NullOpt, NullOpt);
106+ if (is_dataflow_ ) {
107+ ret = Var (Id (" gv" + std::to_string (global_var_counter_ ++)), NullOpt, NullOpt);
82108 ret->shape_ = output->shape_ ;
83109 ret->checked_type_ = output->checked_type_ ;
84- this ->func .bindings .emplace_back (VarBinding (ret, output));
110+ this ->func_ .bindings .emplace_back (VarBinding (ret, output));
85111 } else {
86- this ->func .ret = output;
112+ this ->func_ .ret = output;
87113 }
88114 return ret;
89115}
90116
91- Function IRBuilderNode::Get () { return this ->func .func ; }
117+ Function IRBuilderNode::Get () { return this ->func_ .func ; }
92118
93- std::vector<BindingBlock> IRBuilderNode::GetBlocks () { return this ->func .binding_blocks ; }
119+ std::vector<BindingBlock> IRBuilderNode::GetBlocks () { return this ->func_ .binding_blocks ; }
94120
95121class FunctionScope ::Internal {
96122 public:
@@ -121,20 +147,16 @@ DataflowScope::DataflowScope(IRBuilder ib) {
121147 data_ = std::move (n);
122148}
123149
124- void DataflowScope::EnterWithScope () {
125- this ->get ()->ir_builder ->BuildBlock ();
126- }
150+ void DataflowScope::EnterWithScope () { this ->get ()->ir_builder ->BuildBlock (); }
127151
128- void DataflowScope::ExitWithScope () {
129- this ->get ()->ir_builder ->BuildBlock ();
130- }
152+ void DataflowScope::ExitWithScope () { this ->get ()->ir_builder ->BuildBlock (); }
131153
132154TVM_REGISTER_GLOBAL (" relax.IRBuilderCreate" ).set_body_typed(IRBuilderNode::Create);
133155
134156TVM_REGISTER_GLOBAL (" relax.IRBuilderFillFuncNameParam" )
135- .set_body_typed([](IRBuilder builder, const Array<Var>& params, const std::string& func_name) {
136- return builder->FillFuncNameParam (params, func_name);
137- });
157+ .set_body_typed([](IRBuilder builder, const Array<Var>& params, const std::string& func_name) {
158+ return builder->FillFuncNameParam (params, func_name);
159+ });
138160
139161TVM_REGISTER_GLOBAL (" relax.IRBuilderBuildFunction" ).set_body_typed([](IRBuilder builder) {
140162 return builder->BuildFunction ();
@@ -145,9 +167,9 @@ TVM_REGISTER_GLOBAL("relax.IRBuilderEmit").set_body_typed([](IRBuilder builder,
145167});
146168
147169TVM_REGISTER_GLOBAL (" relax.IRBuilderEmitOutput" )
148- .set_body_typed([](IRBuilder builder, const Expr& output) {
149- return builder->EmitOutput (output);
150- });
170+ .set_body_typed([](IRBuilder builder, const Expr& output) {
171+ return builder->EmitOutput (output);
172+ });
151173
152174TVM_REGISTER_GLOBAL (" relax.IRBuilderGet" ).set_body_typed([](IRBuilder builder) {
153175 return builder->Get ();
0 commit comments