Skip to content

Commit b3331bc

Browse files
committed
Repair even more tests
1 parent 824c907 commit b3331bc

17 files changed

+112
-51
lines changed

python/tvm/relay/quantize/_partition_conversions.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def fuse_partitions(pre_mod, mid_mod, post_mod):
121121
relay.GlobalVar("dequantize_outputs"): post_func,
122122
}
123123
)
124+
124125
# construct a `main` that strings together the partitions, such that its
125126
# behaviour is equivalent to `main` in an *unpartitioned* module
126127
scope_builder = relay.ScopeBuilder()
@@ -142,7 +143,7 @@ def fuse_partitions(pre_mod, mid_mod, post_mod):
142143
)
143144
scope_builder.ret(dequantized_outputs)
144145
fused_mod["main"] = relay.Function(fused_mod_main_params, scope_builder.get())
145-
return fused_mod
146+
return relay.transform.InferType()(fused_mod)
146147

147148

148149
class PrefixCutter(ExprMutator):
@@ -217,6 +218,7 @@ def partition_prefix(mod, quantized_dtypes):
217218
assert func.attrs is None, "unimplemented"
218219
mid_func = relay.Function(relay.analysis.free_vars(mid_body), mid_body)
219220
mid_mod = tvm.IRModule.from_expr(mid_func)
221+
mid_mod = relay.transform.InferType()(mid_mod)
220222

221223
scope_builder = prefix_cutter.prefix_sb
222224
# make sure we pass through all inputs in the prefix function's return expr
@@ -237,6 +239,7 @@ def partition_prefix(mod, quantized_dtypes):
237239
pre_func_body = scope_builder.get()
238240
pre_func = relay.Function(relay.analysis.free_vars(pre_func_body), pre_func_body)
239241
pre_mod = tvm.IRModule.from_expr(pre_func)
242+
pre_mod = relay.transform.InferType()(pre_mod)
240243

241244
return pre_mod, mid_mod
242245

@@ -288,6 +291,7 @@ def partition_suffix(mod, quantized_dtypes):
288291
assert func.attrs is None, "unimplemented"
289292
post_func = relay.Function(relay.analysis.free_vars(post_body), post_body, func.ret_type)
290293
post_mod = tvm.IRModule.from_expr(post_func)
294+
post_mod = relay.transform.InferType()(post_mod)
291295

292296
mid_body = suffix_cutter.mid_body
293297
if mid_body is None:
@@ -298,9 +302,11 @@ def partition_suffix(mod, quantized_dtypes):
298302
post_body = relay.Var("input", mid_mod["main"].ret_type)
299303
post_func = relay.Function([post_body], post_body)
300304
post_mod = tvm.IRModule.from_expr(post_func)
305+
post_mod = relay.transform.InferType()(post_mod)
301306
else:
302307
mid_func = relay.Function(func.params, mid_body)
303308
mid_mod = tvm.IRModule.from_expr(mid_func)
309+
mid_mod = relay.transform.InferType()(mid_mod)
304310

305311
return mid_mod, post_mod
306312

python/tvm/relay/testing/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def run_opt_pass(expr, opt_pass, import_prelude=False):
5353
mod = tvm.IRModule.from_expr(expr)
5454
if import_prelude:
5555
Prelude(mod)
56+
mod = relay.transform.InferType()(mod)
5657
mod = opt_pass(mod)
5758
entry = mod["main"]
5859
return entry if isinstance(expr, relay.Function) else entry.body

