Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/transform/lower_tile_op.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,12 @@ class BufferGemmCollector : public StmtExprVisitor {

private:
void VisitStmt_(const EvaluateNode *op) {
auto call = Downcast<Call>(op->value);
const CallNode *call_node = op->value.as<CallNode>();
// Value of EvaluateNode may not be a call
if (!call_node) {
return;
}
auto call = Downcast<Call>(call_node);
if (call->op.same_as(Gemm::Get())) {
auto srcA_buffer_access_ptr = Downcast<Call>(call->args[0]);
ICHECK(srcA_buffer_access_ptr->op.same_as(builtin::tvm_access_ptr()));
Expand Down
210 changes: 210 additions & 0 deletions src/transform/split_host_device.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file split_host_device.cc
* \brief Split device function from host.
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/global_var_supply.h>
#include <tvm/ir/transform.h>
#include <tvm/target/target.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "tir/analysis/var_use_def_analysis.h"

namespace tvm {
namespace tl {

namespace tir = tvm::tir;

class HostDeviceSplitter : public tir::StmtMutator {
public:
explicit HostDeviceSplitter(IRModule *device_mod,
std::function<GlobalVar()> var_supply)
: device_mod_(device_mod), var_supply_(std::move(var_supply)) {}

tir::Stmt VisitStmt_(const tir::AttrStmtNode *op) final {
if (op->attr_key == tvm::attr::kTarget) {
found_device_region_ = true;
auto device_target = op->node.as<tvm::Target>().value().WithoutHost();
return SplitDeviceFunc(op->body, device_target);
}
return tir::StmtMutator::VisitStmt_(op);
}

tir::Stmt ForceSplit(tir::Stmt body, tvm::Target device_target) {
return SplitDeviceFunc(std::move(body), std::move(device_target));
}

bool found_device_region() const { return found_device_region_; }

private:
bool found_device_region_{false};

tir::Stmt SplitDeviceFunc(tir::Stmt body, tvm::Target device_target) {
auto [params, buffers_to_declare] =
[&]() -> std::tuple<Array<tir::Var>, Array<tir::Buffer>> {
tir::VarUseDefAnalyzer use_def(/*defined_vars=*/{},
/*visit_thread_extent=*/true);
use_def(body);

// Sort first by variable type, then by variable name
std::vector<tir::Var> params{use_def.undefined_.begin(),
use_def.undefined_.end()};
std::sort(params.begin(), params.end(),
[](const tir::Var &a, const tir::Var &b) {
auto sort_key = [](const tir::Var &var) {
return std::tuple{
!var->dtype.is_handle(),
var->name_hint,
};
};
return sort_key(a) < sort_key(b);
});
return {params, use_def.undefined_buffers_};
}();

// CodeGenCPU is used for some device-side targets, such as
// "ext_dev", and expects to be able to return a int32_t status
// code.

bool can_propagate_errors = [&]() {
auto kind = device_target->GetTargetDeviceType();
return kind == kDLCPU || kind == kDLExtDev || kind == kDLHexagon;
}();
IntImm success(DataType::Int(32), 0);
Type kernel_ret_type;
if (can_propagate_errors) {
kernel_ret_type = PrimType(DataType::Int(32));
body = tir::SeqStmt::Flatten(body, tir::Evaluate(ret(success)));
} else {
kernel_ret_type = VoidType();
}

for (tir::Buffer buf : buffers_to_declare) {
body = tir::DeclBuffer(buf, std::move(body));
}
tir::PrimFunc device_func(params, body, kernel_ret_type);
device_func =
WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target},
{tir::attr::kNoAlias, true},
{tir::attr::kIsGlobalFunc, true}});

GlobalVar kernel_symbol_global = var_supply_();
(*device_mod_)->Add(kernel_symbol_global, device_func);
Array<PrimExpr> args =
params.Map([](const tir::Var &var) -> PrimExpr { return var; });

if (can_propagate_errors) {
tir::Var kernel_error_code("kernel_error_code", success->dtype);
tir::Call kernel_call(success->dtype, kernel_symbol_global, args);
tir::AssertStmt assert_success(
kernel_error_code == success,
tir::StringImm("Error executing compute kernel"), tir::Evaluate(0));
tir::LetStmt let_check(kernel_error_code, kernel_call, assert_success);

return let_check;

} else {
return tir::Evaluate(
tir::Call(DataType::Void(), kernel_symbol_global, args));
}
}

// target ir module
IRModule *device_mod_;
// Generate new GlobalVar for the kernel
std::function<GlobalVar()> var_supply_;
};

tir::PrimFunc SplitHostDevice(tir::PrimFunc func, IRModule *device_mod,
std::function<GlobalVar()> var_supply) {
HostDeviceSplitter splitter(device_mod, std::move(var_supply));

if (auto body = splitter(func->body); !body.same_as(func->body)) {
func.CopyOnWrite()->body = body;
} else if (!splitter.found_device_region()) {
if (auto target = func->GetAttr<Target>(tvm::attr::kTarget)) {
auto device_target = target.value().WithoutHost();
if (device_target.defined() &&
func->HasNonzeroAttr(tir::attr::kIsEntryFunc) &&
tir::is_no_op(func->body)) {
if (auto forced = splitter.ForceSplit(func->body, device_target);
!forced.same_as(func->body)) {
func.CopyOnWrite()->body = forced;
}
}
}
}

return func;
}

namespace transform {

tvm::transform::Pass SplitHostDevice() {
auto pass_func = [](IRModule mod, tvm::transform::PassContext ctx) {
tvm::GlobalVarSupply global_var_supply(mod);

IRModule device_mod = IRModule(Map<GlobalVar, BaseFunc>({}));
IRModule updates = IRModule(Map<GlobalVar, BaseFunc>({}));

for (const auto &[gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<tir::PrimFunc>()) {
tir::PrimFunc func = opt.value();

auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
auto name_prefix = global_symbol.value_or(gvar->name_hint);
auto kernel_name = name_prefix + "_kernel";
auto var_supply = [&global_var_supply, &kernel_name]() -> GlobalVar {
return global_var_supply->FreshGlobal(kernel_name, false);
};

func = ::tvm::tl::SplitHostDevice(std::move(func), &device_mod,
var_supply);
if (!func.same_as(base_func)) {
updates->Add(gvar, func);
}
}
}

mod->Update(updates);
mod->Update(device_mod);
return tir::transform::ConvertSSA()(mod);
};

return tvm::transform::CreateModulePass(pass_func, 0, "tl.SplitHostDevice",
{});
}

TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.SplitHostDevice", SplitHostDevice);
});

} // namespace transform
} // namespace tl
} // namespace tvm
71 changes: 71 additions & 0 deletions testing/python/issue/test_tilelang_issue_830.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# ruff: noqa

