Skip to content

Commit 3a33771

Browse files
authored
[TVMScript] Handle parsing of PrimFunc calls with non-void return (#15239)
* [TVMScript] Handle parsing of PrimFunc calls with non-void return Prior to this commit, the return type of all internal function calls was hard-coded as `"void"`. After this commit, the `GlobalVar` representing the internal function has type annotation based on the callee's signature, which is then used as the return type of the internal call. * Update CallNode return type in MakeUnpackedAPI
1 parent 81463d7 commit 3a33771

File tree

4 files changed

+45
-6
lines changed

4 files changed

+45
-6
lines changed

python/tvm/tir/op.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,14 @@ def call_tir(global_var: tvm.ir.GlobalVar, *args):
445445
The call expression.
446446
"""
447447
assert isinstance(global_var, tvm.ir.GlobalVar)
448-
return Call(dtype="void", op=global_var, args=args)
448+
449+
dtype = "void"
450+
if global_var.checked_type is not None:
451+
ret_type = global_var.checked_type.ret_type
452+
if hasattr(ret_type, "dtype"):
453+
dtype = ret_type.dtype
454+
455+
return Call(dtype=dtype, op=global_var, args=args)
449456

450457

451458
def start_profile_intrinsic(id):

src/script/ir_builder/ir/ir.cc

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#include <tvm/ir/module.h>
2020
#include <tvm/runtime/registry.h>
2121
#include <tvm/script/ir_builder/ir/ir.h>
22+
#include <tvm/tir/function.h>
23+
#include <tvm/tir/op.h>
2224

2325
#include "./utils.h"
2426

@@ -38,7 +40,17 @@ GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature)
3840
IRModuleFrame frame = FindModuleFrame("I.DeclFunction");
3941
CHECK(!frame->global_var_map.count(func_name))
4042
<< "ValueError: function " << func_name << " already exists";
41-
GlobalVar gv = GlobalVar(func_name);
43+
44+
auto gvar_type = [&]() -> Type {
45+
if (auto prim_func = func_signature.as<tir::PrimFuncNode>()) {
46+
Array<Type> arg_types = prim_func->params.Map([](const auto& var) { return GetType(var); });
47+
return FuncType(arg_types, prim_func->ret_type, {}, {});
48+
}
49+
50+
return {};
51+
}();
52+
53+
GlobalVar gv = GlobalVar(func_name, gvar_type);
4254
CHECK(frame->functions.find(gv) == frame->functions.end())
4355
<< "ValueError: function " << func_name << " has already been defined.";
4456
frame->global_var_map.Set(func_name, gv);

src/tir/transforms/make_unpacked_api.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,21 @@ class SubroutineCallRewriter : public StmtExprMutator {
6464

6565
if (auto gvar = node->op.as<GlobalVarNode>()) {
6666
if (external_methods_.count(gvar)) {
67-
Array<PrimExpr> args = node->args.Map([this](const PrimExpr& arg) -> PrimExpr {
67+
Array<PrimExpr> args = node->args.Map([](const PrimExpr& arg) -> PrimExpr {
6868
if (auto* as_call = arg.as<CallNode>()) {
6969
if (as_call->op.same_as(builtin::tvm_stack_make_array())) {
7070
PrimExpr data_ptr = as_call->args[0];
71-
made_change_ = true;
7271
return data_ptr;
7372
}
7473
}
7574
return arg;
7675
});
77-
if (!args.same_as(node->args)) {
78-
node.CopyOnWrite()->args = args;
76+
77+
if (!args.same_as(node->args) || node->dtype != DataType::Int(32)) {
78+
auto write_ptr = node.CopyOnWrite();
79+
write_ptr->dtype = DataType::Int(32);
80+
write_ptr->args = args;
81+
made_change_ = true;
7982
}
8083
}
8184
}

tests/python/unittest/test_tvmscript_roundtrip.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3817,6 +3817,22 @@ def subroutine(A_data: T.handle("float32"), n: T.int32):
38173817
return mod
38183818

38193819

3820+
def subroutine_call_returning_int():
3821+
"""An internal function call may return non-void"""
3822+
3823+
@I.ir_module
3824+
class mod:
3825+
@T.prim_func
3826+
def main(A: T.Buffer(2, "float32")):
3827+
mod.subroutine(A[0]) + mod.subroutine(A[1])
3828+
3829+
@T.prim_func
3830+
def subroutine(x: T.float32) -> T.float32:
3831+
T.ret(x * x)
3832+
3833+
return mod
3834+
3835+
38203836
def undefined_data_ptr_in_decl_buffer():
38213837
"""The T.decl_buffer syntax should not introduce an Allocate
38223838
@@ -4009,6 +4025,7 @@ def func():
40094025
ir_module_with_attrs,
40104026
nested_seqstmt,
40114027
subroutine_call,
4028+
subroutine_call_returning_int,
40124029
undefined_data_ptr_in_decl_buffer,
40134030
undefined_shape_in_decl_buffer,
40144031
undefined_stride_in_decl_buffer,

0 commit comments

Comments
 (0)