Skip to content

Commit ad6cca0

Browse files
Lunderbergjunrushao
authored andcommitted
[TIR] Return error code from kernels in SplitHostDevice (apache#15241)
* [TVMScript] Handle parsing of PrimFunc calls with non-void return Prior to this commit, the return type of all internal function calls was hard-coded as `"void"`. After this commit, the `GlobalVar` representing the internal function has type annotation based on the callee's signature, which is then used as the return type of the internal call. * Update CallNode return type in MakeUnpackedAPI * [TIR] Return error code from kernels in SplitHostDevice Some codegen types delegate to `CodeGenCPU` for their compute kernels, as they may delegate work to packed functions. Because `CodeGenCPU` assumes that it can return an error code at any point (e.g. when launching a parallel for loop), the compute kernel should return an error code. * [TIR] Remove builtin::ret(0) from device-side kernel * Restrict the int32 return type to targets that need to propagate errors * Updated unit tests for CPU-specific checks
1 parent 9fd6f3f commit ad6cca0

File tree

3 files changed

+108
-4
lines changed

3 files changed

+108
-4
lines changed

src/tir/transforms/lower_device_kernel_launch.cc

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,36 @@ class DeviceInfoCollector : public StmtVisitor {
145145
// The amount of dynamic shared memory used
146146
Optional<PrimExpr> dyn_shmem_size{NullOpt};
147147
};
148+
149+
class ReturnRemover : public StmtExprMutator {
150+
public:
151+
static Stmt Apply(const Stmt& stmt) {
152+
ReturnRemover mutator;
153+
return mutator(stmt);
154+
}
155+
156+
private:
157+
using Parent = StmtExprMutator;
158+
Stmt VisitStmt_(const EvaluateNode* op) override {
159+
if (auto* call = op->value.as<CallNode>()) {
160+
if (call->op.same_as(builtin::ret())) {
161+
ICHECK_EQ(call->args.size(), 1);
162+
auto as_int = call->args[0].as<IntImmNode>();
163+
ICHECK(as_int && as_int->value == 0)
164+
<< "Device kernel may only contain successful return, T.ret(0)";
165+
return Evaluate(0);
166+
}
167+
}
168+
return Parent::VisitStmt_(op);
169+
}
170+
171+
PrimExpr VisitExpr_(const CallNode* op) override {
172+
if (op->op.same_as(builtin::ret())) {
173+
LOG(FATAL) << "Call to builtin::ret() should only appear within an Evaluate node";
174+
}
175+
return Parent::VisitExpr_(op);
176+
}
177+
};
148178
} // namespace
149179

