Skip to content

Commit 577bc6d

Browse files
committed
Add TileLang SplitHostDevice pass and tighten issue 830 test names
1 parent ff8297f commit 577bc6d

File tree

4 files changed

+225
-27
lines changed

4 files changed

+225
-27
lines changed

src/transform/split_host_device.cc

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file split_host_device.cc
22+
* \brief Split device function from host.
23+
*/
24+
#include <tvm/ffi/function.h>
25+
#include <tvm/ffi/reflection/registry.h>
26+
#include <tvm/ir/global_var_supply.h>
27+
#include <tvm/ir/transform.h>
28+
#include <tvm/target/target.h>
29+
#include <tvm/tir/analysis.h>
30+
#include <tvm/tir/builtin.h>
31+
#include <tvm/tir/expr.h>
32+
#include <tvm/tir/op.h>
33+
#include <tvm/tir/stmt_functor.h>
34+
#include <tvm/tir/transform.h>
35+
36+
#include "tir/analysis/var_use_def_analysis.h"
37+
38+
namespace tvm {
39+
namespace tl {
40+
41+
namespace tir = tvm::tir;
42+
43+
class HostDeviceSplitter : public tir::StmtMutator {
44+
public:
45+
explicit HostDeviceSplitter(IRModule* device_mod, std::function<GlobalVar()> var_supply)
46+
: device_mod_(device_mod), var_supply_(std::move(var_supply)) {}
47+
48+
tir::Stmt VisitStmt_(const tir::AttrStmtNode* op) final {
49+
if (op->attr_key == tvm::attr::kTarget) {
50+
found_device_region_ = true;
51+
auto device_target = op->node.as<tvm::Target>().value().WithoutHost();
52+
return SplitDeviceFunc(op->body, device_target);
53+
}
54+
return tir::StmtMutator::VisitStmt_(op);
55+
}
56+
57+
tir::Stmt ForceSplit(tir::Stmt body, tvm::Target device_target) {
58+
return SplitDeviceFunc(std::move(body), std::move(device_target));
59+
}
60+
61+
bool found_device_region() const { return found_device_region_; }
62+
63+
private:
64+
bool found_device_region_{false};
65+
66+
tir::Stmt SplitDeviceFunc(tir::Stmt body, tvm::Target device_target) {
67+
auto [params, buffers_to_declare] =
68+
[&]() -> std::tuple<Array<tir::Var>, Array<tir::Buffer>> {
69+
tir::VarUseDefAnalyzer use_def(/*defined_vars=*/{}, /*visit_thread_extent=*/true);
70+
use_def(body);
71+
72+
// Sort first by variable type, then by variable name
73+
std::vector<tir::Var> params{use_def.undefined_.begin(), use_def.undefined_.end()};
74+
std::sort(params.begin(), params.end(), [](const tir::Var& a, const tir::Var& b) {
75+
auto sort_key = [](const tir::Var& var) {
76+
return std::tuple{
77+
!var->dtype.is_handle(),
78+
var->name_hint,
79+
};
80+
};
81+
return sort_key(a) < sort_key(b);
82+
});
83+
return {params, use_def.undefined_buffers_};
84+
}();
85+
86+
// CodeGenCPU is used for some device-side targets, such as
87+
// "ext_dev", and expects to be able to return a int32_t status
88+
// code.
89+
90+
bool can_propagate_errors = [&]() {
91+
auto kind = device_target->GetTargetDeviceType();
92+
return kind == kDLCPU || kind == kDLExtDev || kind == kDLHexagon;
93+
}();
94+
IntImm success(DataType::Int(32), 0);
95+
Type kernel_ret_type;
96+
if (can_propagate_errors) {
97+
kernel_ret_type = PrimType(DataType::Int(32));
98+
body = tir::SeqStmt::Flatten(body, tir::Evaluate(ret(success)));
99+
} else {
100+
kernel_ret_type = VoidType();
101+
}
102+
103+
for (tir::Buffer buf : buffers_to_declare) {
104+
body = tir::DeclBuffer(buf, std::move(body));
105+
}
106+
tir::PrimFunc device_func(params, body, kernel_ret_type);
107+
device_func = WithAttrs(
108+
std::move(device_func),
109+
{{tvm::attr::kTarget, device_target},
110+
{tir::attr::kNoAlias, true},
111+
{tir::attr::kIsGlobalFunc, true}});
112+
113+
GlobalVar kernel_symbol_global = var_supply_();
114+
(*device_mod_)->Add(kernel_symbol_global, device_func);
115+
Array<PrimExpr> args = params.Map([](const tir::Var& var) -> PrimExpr { return var; });
116+
117+
if (can_propagate_errors) {
118+
tir::Var kernel_error_code("kernel_error_code", success->dtype);
119+
tir::Call kernel_call(success->dtype, kernel_symbol_global, args);
120+
tir::AssertStmt assert_success(
121+
kernel_error_code == success, tir::StringImm("Error executing compute kernel"),
122+
tir::Evaluate(0));
123+
tir::LetStmt let_check(kernel_error_code, kernel_call, assert_success);
124+
125+
return let_check;
126+
127+
} else {
128+
return tir::Evaluate(tir::Call(DataType::Void(), kernel_symbol_global, args));
129+
}
130+
}
131+
132+
// target ir module
133+
IRModule* device_mod_;
134+
// Generate new GlobalVar for the kernel
135+
std::function<GlobalVar()> var_supply_;
136+
};
137+
138+
tir::PrimFunc SplitHostDevice(tir::PrimFunc func, IRModule* device_mod,
139+
std::function<GlobalVar()> var_supply) {
140+
HostDeviceSplitter splitter(device_mod, std::move(var_supply));
141+
142+
if (auto body = splitter(func->body); !body.same_as(func->body)) {
143+
func.CopyOnWrite()->body = body;
144+
} else if (!splitter.found_device_region()) {
145+
if (auto target = func->GetAttr<Target>(tvm::attr::kTarget)) {
146+
auto device_target = target.value().WithoutHost();
147+
if (device_target.defined() && func->HasNonzeroAttr(tir::attr::kIsEntryFunc) &&
148+
tir::is_no_op(func->body)) {
149+
if (auto forced = splitter.ForceSplit(func->body, device_target);
150+
!forced.same_as(func->body)) {
151+
func.CopyOnWrite()->body = forced;
152+
}
153+
}
154+
}
155+
}
156+
157+
return func;
158+
}
159+
160+
namespace transform {
161+
162+
tvm::transform::Pass SplitHostDevice() {
163+
auto pass_func = [](IRModule mod, tvm::transform::PassContext ctx) {
164+
tvm::GlobalVarSupply global_var_supply(mod);
165+
166+
IRModule device_mod = IRModule(Map<GlobalVar, BaseFunc>({}));
167+
IRModule updates = IRModule(Map<GlobalVar, BaseFunc>({}));
168+
169+
for (const auto& [gvar, base_func] : mod->functions) {
170+
if (auto opt = base_func.as<tir::PrimFunc>()) {
171+
tir::PrimFunc func = opt.value();
172+
173+
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
174+
auto name_prefix = global_symbol.value_or(gvar->name_hint);
175+
auto kernel_name = name_prefix + "_kernel";
176+
auto var_supply = [&global_var_supply, &kernel_name]() -> GlobalVar {
177+
return global_var_supply->FreshGlobal(kernel_name, false);
178+
};
179+
180+
func = ::tvm::tl::SplitHostDevice(std::move(func), &device_mod, var_supply);
181+
if (!func.same_as(base_func)) {
182+
updates->Add(gvar, func);
183+
}
184+
}
185+
}
186+
187+
mod->Update(updates);
188+
mod->Update(device_mod);
189+
return tir::transform::ConvertSSA()(mod);
190+
};
191+
192+
return tvm::transform::CreateModulePass(pass_func, 0, "tl.SplitHostDevice", {});
193+
}
194+
195+
TVM_FFI_STATIC_INIT_BLOCK({
196+
namespace refl = tvm::ffi::reflection;
197+
refl::GlobalDef().def("tl.transform.SplitHostDevice", SplitHostDevice);
198+
});
199+
200+
} // namespace transform
201+
} // namespace tl
202+
} // namespace tvm

