@@ -34,27 +34,30 @@ namespace relay {
34
34
35
35
class DynamicToStaticMutator : public MixedModeMutator {
36
36
public:
37
- DynamicToStaticMutator () {
37
+ DynamicToStaticMutator (IRModule mod, Function func) : mod_(mod), func_(func ) {
38
38
op_map_ = {
39
39
{Op::Get (" dyn.reshape" ),
40
- [](const CallNode* call_node) {
41
- if (const ConstantNode* shape = call_node->args [1 ].as <ConstantNode>()) {
40
+ [this ](const CallNode* call_node) {
41
+ auto args = PrepareArgs (call_node);
42
+ if (const ConstantNode* shape = args[1 ].as <ConstantNode>()) {
42
43
ICHECK_EQ (shape->data ->ndim , 1 );
43
44
return MakeReshape (call_node->args [0 ], ToVector (shape->data ));
44
45
}
45
46
return Expr (nullptr );
46
47
}},
47
48
{Op::Get (" dyn.tile" ),
48
- [](const CallNode* call_node) {
49
- if (const ConstantNode* reps = call_node->args [1 ].as <ConstantNode>()) {
49
+ [this ](const CallNode* call_node) {
50
+ auto args = PrepareArgs (call_node);
51
+ if (const ConstantNode* reps = args[1 ].as <ConstantNode>()) {
50
52
ICHECK_EQ (reps->data ->ndim , 1 );
51
53
return MakeTile (call_node->args [0 ], ToVector (reps->data ));
52
54
}
53
55
return Expr (nullptr );
54
56
}},
55
57
{Op::Get (" dyn.topk" ),
56
- [](const CallNode* call_node) {
57
- if (const ConstantNode* k = call_node->args [1 ].as <ConstantNode>()) {
58
+ [this ](const CallNode* call_node) {
59
+ auto args = PrepareArgs (call_node);
60
+ if (const ConstantNode* k = args[1 ].as <ConstantNode>()) {
58
61
const TopKAttrs* param = call_node->attrs .as <TopKAttrs>();
59
62
ICHECK (param);
60
63
return MakeTopK (call_node->args [0 ], static_cast <int >(ToScalar (k->data , 0 )),
@@ -63,34 +66,38 @@ class DynamicToStaticMutator : public MixedModeMutator {
63
66
return Expr (nullptr );
64
67
}},
65
68
{Op::Get (" dyn.broadcast_to" ),
66
- [](const CallNode* call_node) {
67
- if (const ConstantNode* shape = call_node->args [1 ].as <ConstantNode>()) {
69
+ [this ](const CallNode* call_node) {
70
+ auto args = PrepareArgs (call_node);
71
+ if (const ConstantNode* shape = args[1 ].as <ConstantNode>()) {
68
72
ICHECK_EQ (shape->data ->ndim , 1 );
69
73
return MakeBroadCastTo (call_node->args [0 ], ToVector (shape->data ));
70
74
}
71
75
return Expr (nullptr );
72
76
}},
73
77
{Op::Get (" dyn.zeros" ),
74
- [](const CallNode* call_node) {
75
- if (const ConstantNode* shape = call_node->args [0 ].as <ConstantNode>()) {
78
+ [this ](const CallNode* call_node) {
79
+ auto args = PrepareArgs (call_node);
80
+ if (const ConstantNode* shape = args[0 ].as <ConstantNode>()) {
76
81
const InitOpAttrs* param = call_node->attrs .as <InitOpAttrs>();
77
82
ICHECK (param);
78
83
return MakeZeros (ToVector (shape->data ), param->dtype );
79
84
}
80
85
return Expr (nullptr );
81
86
}},
82
87
{Op::Get (" dyn.ones" ),
83
- [](const CallNode* call_node) {
84
- if (const ConstantNode* shape = call_node->args [0 ].as <ConstantNode>()) {
88
+ [this ](const CallNode* call_node) {
89
+ auto args = PrepareArgs (call_node);
90
+ if (const ConstantNode* shape = args[0 ].as <ConstantNode>()) {
85
91
const InitOpAttrs* param = call_node->attrs .as <InitOpAttrs>();
86
92
ICHECK (param);
87
93
return MakeOnes (ToVector (shape->data ), param->dtype );
88
94
}
89
95
return Expr (nullptr );
90
96
}},
91
97
{Op::Get (" dyn.one_hot" ),
92
- [](const CallNode* call_node) {
93
- if (const ConstantNode* depth = call_node->args [3 ].as <ConstantNode>()) {
98
+ [this ](const CallNode* call_node) {
99
+ auto args = PrepareArgs (call_node);
100
+ if (const ConstantNode* depth = args[3 ].as <ConstantNode>()) {
94
101
const OneHotAttrs* param = call_node->attrs .as <OneHotAttrs>();
95
102
ICHECK (param);
96
103
return MakeOneHot (call_node->args [0 ], call_node->args [1 ], call_node->args [2 ],
@@ -100,8 +107,9 @@ class DynamicToStaticMutator : public MixedModeMutator {
100
107
return Expr (nullptr );
101
108
}},
102
109
{Op::Get (" dyn.image.resize" ),
103
- [](const CallNode* call_node) {
104
- if (const ConstantNode* size = call_node->args [1 ].as <ConstantNode>()) {
110
+ [this ](const CallNode* call_node) {
111
+ auto args = PrepareArgs (call_node);
112
+ if (const ConstantNode* size = args[1 ].as <ConstantNode>()) {
105
113
const ResizeAttrs* param = call_node->attrs .as <ResizeAttrs>();
106
114
ICHECK (param);
107
115
auto size_int = ToVector (size->data );
@@ -115,8 +123,9 @@ class DynamicToStaticMutator : public MixedModeMutator {
115
123
return Expr (nullptr );
116
124
}},
117
125
{Op::Get (" dyn.full" ),
118
- [](const CallNode* call_node) {
119
- if (const ConstantNode* shape = call_node->args [1 ].as <ConstantNode>()) {
126
+ [this ](const CallNode* call_node) {
127
+ auto args = PrepareArgs (call_node);
128
+ if (const ConstantNode* shape = args[1 ].as <ConstantNode>()) {
120
129
ICHECK_EQ (shape->data ->ndim , 1 );
121
130
const InitOpAttrs* param = call_node->attrs .as <InitOpAttrs>();
122
131
ICHECK (param);
@@ -125,9 +134,10 @@ class DynamicToStaticMutator : public MixedModeMutator {
125
134
return Expr (nullptr );
126
135
}},
127
136
{Op::Get (" dyn.nn.upsampling" ),
128
- [](const CallNode* call_node) {
129
- const ConstantNode* scale_h = call_node->args [1 ].as <ConstantNode>();
130
- const ConstantNode* scale_w = call_node->args [2 ].as <ConstantNode>();
137
+ [this ](const CallNode* call_node) {
138
+ auto args = PrepareArgs (call_node);
139
+ const ConstantNode* scale_h = args[1 ].as <ConstantNode>();
140
+ const ConstantNode* scale_w = args[2 ].as <ConstantNode>();
131
141
if (scale_h && scale_w) {
132
142
ICHECK_EQ (scale_h->data ->ndim , 0 );
133
143
ICHECK_EQ (scale_w->data ->ndim , 0 );
@@ -140,10 +150,11 @@ class DynamicToStaticMutator : public MixedModeMutator {
140
150
return Expr (nullptr );
141
151
}},
142
152
{Op::Get (" dyn.nn.upsampling3d" ),
143
- [](const CallNode* call_node) {
144
- const ConstantNode* scale_d = call_node->args [1 ].as <ConstantNode>();
145
- const ConstantNode* scale_h = call_node->args [2 ].as <ConstantNode>();
146
- const ConstantNode* scale_w = call_node->args [3 ].as <ConstantNode>();
153
+ [this ](const CallNode* call_node) {
154
+ auto args = PrepareArgs (call_node);
155
+ const ConstantNode* scale_d = args[1 ].as <ConstantNode>();
156
+ const ConstantNode* scale_h = args[2 ].as <ConstantNode>();
157
+ const ConstantNode* scale_w = args[3 ].as <ConstantNode>();
147
158
if (scale_d && scale_h && scale_w) {
148
159
ICHECK_EQ (scale_d->data ->ndim , 0 );
149
160
ICHECK_EQ (scale_h->data ->ndim , 0 );
@@ -159,9 +170,10 @@ class DynamicToStaticMutator : public MixedModeMutator {
159
170
return Expr (nullptr );
160
171
}},
161
172
{Op::Get (" dyn.nn.pad" ),
162
- [](const CallNode* call_node) {
163
- const ConstantNode* pad_width = call_node->args [1 ].as <ConstantNode>();
164
- const ConstantNode* pad_fill = call_node->args [2 ].as <ConstantNode>();
173
+ [this ](const CallNode* call_node) {
174
+ auto args = PrepareArgs (call_node);
175
+ const ConstantNode* pad_width = args[1 ].as <ConstantNode>();
176
+ const ConstantNode* pad_fill = args[2 ].as <ConstantNode>();
165
177
if (pad_width && pad_fill) {
166
178
ICHECK_EQ (pad_fill->data ->ndim , 0 ); // pad_val is 1d
167
179
ICHECK_EQ (pad_width->data ->ndim , 2 ); // pad_width is 2d
@@ -174,10 +186,11 @@ class DynamicToStaticMutator : public MixedModeMutator {
174
186
return Expr (nullptr );
175
187
}},
176
188
{Op::Get (" dyn.strided_slice" ),
177
- [](const CallNode* call_node) {
178
- const ConstantNode* begin = call_node->args [1 ].as <ConstantNode>();
179
- const ConstantNode* end = call_node->args [2 ].as <ConstantNode>();
180
- const ConstantNode* stride = call_node->args [3 ].as <ConstantNode>();
189
+ [this ](const CallNode* call_node) {
190
+ auto args = PrepareArgs (call_node);
191
+ const ConstantNode* begin = args[1 ].as <ConstantNode>();
192
+ const ConstantNode* end = args[2 ].as <ConstantNode>();
193
+ const ConstantNode* stride = args[3 ].as <ConstantNode>();
181
194
if (begin && end && stride) {
182
195
ICHECK_EQ (begin->data ->ndim , 1 );
183
196
ICHECK_EQ (end->data ->ndim , 1 );
@@ -190,8 +203,9 @@ class DynamicToStaticMutator : public MixedModeMutator {
190
203
return Expr (nullptr );
191
204
}},
192
205
{Op::Get (" dyn.sparse_to_dense" ),
193
- [](const CallNode* call_node) {
194
- const ConstantNode* output_shape = call_node->args [3 ].as <ConstantNode>();
206
+ [this ](const CallNode* call_node) {
207
+ auto args = PrepareArgs (call_node);
208
+ const ConstantNode* output_shape = args[3 ].as <ConstantNode>();
195
209
if (output_shape) {
196
210
ICHECK_EQ (output_shape->data ->ndim , 1 );
197
211
return MakeSparseToDense (call_node->args [0 ], ToVector (output_shape->data ),
@@ -200,6 +214,45 @@ class DynamicToStaticMutator : public MixedModeMutator {
200
214
return Expr (nullptr );
201
215
}},
202
216
};
217
+ Map<BaseFunc, GlobalVar> vars;
218
+ for (auto kv : mod_->functions ) {
219
+ vars.Set (kv.second , kv.first );
220
+ }
221
+ gv_ = vars[func_];
222
+ }
223
+
224
+ Expr PrepareInput (const Expr& expr) {
225
+ BaseFunc func;
226
+ if (auto * func_node = expr.as <BaseFuncNode>()) {
227
+ func = GetRef<BaseFunc>(func_node);
228
+ } else {
229
+ func =
230
+ relay::Function (relay::FreeVars (expr), expr, Type (), relay::FreeTypeVars (expr, mod_), {});
231
+ }
232
+ mod_->Update (gv_, func);
233
+ mod_ = transform::FoldConstant ()(mod_);
234
+ mod_ = transform::InferType ()(mod_);
235
+ mod_ = transform::FoldConstant ()(mod_);
236
+ mod_ = transform::InferType ()(mod_);
237
+ Expr out;
238
+ if (expr.as <FunctionNode>()) {
239
+ out = mod_->Lookup (gv_);
240
+ } else {
241
+ out = mod_->Lookup (gv_).as <FunctionNode>()->body ;
242
+ }
243
+ return out;
244
+ }
245
+
246
+ std::vector<Expr> PrepareArgs (const CallNode* call_node) {
247
+ std::vector<Expr> args;
248
+ for (auto arg : call_node->args ) {
249
+ if (arg.as <ConstantNode>()) {
250
+ args.emplace_back (arg);
251
+ } else {
252
+ args.emplace_back (PrepareInput (arg));
253
+ }
254
+ }
255
+ return args;
203
256
}
204
257
205
258
private:
@@ -222,35 +275,19 @@ class DynamicToStaticMutator : public MixedModeMutator {
222
275
}
223
276
return post ;
224
277
}
278
+
225
279
std::unordered_map<Expr, std::function<Expr(const CallNode*)>, ObjectPtrHash, ObjectPtrEqual>
226
280
op_map_;
281
+ IRModule mod_;
282
+ Function func_;
283
+ GlobalVar gv_;
227
284
};
228
285
229
286
Expr DynamicToStatic (Function f, IRModule m) {
230
- Expr pre = f;
231
- Expr expr = f;
232
- auto fold_const = transform::FoldConstant ();
233
- auto infer_type = transform::InferType ();
234
- DynamicToStaticMutator mutator;
235
- Map<BaseFunc, GlobalVar> vars;
236
- for (auto kv : m->functions ) {
237
- vars.Set (kv.second , kv.first );
238
- }
239
- const auto gv = vars[f];
240
- // Put a limit on the while loop
241
- // Primarily used to prevent accidental infinite lops in development
242
- const int loop_limit = 1000 ;
243
- int i = 0 ;
244
- do {
245
- pre = expr;
246
- // TODO(mbrookhart): Is it possible to run these passes JUST on the current function?
247
- m = infer_type (m);
248
- m = fold_const (m);
249
- expr = mutator.Mutate (m->functions [gv]);
250
- m->Update (gv, Downcast<BaseFunc>(expr));
251
- i += 1 ;
252
- } while (!StructuralEqual ()(pre , expr) && i < loop_limit);
253
- return expr;
287
+ DynamicToStaticMutator mutator (m, f);
288
+ Expr expr = mutator.Mutate (f);
289
+ Expr out = mutator.PrepareInput (expr);
290
+ return out;
254
291
}
255
292
256
293
namespace transform {
0 commit comments