import torch
import tilelang
import tilelang.testing
import tilelang.language as T


@tilelang.jit
def _empty_kernel():

@T.prim_func
def empty_kernel():
with T.Kernel(1, threads=32) as thread_idx:
pass

return empty_kernel


def test_empty_kernel_lowering():
kernel = _empty_kernel()
kernel()


@tilelang.jit
def _empty_with_dead_code_kernel():
num_tokens = T.symbolic("num_tokens")

@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), "float32"]):
with T.Kernel(num_tokens, threads=32) as pid:
y = x[pid]

return buggy_kernel


@tilelang.testing.requires_cuda
def test_empty_with_dead_code_kernel():
kernel = _empty_with_dead_code_kernel()
x = torch.randn((128,), dtype=torch.float32, device="cuda")
kernel(x)


@tilelang.jit
def _empty_kernel_with_binding_variants(use_tuple_binding: bool = False):

@T.prim_func
def kernel_with_tuple_kernel_binding():
with T.Kernel(1, threads=32) as (pid,):
print(pid)
pass

@T.prim_func
def kernel_with_scalar_kernel_binding():
with T.Kernel(1, threads=32) as pid:
print(pid)
pass

return kernel_with_tuple_kernel_binding if use_tuple_binding else kernel_with_scalar_kernel_binding