testing/python/issue/test_tilelang_issue_830.py

100755100644
Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,26 @@
1-
# yapf: disable
21
# ruff: noqa
32

4-
# This testfile can't be formatted by yapf & ruff
5-
63
import tilelang
74
import tilelang.testing
85
import tilelang.language as T
96

107

11-
@tilelang.jit()
12-
def get_buggy_kernel():
8+
@tilelang.jit
9+
def _empty_kernel():
10+
1311
@T.prim_func
14-
def buggy():
15-
with T.Kernel(1, threads=32) as pid:
12+
def empty_kernel():
13+
with T.Kernel(1, threads=32) as thread_idx:
1614
A_shared = T.alloc_shared((1,), "float32")
1715

18-
return buggy
19-
20-
21-
@tilelang.jit()
22-
def get_buggy_kernel1():
23-
num_tokens = T.symbolic('num_tokens')
24-
25-
@T.prim_func
26-
def buggy(x: T.Tensor[(num_tokens, ), 'float'],):
27-
with T.Kernel(num_tokens, threads=32) as pid:
28-
y = x[pid]
16+
return empty_kernel
2917

30-
return buggy
3118

32-
def test_dummy_kernel_gen():
33-
"""Test dummy kernel generation"""
34-
# Currently still can't pass the test
35-
# kernel = get_buggy_kernel()
36-
# kernel()
37-
pass
19+
def test_empty_kernel_lowering():
20+
kernel = _empty_kernel()
21+
kernel()
3822

3923

4024
if __name__ == "__main__":
41-
tilelang.testing.main()
25+
test_empty_kernel_lowering()
26+
# tilelang.testing.main()

tilelang/engine/phase.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
186186
if allow_global_thread_synchronization():
187187
mod = tilelang.transform.ThreadSync("global")(mod)
188188
mod = tilelang.transform.AnnotateDeviceRegions()(mod)
189-
mod = tir.transform.SplitHostDevice()(mod)
189+
mod = tilelang.transform.SplitHostDevice()(mod)
190190
# MergeSharedMemoryAllocations must be applied after SplitHostDevice
191191
# because the merged allocation site is at the beginning of each device function
192192
enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target)

tilelang/transform/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,17 @@ def AnnotateDeviceRegions():
282282
return _ffi_api.AnnotateDeviceRegions() # type: ignore
283283

284284

285+
def SplitHostDevice():
286+
"""Split host/device functions even for empty kernels.
287+
288+
Returns
289+
-------
290+
fpass : tvm.transform.Pass
291+
The result pass
292+
"""
293+
return _ffi_api.SplitHostDevice() # type: ignore
294+
295+
285296
def VectorizeLoop(enable_vectorize: bool = True):
286297
"""VectorizeLoop
287298

0 commit comments

Comments
 (0)