Skip to content

Commit 3635945

Browse files
author
Matthew Brookhart
authored
Refactor Dynamic to Static (apache#7368)
* DynamicToStatic Refactor * fix test * add regression tests * cleanup * skip PrepareInput if the arg is already a constant * fix an issue with type inference with global functions
1 parent 0bd259a commit 3635945

File tree

2 files changed

+138
-61
lines changed

2 files changed

+138
-61
lines changed

src/relay/transforms/dynamic_to_static.cc

Lines changed: 96 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -34,27 +34,30 @@ namespace relay {
3434

3535
class DynamicToStaticMutator : public MixedModeMutator {
3636
public:
37-
DynamicToStaticMutator() {
37+
DynamicToStaticMutator(IRModule mod, Function func) : mod_(mod), func_(func) {
3838
op_map_ = {
3939
{Op::Get("dyn.reshape"),
40-
[](const CallNode* call_node) {
41-
if (const ConstantNode* shape = call_node->args[1].as<ConstantNode>()) {
40+
[this](const CallNode* call_node) {
41+
auto args = PrepareArgs(call_node);
42+
if (const ConstantNode* shape = args[1].as<ConstantNode>()) {
4243
ICHECK_EQ(shape->data->ndim, 1);
4344
return MakeReshape(call_node->args[0], ToVector(shape->data));
4445
}
4546
return Expr(nullptr);
4647
}},
4748
{Op::Get("dyn.tile"),
48-
[](const CallNode* call_node) {
49-
if (const ConstantNode* reps = call_node->args[1].as<ConstantNode>()) {
49+
[this](const CallNode* call_node) {
50+
auto args = PrepareArgs(call_node);
51+
if (const ConstantNode* reps = args[1].as<ConstantNode>()) {
5052
ICHECK_EQ(reps->data->ndim, 1);
5153
return MakeTile(call_node->args[0], ToVector(reps->data));
5254
}
5355
return Expr(nullptr);
5456
}},
5557
{Op::Get("dyn.topk"),
56-
[](const CallNode* call_node) {
57-
if (const ConstantNode* k = call_node->args[1].as<ConstantNode>()) {
58+
[this](const CallNode* call_node) {
59+
auto args = PrepareArgs(call_node);
60+
if (const ConstantNode* k = args[1].as<ConstantNode>()) {
5861
const TopKAttrs* param = call_node->attrs.as<TopKAttrs>();
5962
ICHECK(param);
6063
return MakeTopK(call_node->args[0], static_cast<int>(ToScalar(k->data, 0)),
@@ -63,34 +66,38 @@ class DynamicToStaticMutator : public MixedModeMutator {
6366
return Expr(nullptr);
6467
}},
6568
{Op::Get("dyn.broadcast_to"),
66-
[](const CallNode* call_node) {
67-
if (const ConstantNode* shape = call_node->args[1].as<ConstantNode>()) {
69+
[this](const CallNode* call_node) {
70+
auto args = PrepareArgs(call_node);
71+
if (const ConstantNode* shape = args[1].as<ConstantNode>()) {
6872
ICHECK_EQ(shape->data->ndim, 1);
6973
return MakeBroadCastTo(call_node->args[0], ToVector(shape->data));
7074
}
7175
return Expr(nullptr);
7276
}},
7377
{Op::Get("dyn.zeros"),
74-
[](const CallNode* call_node) {
75-
if (const ConstantNode* shape = call_node->args[0].as<ConstantNode>()) {
78+
[this](const CallNode* call_node) {
79+
auto args = PrepareArgs(call_node);
80+
if (const ConstantNode* shape = args[0].as<ConstantNode>()) {
7681
const InitOpAttrs* param = call_node->attrs.as<InitOpAttrs>();
7782
ICHECK(param);
7883
return MakeZeros(ToVector(shape->data), param->dtype);
7984
}
8085
return Expr(nullptr);
8186
}},
8287
{Op::Get("dyn.ones"),
83-
[](const CallNode* call_node) {
84-
if (const ConstantNode* shape = call_node->args[0].as<ConstantNode>()) {
88+
[this](const CallNode* call_node) {
89+
auto args = PrepareArgs(call_node);
90+
if (const ConstantNode* shape = args[0].as<ConstantNode>()) {
8591
const InitOpAttrs* param = call_node->attrs.as<InitOpAttrs>();
8692
ICHECK(param);
8793
return MakeOnes(ToVector(shape->data), param->dtype);
8894
}
8995
return Expr(nullptr);
9096
}},
9197
{Op::Get("dyn.one_hot"),
92-
[](const CallNode* call_node) {
93-
if (const ConstantNode* depth = call_node->args[3].as<ConstantNode>()) {
98+
[this](const CallNode* call_node) {
99+
auto args = PrepareArgs(call_node);
100+
if (const ConstantNode* depth = args[3].as<ConstantNode>()) {
94101
const OneHotAttrs* param = call_node->attrs.as<OneHotAttrs>();
95102
ICHECK(param);
96103
return MakeOneHot(call_node->args[0], call_node->args[1], call_node->args[2],
@@ -100,8 +107,9 @@ class DynamicToStaticMutator : public MixedModeMutator {
100107
return Expr(nullptr);
101108
}},
102109
{Op::Get("dyn.image.resize"),
103-
[](const CallNode* call_node) {
104-
if (const ConstantNode* size = call_node->args[1].as<ConstantNode>()) {
110+
[this](const CallNode* call_node) {
111+
auto args = PrepareArgs(call_node);
112+
if (const ConstantNode* size = args[1].as<ConstantNode>()) {
105113
const ResizeAttrs* param = call_node->attrs.as<ResizeAttrs>();
106114
ICHECK(param);
107115
auto size_int = ToVector(size->data);
@@ -115,8 +123,9 @@ class DynamicToStaticMutator : public MixedModeMutator {
115123
return Expr(nullptr);
116124
}},
117125
{Op::Get("dyn.full"),
118-
[](const CallNode* call_node) {
119-
if (const ConstantNode* shape = call_node->args[1].as<ConstantNode>()) {
126+
[this](const CallNode* call_node) {
127+
auto args = PrepareArgs(call_node);
128+
if (const ConstantNode* shape = args[1].as<ConstantNode>()) {
120129
ICHECK_EQ(shape->data->ndim, 1);
121130
const InitOpAttrs* param = call_node->attrs.as<InitOpAttrs>();
122131
ICHECK(param);
@@ -125,9 +134,10 @@ class DynamicToStaticMutator : public MixedModeMutator {
125134
return Expr(nullptr);
126135
}},
127136
{Op::Get("dyn.nn.upsampling"),
128-
[](const CallNode* call_node) {
129-
const ConstantNode* scale_h = call_node->args[1].as<ConstantNode>();
130-
const ConstantNode* scale_w = call_node->args[2].as<ConstantNode>();
137+
[this](const CallNode* call_node) {
138+
auto args = PrepareArgs(call_node);
139+
const ConstantNode* scale_h = args[1].as<ConstantNode>();
140+
const ConstantNode* scale_w = args[2].as<ConstantNode>();
131141
if (scale_h && scale_w) {
132142
ICHECK_EQ(scale_h->data->ndim, 0);
133143
ICHECK_EQ(scale_w->data->ndim, 0);
@@ -140,10 +150,11 @@ class DynamicToStaticMutator : public MixedModeMutator {
140150
return Expr(nullptr);
141151
}},
142152
{Op::Get("dyn.nn.upsampling3d"),
143-
[](const CallNode* call_node) {
144-
const ConstantNode* scale_d = call_node->args[1].as<ConstantNode>();
145-
const ConstantNode* scale_h = call_node->args[2].as<ConstantNode>();
146-
const ConstantNode* scale_w = call_node->args[3].as<ConstantNode>();
153+
[this](const CallNode* call_node) {
154+
auto args = PrepareArgs(call_node);
155+
const ConstantNode* scale_d = args[1].as<ConstantNode>();
156+
const ConstantNode* scale_h = args[2].as<ConstantNode>();
157+
const ConstantNode* scale_w = args[3].as<ConstantNode>();
147158
if (scale_d && scale_h && scale_w) {
148159
ICHECK_EQ(scale_d->data->ndim, 0);
149160
ICHECK_EQ(scale_h->data->ndim, 0);
@@ -159,9 +170,10 @@ class DynamicToStaticMutator : public MixedModeMutator {
159170
return Expr(nullptr);
160171
}},
161172
{Op::Get("dyn.nn.pad"),
162-
[](const CallNode* call_node) {
163-
const ConstantNode* pad_width = call_node->args[1].as<ConstantNode>();
164-
const ConstantNode* pad_fill = call_node->args[2].as<ConstantNode>();
173+
[this](const CallNode* call_node) {
174+
auto args = PrepareArgs(call_node);
175+
const ConstantNode* pad_width = args[1].as<ConstantNode>();
176+
const ConstantNode* pad_fill = args[2].as<ConstantNode>();
165177
if (pad_width && pad_fill) {
166178
ICHECK_EQ(pad_fill->data->ndim, 0); // pad_val is 1d
167179
ICHECK_EQ(pad_width->data->ndim, 2); // pad_width is 2d
@@ -174,10 +186,11 @@ class DynamicToStaticMutator : public MixedModeMutator {
174186
return Expr(nullptr);
175187
}},
176188
{Op::Get("dyn.strided_slice"),
177-
[](const CallNode* call_node) {
178-
const ConstantNode* begin = call_node->args[1].as<ConstantNode>();
179-
const ConstantNode* end = call_node->args[2].as<ConstantNode>();
180-
const ConstantNode* stride = call_node->args[3].as<ConstantNode>();
189+
[this](const CallNode* call_node) {
190+
auto args = PrepareArgs(call_node);
191+
const ConstantNode* begin = args[1].as<ConstantNode>();
192+
const ConstantNode* end = args[2].as<ConstantNode>();
193+
const ConstantNode* stride = args[3].as<ConstantNode>();
181194
if (begin && end && stride) {
182195
ICHECK_EQ(begin->data->ndim, 1);
183196
ICHECK_EQ(end->data->ndim, 1);
@@ -190,8 +203,9 @@ class DynamicToStaticMutator : public MixedModeMutator {
190203
return Expr(nullptr);
191204
}},
192205
{Op::Get("dyn.sparse_to_dense"),
193-
[](const CallNode* call_node) {
194-
const ConstantNode* output_shape = call_node->args[3].as<ConstantNode>();
206+
[this](const CallNode* call_node) {
207+
auto args = PrepareArgs(call_node);
208+
const ConstantNode* output_shape = args[3].as<ConstantNode>();
195209
if (output_shape) {
196210
ICHECK_EQ(output_shape->data->ndim, 1);
197211
return MakeSparseToDense(call_node->args[0], ToVector(output_shape->data),
@@ -200,6 +214,45 @@ class DynamicToStaticMutator : public MixedModeMutator {
200214
return Expr(nullptr);
201215
}},
202216
};
217+
Map<BaseFunc, GlobalVar> vars;
218+
for (auto kv : mod_->functions) {
219+
vars.Set(kv.second, kv.first);
220+
}
221+
gv_ = vars[func_];
222+
}
223+
224+
Expr PrepareInput(const Expr& expr) {
225+
BaseFunc func;
226+
if (auto* func_node = expr.as<BaseFuncNode>()) {
227+
func = GetRef<BaseFunc>(func_node);
228+
} else {
229+
func =
230+
relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod_), {});
231+
}
232+
mod_->Update(gv_, func);
233+
mod_ = transform::FoldConstant()(mod_);
234+
mod_ = transform::InferType()(mod_);
235+
mod_ = transform::FoldConstant()(mod_);
236+
mod_ = transform::InferType()(mod_);
237+
Expr out;
238+
if (expr.as<FunctionNode>()) {
239+
out = mod_->Lookup(gv_);
240+
} else {
241+
out = mod_->Lookup(gv_).as<FunctionNode>()->body;
242+
}
243+
return out;
244+
}
245+
246+
std::vector<Expr> PrepareArgs(const CallNode* call_node) {
247+
std::vector<Expr> args;
248+
for (auto arg : call_node->args) {
249+
if (arg.as<ConstantNode>()) {
250+
args.emplace_back(arg);
251+
} else {
252+
args.emplace_back(PrepareInput(arg));
253+
}
254+
}
255+
return args;
203256
}
204257

205258
private:
@@ -222,35 +275,19 @@ class DynamicToStaticMutator : public MixedModeMutator {
222275
}
223276
return post;
224277
}
278+
225279
std::unordered_map<Expr, std::function<Expr(const CallNode*)>, ObjectPtrHash, ObjectPtrEqual>
226280
op_map_;
281+
IRModule mod_;
282+
Function func_;
283+
GlobalVar gv_;
227284
};
228285

229286
Expr DynamicToStatic(Function f, IRModule m) {
230-
Expr pre = f;
231-
Expr expr = f;
232-
auto fold_const = transform::FoldConstant();
233-
auto infer_type = transform::InferType();
234-
DynamicToStaticMutator mutator;
235-
Map<BaseFunc, GlobalVar> vars;
236-
for (auto kv : m->functions) {
237-
vars.Set(kv.second, kv.first);
238-
}
239-
const auto gv = vars[f];
240-
// Put a limit on the while loop
241-
// Primarily used to prevent accidental infinite lops in development
242-
const int loop_limit = 1000;
243-
int i = 0;
244-
do {
245-
pre = expr;
246-
// TODO(mbrookhart): Is it possible to run these passes JUST on the current function?
247-
m = infer_type(m);
248-
m = fold_const(m);
249-
expr = mutator.Mutate(m->functions[gv]);
250-
m->Update(gv, Downcast<BaseFunc>(expr));
251-
i += 1;
252-
} while (!StructuralEqual()(pre, expr) && i < loop_limit);
253-
return expr;
287+
DynamicToStaticMutator mutator(m, f);
288+
Expr expr = mutator.Mutate(f);
289+
Expr out = mutator.PrepareInput(expr);
290+
return out;
254291
}
255292

256293
namespace transform {

tests/python/relay/test_pass_dynamic_to_static.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,11 +232,11 @@ def verify_ones_zeros(shape, dtype):
232232

233233
func = run_infer_type(relay.Function([x], y))
234234
func2 = run_opt_pass(
235-
run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()
235+
run_opt_pass(func, transform.DynamicToStatic()),
236+
transform.InferType(),
236237
)
237238

238239
zz = func2.body
239-
assert isinstance(zz, relay.Constant)
240240
assert zz.checked_type == relay.ty.TensorType(shape, dtype)
241241

242242
x_data = np.random.uniform(low=1, high=1, size=shape)
@@ -518,5 +518,45 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_
518518
verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0]) # default value not specified
519519

520520

521+
@tvm.testing.uses_gpu
522+
def test_dynamic_to_static_dynamic_rank():
523+
def verify_full(fill_value, fill_shape, dtype):
524+
x = relay.var("x", relay.scalar_type(dtype))
525+
y = relay.var("y", relay.TensorType(fill_shape, "int64"))
526+
shape = relay.shape_of(y)
527+
shape = relay.strided_slice(shape, [0], relay.shape_of(shape))
528+
z = relay.full(x, shape, dtype)
529+
530+
func = relay.Function([x, y], z)
531+
func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
532+
533+
zz = func2.body
534+
assert isinstance(zz, relay.Call)
535+
assert zz.op == relay.op.get("full")
536+
537+
ref_res = np.full(fill_shape, fill_value).astype(dtype)
538+
y_data = np.random.uniform(low=-1, high=1, size=fill_shape).astype("int64")
539+
verify_func(func2, [fill_value, y_data], ref_res)
540+
541+
verify_full(4, (1, 2, 3, 4), "int32")
542+
verify_full(4.0, (1, 2, 8, 10), "float32")
543+
544+
545+
@tvm.testing.uses_gpu
546+
def test_dynamic_to_static_dynamic_if():
547+
x = relay.var("x", relay.TensorType((2, 2), "int64"))
548+
cond = relay.const(1)
549+
iff = relay.If(cond, relay.reshape(x, [1, 4]), relay.reshape(x, (4, 1)))
550+
551+
func = relay.Function([x], iff)
552+
func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
553+
554+
zz = func2.body
555+
assert isinstance(zz, relay.Call)
556+
assert zz.op == relay.op.get("reshape")
557+
x_data = np.random.uniform(low=-1, high=1, size=(2, 2)).astype("int64")
558+
verify_func(func2, [x_data], x_data.reshape(1, 4))
559+
560+
521561
if __name__ == "__main__":
522562
pytest.main([__file__])

0 commit comments

Comments
 (0)