def test_empty_kernel_with_binding_variants():
kernel = _empty_kernel_with_binding_variants()
kernel()

tuple_kernel = _empty_kernel_with_binding_variants(use_tuple_binding=True)
tuple_kernel()


if __name__ == "__main__":
tilelang.testing.main()
2 changes: 1 addition & 1 deletion tilelang/engine/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
if allow_global_thread_synchronization():
mod = tilelang.transform.ThreadSync("global")(mod)
mod = tilelang.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod)
mod = tilelang.transform.SplitHostDevice()(mod)
# MergeSharedMemoryAllocations must be applied after SplitHostDevice
# because the merged allocation site is at the beginning of each device function
enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target)
Expand Down
55 changes: 50 additions & 5 deletions tilelang/language/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,18 @@
from tilelang import _ffi_api
import threading

# Ensure single-dimension kernel bindings can be unpacked like iterables.
# especially for issue https://github.com/tile-ai/tilelang/issues/830
if not hasattr(Var, "__iter__"):

def _var_iter(self):
yield self

Var.__iter__ = _var_iter # type: ignore[attr-defined]

if not hasattr(Var, "__len__"):
Var.__len__ = lambda self: 1 # type: ignore[attr-defined]


class FrameStack:
"""
Expand Down Expand Up @@ -68,6 +80,17 @@ def _get_current_stack() -> FrameStack:
return _local.kernel_launch_frame_stack


def _normalize_bindings(bindings: List[Var]) -> Union[Var, List[Var]]:
"""
Return a bare Var when we only have a single binding so that users may write either
`with T.Kernel(...) as pid:` or `with T.Kernel(...) as (pid,)`.
Otherwise, keep the list semantics for multi-dimensional launches.
"""
if len(bindings) == 1:
return bindings[0]
return bindings


@register_object("tl.KernelLaunchFrame")
class KernelLaunchFrame(TIRFrame):
"""
Expand All @@ -83,9 +106,6 @@ def __enter__(self) -> Union[Var, List[Var]]:
"""
super().__enter__()
_get_current_stack().push(self)
# If we have exactly 5 frames, return the single iter_var.var.
if len(self.frames) == 5:
return self.frames[0].iter_var.var

last_block_frame = self.frames[-1]
assert isinstance(last_block_frame,
Expand All @@ -95,11 +115,11 @@ def __enter__(self) -> Union[Var, List[Var]]:

if maybe_cpu:
# CPU kernel frame, return a list of for frame items.
return [frame.vars[0] for frame in self.frames[0:-1]]
return _normalize_bindings([frame.vars[0] for frame in self.frames[0:-1]])
else:
# Otherwise, return a list of iter_var.var objects (excluding the last 4 frames).
# As 4 frames for threadIdx.x, threadIdx.y, threadIdx.z and block frame with attributes
return [frame.iter_var.var for frame in self.frames[0:-4]]
return _normalize_bindings([frame.iter_var.var for frame in self.frames[0:-4]])

def __exit__(self, ptype, value, trace):
"""
Expand Down Expand Up @@ -234,6 +254,31 @@ def Kernel(
-------
res : Tuple[frame.LaunchThreadFrame]
The result LaunchThreadFrame.

Examples
--------
Create a 1-D CUDA kernel launch and unpack the single block index:

.. code-block:: python

with T.Kernel(T.ceildiv(N, 128), threads=128) as bx:
# bx is the blockIdx.x binding (also iterable as (bx,))
...

Launch a 2-D grid while requesting two thread dimensions:

.. code-block:: python

with T.Kernel(grid_x, grid_y, threads=(64, 2)) as (bx, by):
tx, ty = T.get_thread_bindings()
...

Emit a CPU kernel where thread bindings are skipped:

.. code-block:: python

with T.Kernel(loop_extent, is_cpu=True) as (i,):
...
"""
attrs: dict = {}

Expand Down
Loading
Loading