python/tvm/relay/testing/py_converter.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""Utility for converting Relay code into a Python script with equivalent semantics"""
18+
import sys
1819
import ast
1920
from ast import alias, Assign, Load, Name, NameConstant, Num, Return, Store, Str
2021
import re
@@ -27,6 +28,8 @@
2728
from tvm.relay.function import Function
2829
from tvm.relay.expr_functor import ExprFunctor
2930

31+
__MAJOR__, __MINOR__, _, _, _ = sys.version_info
32+
3033
OUTPUT_VAR_NAME = "_py_out"
3134

3235
# corresponds to:
@@ -82,8 +85,12 @@ def convert(self, prog: Expr):
8285
# we finally must assign the final expression to the output var
8386
# so it can be read after running EXEC
8487
body.append(Assign([Name(OUTPUT_VAR_NAME, Store())], prog_body))
88+
global __MAJOR__, __MINOR__
8589

86-
return ast.fix_missing_locations(ast.Module(body=body))
90+
if __MAJOR__ == 3 and __MINOR__ == 8:
91+
return ast.fix_missing_locations(ast.Module(body=body,type_ignores=[]))
92+
else:
93+
return ast.fix_missing_locations(ast.Module(body=body))
8794

8895
def optimize(self, prog: Expr):
8996
"""Performs optimizations necessary to be able to generate code for prog."""
@@ -210,11 +217,19 @@ def create_call(self, func_name: str, arguments):
210217

211218
def create_def(self, func_name: str, arguments: [str], body):
212219
"""Wrapper over function definition AST node, whose constructor is inconvenient."""
220+
inner_args = [ast.arg(argument, None) for argument in arguments]
221+
222+
global __MAJOR__, __MINOR__
223+
if __MAJOR__ == 3 and __MINOR__ == 8:
224+
arguments = ast.arguments(
225+
[], inner_args, None, [], [], None, [])
226+
else:
227+
arguments = ast.arguments(
228+
inner_args, None, [], [], None, [])
229+
213230
return ast.FunctionDef(
214231
func_name,
215-
ast.arguments(
216-
[ast.arg(argument, None) for argument in arguments], None, [], [], None, []
217-
),
232+
arguments,
218233
body,
219234
[],
220235
None,
@@ -576,8 +591,11 @@ def to_python(expr: Expr, mod=None, target=tvm.target.Target("llvm")):
576591
"""Converts the given Relay expression into a Python script (as a Python AST object).
577592
For easiest debugging, import the astor package and use to_source()."""
578593
mod = mod if mod is not None else tvm.IRModule()
594+
mod = relay.transform.InferType()(mod)
579595
converter = PythonConverter(mod, target)
580-
return converter.convert(expr)
596+
python = converter.convert(expr)
597+
assert python
598+
return python
581599

582600

583601
def run_as_python(expr: Expr, mod=None, target=tvm.target.Target("llvm")):

tests/python/relay/test_ir_well_formed.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,12 @@ def test_tuple_get_item():
5252
def test_adt():
5353
mod = tvm.IRModule()
5454
p = Prelude(mod)
55+
_, none, some = p.mod.get_type("Option")
5556
x = relay.Var("x")
56-
some_case = relay.Clause(relay.PatternConstructor(p.some, [relay.PatternVar(x)]), x)
57+
some_case = relay.Clause(relay.PatternConstructor(some, [relay.PatternVar(x)]), x)
5758
default_case = relay.Clause(relay.PatternVar(x), x)
58-
m0 = relay.Match(p.none(), [default_case])
59-
m1 = relay.Match(p.none(), [some_case, default_case])
59+
m0 = relay.Match(none(), [default_case])
60+
m1 = relay.Match(none(), [some_case, default_case])
6061
assert well_formed(m0)
6162
assert not well_formed(m1)
6263

tests/python/relay/test_op_qnn_add.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def test_tflite_same_io_qnn_params():
3838

3939
func = relay.Function([x, y], z)
4040
mod = tvm.IRModule.from_expr(func)
41+
mod = relay.transform.InferType()(mod)
4142
mod = relay.qnn.transform.CanonicalizeOps()(mod)
4243
func = mod["main"]
4344

@@ -85,6 +86,7 @@ def test_tflite_different_io_qnn_params():
8586

8687
func = relay.Function([x, y], z)
8788
mod = tvm.IRModule.from_expr(func)
89+
mod = relay.transform.InferType()(mod)
8890
mod = relay.qnn.transform.CanonicalizeOps()(mod)
8991
func = mod["main"]
9092

@@ -132,8 +134,10 @@ def test_saturation():
132134

133135
func = relay.Function([x, y], z)
134136
mod = tvm.IRModule.from_expr(func)
137+
mod = relay.transform.InferType()(mod)
135138
mod = relay.qnn.transform.CanonicalizeOps()(mod)
136139
func = mod["main"]
140+
mod = relay.transform.InferType()(mod)
137141

138142
x_data = np.array((255, 1, 1, 0)).reshape((1, 4))
139143
y_data = np.array((255, 255, 128, 0)).reshape((1, 4))
@@ -157,6 +161,7 @@ def test_saturation():
157161

158162
func = relay.Function([x, y], z)
159163
mod = tvm.IRModule.from_expr(func)
164+
mod = relay.transform.InferType()(mod)
160165
mod = relay.qnn.transform.CanonicalizeOps()(mod)
161166
func = mod["main"]
162167

@@ -182,6 +187,7 @@ def test_saturation():
182187

183188
func = relay.Function([x, y], z)
184189
mod = tvm.IRModule.from_expr(func)
190+
mod = relay.transform.InferType()(mod)
185191
mod = relay.qnn.transform.CanonicalizeOps()(mod)
186192
func = mod["main"]
187193

@@ -207,6 +213,7 @@ def test_saturation():
207213

208214
func = relay.Function([x, y], z)
209215
mod = tvm.IRModule.from_expr(func)
216+
mod = relay.transform.InferType()(mod)
210217
mod = relay.qnn.transform.CanonicalizeOps()(mod)
211218
func = mod["main"]
212219

tests/python/relay/test_op_qnn_concatenate.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def test_same_io_qnn_params():
4545

4646
func = relay.Function([x, y], z)
4747
mod = tvm.IRModule.from_expr(func)
48+
mod = relay.transform.InferType()(mod)
4849
mod = relay.qnn.transform.CanonicalizeOps()(mod)
4950
func = mod["main"]
5051

@@ -79,6 +80,7 @@ def test_different_io_qnn_params():
7980

8081
func = relay.Function([x, y], z)
8182
mod = tvm.IRModule.from_expr(func)
83+
mod = relay.transform.InferType()(mod)
8284
mod = relay.qnn.transform.CanonicalizeOps()(mod)
8385
func = mod["main"]
8486

@@ -113,6 +115,7 @@ def test_few_same_io_qnn_params():
113115

114116
func = relay.Function([x, y], z)
115117
mod = tvm.IRModule.from_expr(func)
118+
mod = relay.transform.InferType()(mod)
116119
mod = relay.qnn.transform.CanonicalizeOps()(mod)
117120
func = mod["main"]
118121

@@ -147,6 +150,7 @@ def test_same_i_qnn_params():
147150

148151
func = relay.Function([x, y], z)
149152
mod = tvm.IRModule.from_expr(func)
153+
mod = relay.transform.InferType()(mod)
150154
mod = relay.qnn.transform.CanonicalizeOps()(mod)
151155
func = mod["main"]
152156

tests/python/relay/test_op_qnn_dense.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def qnn_dense_driver(test_configuration):
207207

208208
mod = relay.Function(relay.analysis.free_vars(mod), mod)
209209
mod = tvm.IRModule.from_expr(mod)
210+
mod = relay.transform.InferType()(mod)
210211
mod = relay.qnn.transform.CanonicalizeOps()(mod)
211212
with tvm.transform.PassContext(opt_level=2):
212213
graph, lib, params = relay.build(mod, "llvm", params=None)

tests/python/relay/test_op_qnn_mul.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def test_tflite_same_io_qnn_params():
5757

5858
func = relay.Function([x, y], z)
5959
mod = tvm.IRModule.from_expr(func)
60+
mod = relay.transform.InferType()(mod)
6061
mod = relay.qnn.transform.CanonicalizeOps()(mod)
6162
func = mod["main"]
6263

@@ -110,6 +111,7 @@ def test_tflite_different_io_qnn_params():
110111

111112
func = relay.Function([x, y], z)
112113
mod = tvm.IRModule.from_expr(func)
114+
mod = relay.transform.InferType()(mod)
113115
mod = relay.qnn.transform.CanonicalizeOps()(mod)
114116
func = mod["main"]
115117

@@ -158,6 +160,7 @@ def test_saturation():
158160

159161
func = relay.Function([x, y], z)
160162
mod = tvm.IRModule.from_expr(func)
163+
mod = relay.transform.InferType()(mod)
161164
mod = relay.qnn.transform.CanonicalizeOps()(mod)
162165
func = mod["main"]
163166

@@ -191,6 +194,7 @@ def test_saturation():
191194

192195
func = relay.Function([x, y], z)
193196
mod = tvm.IRModule.from_expr(func)
197+
mod = relay.transform.InferType()(mod)
194198
mod = relay.qnn.transform.CanonicalizeOps()(mod)
195199
func = mod["main"]
196200

@@ -225,6 +229,7 @@ def test_saturation():
225229

226230
func = relay.Function([x, y], z)
227231
mod = tvm.IRModule.from_expr(func)
232+
mod = relay.transform.InferType()(mod)
228233
mod = relay.qnn.transform.CanonicalizeOps()(mod)
229234
func = mod["main"]
230235

tests/python/relay/test_op_qnn_subtract.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def qnn_subtract_driver(x_datas, y_datas, golden_outputs, scale_and_zp, data_dty
4545
)
4646
func = relay.Function([x, y], z)
4747
mod = tvm.IRModule.from_expr(func)
48+
mod = relay.transform.InferType()(mod)
4849
mod = relay.qnn.transform.CanonicalizeOps()(mod)
4950
func = mod["main"]
5051
for i in range(0, len(x_datas)):

tests/python/relay/test_pass_annotate_target.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ def expected(dtype, ishape, w1shape):
127127
def test_annotate():
128128
mod = annotated(dtype, ishape, w1shape)
129129
mod = transform.AnnotateTarget("dnnl")(mod)
130+
mod = relay.transform.InferType()(mod)
130131
ref_mod = expected(dtype, ishape, w1shape)
132+
ref_mod = relay.transform.InferType()(ref_mod)
131133
tvm.ir.assert_structural_equal(mod, ref_mod)
132134

133135
def test_run():

0 commit comments

Comments
 (0)