Skip to content

Commit c921781

Browse files
authored
[TIR] Output DeclBuffer in SplitHostDevice (#15493)
* [TIR] Output DeclBuffer in SplitHostDevice If the generated device function uses a buffer, generate a DeclBuffer for the buffer at the top of the device function. This is a subset of the changes made in #14778, broken out for ease of testing and review. * Updated thread sync test to account for DeclBuffer * Updated LowerWarp unit tests to find Allocate in PrimFunc
1 parent 8f60213 commit c921781

File tree

3 files changed

+14
-7
lines changed

3 files changed

+14
-7
lines changed

src/tir/transforms/split_host_device.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class HostDeviceSplitter : public StmtMutator {
5656

5757
private:
5858
Stmt SplitDeviceFunc(Stmt body, Target device_target) {
59-
Array<Var> params = [&]() {
59+
auto [params, buffers_to_declare] = [&]() -> std::tuple<Array<Var>, Array<Buffer>> {
6060
VarUseDefAnalyzer use_def(/*defined_vars=*/{}, /*visit_thread_extent=*/false);
6161
use_def(body);
6262

@@ -71,7 +71,7 @@ class HostDeviceSplitter : public StmtMutator {
7171
};
7272
return sort_key(a) < sort_key(b);
7373
});
74-
return params;
74+
return {params, use_def.undefined_buffers_};
7575
}();
7676

7777
// CodeGenCPU is used for some device-side targets, such as
@@ -91,12 +91,15 @@ class HostDeviceSplitter : public StmtMutator {
9191
kernel_ret_type = VoidType();
9292
}
9393

94-
GlobalVar kernel_symbol_global = var_supply_();
94+
for (Buffer buf : buffers_to_declare) {
95+
body = DeclBuffer(buf, std::move(body));
96+
}
9597
PrimFunc device_func(params, body, kernel_ret_type);
9698
device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target},
9799
{tir::attr::kNoAlias, Bool(true)},
98100
{tir::attr::kIsGlobalFunc, Bool(true)}});
99101

102+
GlobalVar kernel_symbol_global = var_supply_();
100103
(*device_mod_)->Add(kernel_symbol_global, device_func);
101104
Array<PrimExpr> args = params.Map([](const Var& var) -> PrimExpr { return var; });
102105

tests/python/unittest/test_tir_transform_lower_warp_memory.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import pytest
1919
import tvm
2020
import tvm.testing
21-
from tvm import te
21+
from tvm import te, tir
2222
from tvm.contrib.nvcc import have_fp16
2323

2424

@@ -55,9 +55,13 @@ def test_lower_warp_memory_local_scope():
5555

5656
mod = _run_passes(mod)
5757
fdevice = mod["f_kernel"]
58-
allocate = fdevice.body.body
58+
59+
allocate = fdevice
60+
while not isinstance(allocate, tir.Allocate):
61+
allocate = allocate.body
62+
5963
assert allocate.buffer_var.type_annotation.storage_scope == "local"
60-
assert fdevice.body.body.extents[0].value == 2
64+
assert allocate.extents[0].value == 2
6165

6266

6367
@tvm.testing.requires_cuda

tests/python/unittest/test_tir_transform_thread_sync.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_thread_storage_sync():
5757
func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None)
5858
mod = run_passes(func)
5959
f = mod["test_kernel"]
60-
body_list = tvm.tir.stmt_list(f.body.body.body)
60+
body_list = tvm.tir.stmt_list(f.body.body.body.body.body.body)
6161
assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync"))
6262

6363

0 commit comments

Comments
 (0)