150180
class DeviceKernelMutator : public StmtExprMutator {
@@ -185,10 +215,19 @@ class DeviceKernelMutator : public StmtExprMutator {
185215
if (is_kernel_launch) {
186216
const auto& info = device_info_map_.at(gvar.get());
187217

218+
// Kernel launches provide an int32 error code to the caller,
219+
// but do not accept any return type from the callee.
220+
{
221+
auto write_ptr = func.CopyOnWrite();
222+
write_ptr->ret_type = VoidType();
223+
write_ptr->body = ReturnRemover::Apply(write_ptr->body);
224+
}
225+
188226
func = WithAttrs(std::move(func),
189227
{{tvm::attr::kCallingConv, Integer(tvm::CallingConv::kDeviceKernelLaunch)},
190228
{tvm::tir::attr::kKernelLaunchParams, info.launch_params},
191229
{tvm::attr::kGlobalSymbol, info.global_symbol}});
230+
192231
} else if (is_call_extern && !func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
193232
func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint);
194233
}
@@ -197,7 +236,7 @@ class DeviceKernelMutator : public StmtExprMutator {
197236
}
198237

199238
private:
200-
PrimExpr VisitExpr_(const CallNode* op) {
239+
PrimExpr VisitExpr_(const CallNode* op) override {
201240
auto node = Downcast<Call>(Parent::VisitExpr_(op));
202241

203242
auto* gvar = op->op.as<GlobalVarNode>();

src/tir/transforms/split_host_device.cc

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class HostDeviceSplitter : public StmtMutator {
6060
VarUseDefAnalyzer use_def(/*defined_vars=*/{}, /*visit_thread_extent=*/false);
6161
use_def(body);
6262

63-
// Sort first by variable typ, then by variable name
63+
// Sort first by variable type, then by variable name
6464
std::vector<Var> params{use_def.undefined_.begin(), use_def.undefined_.end()};
6565
std::sort(params.begin(), params.end(), [](const Var& a, const Var& b) {
6666
auto sort_key = [](const Var& var) {
@@ -74,16 +74,43 @@ class HostDeviceSplitter : public StmtMutator {
7474
return params;
7575
}();
7676

77+
// CodeGenCPU is used for some device-side targets, such as
78+
// "ext_dev", and expects to be able to return a int32_t status
79+
// code.
80+
81+
bool can_propagate_errors = [&]() {
82+
auto kind = device_target->GetTargetDeviceType();
83+
return kind == kDLCPU || kind == kDLExtDev || kind == kDLHexagon;
84+
}();
85+
IntImm success(DataType::Int(32), 0);
86+
Type kernel_ret_type;
87+
if (can_propagate_errors) {
88+
kernel_ret_type = PrimType(DataType::Int(32));
89+
body = SeqStmt::Flatten(body, Evaluate(ret(success)));
90+
} else {
91+
kernel_ret_type = VoidType();
92+
}
93+
7794
GlobalVar kernel_symbol_global = var_supply_();
78-
PrimFunc device_func(params, body);
95+
PrimFunc device_func(params, body, kernel_ret_type);
7996
device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target},
8097
{tir::attr::kNoAlias, Bool(true)},
8198
{tir::attr::kIsGlobalFunc, Bool(true)}});
8299

83100
(*device_mod_)->Add(kernel_symbol_global, device_func);
84101
Array<PrimExpr> args = params.Map([](const Var& var) -> PrimExpr { return var; });
85102

86-
return Evaluate(Call(DataType::Void(), kernel_symbol_global, args));
103+
if (can_propagate_errors) {
104+
Var kernel_error_code("kernel_error_code", success->dtype);
105+
Call kernel_call(success->dtype, kernel_symbol_global, args);
106+
AssertStmt assert_success(kernel_error_code == success,
107+
StringImm("Error executing compute kernel"), Evaluate(0));
108+
LetStmt let_check(kernel_error_code, kernel_call, assert_success);
109+
110+
return std::move(let_check);
111+
} else {
112+
return Evaluate(Call(DataType::Void(), kernel_symbol_global, args));
113+
}
87114
}
88115

89116
// target ir module

tests/python/unittest/test_tir_transform_split_host_device.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,44 @@ def main_kernel(n: T.int32):
129129
return mod
130130

131131

132+
class TestSplitHostDeviceOnCPU(BaseCompare):
133+
"""A kernel running on the CPU may return an error code"""
134+
135+
def before(self):
136+
@I.ir_module
137+
class mod:
138+
@T.prim_func
139+
def main(n: T.int32):
140+
T.func_attr({"target": T.target("cuda", host="llvm -opt-level=0")})
141+
T.attr(T.target("llvm"), "target", 0)
142+
T.evaluate(n)
143+
144+
return mod
145+
146+
def expected(self):
147+
@I.ir_module
148+
class mod:
149+
@T.prim_func
150+
def main(n: T.int32):
151+
T.func_attr({"target": T.target("cuda", host="llvm -opt-level=0")})
152+
err = mod.main_kernel(n)
153+
assert err == 0, "Error executing compute kernel"
154+
155+
@T.prim_func
156+
def main_kernel(n: T.int32) -> T.int32:
157+
T.func_attr(
158+
{
159+
"target": T.target("llvm"),
160+
"tir.noalias": T.bool(True),
161+
"tir.is_global_func": True,
162+
}
163+
)
164+
T.evaluate(n)
165+
T.ret(0)
166+
167+
return mod
168+
169+
132170
class TestSplitHostDeviceWithoutFuncHostAttribute(BaseCompare):
133171
"""Like TestSplitHostDevice, but no host specified in the host's target
134172

0 commit comments

Comments
 (0)