|
15 | 15 | # specific language governing permissions and limitations
|
16 | 16 | # under the License.
|
17 | 17 | """Utility for converting Relay code into a Python script with equivalent semantics"""
|
| 18 | +import sys |
18 | 19 | import ast
|
19 | 20 | from ast import alias, Assign, Load, Name, NameConstant, Num, Return, Store, Str
|
20 | 21 | import re
|
|
27 | 28 | from tvm.relay.function import Function
|
28 | 29 | from tvm.relay.expr_functor import ExprFunctor
|
29 | 30 |
|
| 31 | +__MAJOR__, __MINOR__, _, _, _ = sys.version_info |
| 32 | + |
30 | 33 | OUTPUT_VAR_NAME = "_py_out"
|
31 | 34 |
|
32 | 35 | # corresponds to:
|
@@ -82,8 +85,12 @@ def convert(self, prog: Expr):
|
82 | 85 | # we finally must assign the final expression to the output var
|
83 | 86 | # so it can be read after running EXEC
|
84 | 87 | body.append(Assign([Name(OUTPUT_VAR_NAME, Store())], prog_body))
|
| 88 | + global __MAJOR__, __MINOR__ |
85 | 89 |
|
86 |
| - return ast.fix_missing_locations(ast.Module(body=body)) |
| 90 | + if __MAJOR__ == 3 and __MINOR__ == 8: |
| 91 | + return ast.fix_missing_locations(ast.Module(body=body,type_ignores=[])) |
| 92 | + else: |
| 93 | + return ast.fix_missing_locations(ast.Module(body=body)) |
87 | 94 |
|
88 | 95 | def optimize(self, prog: Expr):
|
89 | 96 | """Performs optimizations necessary to be able to generate code for prog."""
|
@@ -210,11 +217,19 @@ def create_call(self, func_name: str, arguments):
|
210 | 217 |
|
211 | 218 | def create_def(self, func_name: str, arguments: [str], body):
|
212 | 219 | """Wrapper over function definition AST node, whose constructor is inconvenient."""
|
| 220 | + inner_args = [ast.arg(argument, None) for argument in arguments] |
| 221 | + |
| 222 | + global __MAJOR__, __MINOR__ |
| 223 | + if __MAJOR__ == 3 and __MINOR__ == 8: |
| 224 | + arguments = ast.arguments( |
| 225 | + [], inner_args, None, [], [], None, []) |
| 226 | + else: |
| 227 | + arguments = ast.arguments( |
| 228 | + inner_args, None, [], [], None, []) |
| 229 | + |
213 | 230 | return ast.FunctionDef(
|
214 | 231 | func_name,
|
215 |
| - ast.arguments( |
216 |
| - [ast.arg(argument, None) for argument in arguments], None, [], [], None, [] |
217 |
| - ), |
| 232 | + arguments, |
218 | 233 | body,
|
219 | 234 | [],
|
220 | 235 | None,
|
@@ -576,8 +591,11 @@ def to_python(expr: Expr, mod=None, target=tvm.target.Target("llvm")):
|
576 | 591 | """Converts the given Relay expression into a Python script (as a Python AST object).
|
577 | 592 | For easiest debugging, import the astor package and use to_source()."""
|
578 | 593 | mod = mod if mod is not None else tvm.IRModule()
|
| 594 | + mod = relay.transform.InferType()(mod) |
579 | 595 | converter = PythonConverter(mod, target)
|
580 |
| - return converter.convert(expr) |
| 596 | + python = converter.convert(expr) |
| 597 | + assert python |
| 598 | + return python |
581 | 599 |
|
582 | 600 |
|
583 | 601 | def run_as_python(expr: Expr, mod=None, target=tvm.target.Target("llvm")):
|
|
0 commit comments