Skip to content

Commit

Permalink
[TensorIR] Print TVMScript with prefix T instead of tir (apache#9422)
Browse files Browse the repository at this point in the history
  • Loading branch information
quic-sanirudh authored and ylc committed Jan 7, 2022
1 parent 7c51240 commit 92caf33
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 13 deletions.
2 changes: 1 addition & 1 deletion python/tvm/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def __str__(self):
def __repr__(self):
return self.astext()

def script(self, tir_prefix: str = "tir", show_meta: bool = False) -> str:
def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str:
"""Print IRModule into TVMScript
Parameters
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def mem_copy_16_16(a: T.handle, b: T.handle) -> None:
"""
return _ffi_api.Specialize(self, param_map) # type: ignore

def script(self, tir_prefix: str = "tir", show_meta: bool = False) -> str:
def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str:
"""Print IRModule into TVMScript
Parameters
Expand Down
21 changes: 19 additions & 2 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,17 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
}
return doc;
}

public:
static Doc PrintHeader(const std::string& tir_prefix) {
Doc header;
if (tir_prefix != "tir") {
header << "# from tvm.script import tir as " << tir_prefix << Doc::NewLine();
} else {
header << "# from tvm.script import tir" << Doc::NewLine();
}
return header;
}
};

Doc TVMScriptPrinter::GetUniqueName(std::string prefix) {
Expand Down Expand Up @@ -1431,15 +1442,21 @@ Doc TVMScriptPrinterWithDiagnostic::PrintLoop(const For& loop) {

String AsTVMScript(const ObjectRef& mod, const String& tir_prefix, bool show_meta) {
ICHECK(mod->IsInstance<PrimFuncNode>() || mod->IsInstance<IRModuleNode>());
return TVMScriptPrinter(tir_prefix, show_meta).Print(mod).str() + "\n";
Doc doc;
doc << TVMScriptPrinter::PrintHeader(tir_prefix)
<< TVMScriptPrinter(tir_prefix, show_meta).Print(mod);
return doc.str() + "\n";
}

TVM_REGISTER_GLOBAL("script.AsTVMScript").set_body_typed(AsTVMScript);

String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta,
runtime::TypedPackedFunc<std::string(Stmt)> annotate) {
ICHECK(mod->IsInstance<PrimFuncNode>() || mod->IsInstance<IRModuleNode>());
return TVMScriptPrinterWithDiagnostic(tir_prefix, show_meta, annotate).Print(mod).str() + "\n";
Doc doc;
doc << TVMScriptPrinter::PrintHeader(tir_prefix)
<< TVMScriptPrinterWithDiagnostic(tir_prefix, show_meta, annotate).Print(mod);
return doc.str() + "\n";
}

TVM_REGISTER_GLOBAL("script.AsTVMScriptWithDiagnostic").set_body_typed(AsTVMScriptWithDiagnostic);
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/error.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ String ScheduleError::RenderReport(const String& primitive) const {

os << "ScheduleError: An error occurred in the schedule primitive '" << primitive
<< "'.\n\nThe IR with diagnostic is:\n"
<< AsTVMScriptWithDiagnostic(mod, "tir", false, annotate);
<< AsTVMScriptWithDiagnostic(mod, "T", false, annotate);

// print error message
os << "Error message: " << msg;
Expand Down
16 changes: 8 additions & 8 deletions tests/python/unittest/test_tvmscript_error_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,8 @@ def test_reorder_fail_block():
sch.reorder(l, i)
expected_sub_error_message = (
" # tir.Block#0\n"
' with tir.block("B"):\n'
" ^^^^^^^^^^^^^^^^^^^^\n"
' with T.block("B"):\n'
" ^^^^^^^^^^^^^^^^^^\n"
)
assert expected_sub_error_message in str(execinfo.value)

Expand All @@ -561,10 +561,10 @@ def test_reorder_fail_nested_loop_inner():
with pytest.raises(tvm.tir.ScheduleError) as execinfo:
sch.reorder(k, i)
expected_sub_error_message = (
" for i in tir.serial(0, 128):\n"
" for i in T.serial(0, 128):\n"
" # tir.For#0\n"
" for j in tir.serial(0, 128):\n"
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n"
" for j in T.serial(0, 128):\n"
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n"
)
assert expected_sub_error_message in str(execinfo.value)

Expand All @@ -577,9 +577,9 @@ def test_fuse_fail_nested_loop_outer():
sch.fuse(k, i)
expected_sub_error_message = (
" # tir.For#1\n"
" for i in tir.serial(0, 128):\n"
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n"
" for j in tir.serial(0, 128):\n"
" for i in T.serial(0, 128):\n"
" ^^^^^^^^^^^^^^^^^^^^^^^^^^\n"
" for j in T.serial(0, 128):\n"
)
assert expected_sub_error_message in str(execinfo.value)

Expand Down

0 comments on commit 92caf33

Please sign in to comment.