Skip to content

Conversation

@SiriusNEO
Copy link
Contributor

@SiriusNEO SiriusNEO commented Oct 9, 2025

This PR fixes part of #830. Also it adds a test for this issue. Since the problem hasn't been completely fixed yet, the test is commented for now.

Summary by CodeRabbit

  • New Features

    • Exposed a host/device splitting transformation as a pipeline pass for extracting device kernels and invoking them from host code.
    • Added kernel binding unpacking so single-dimension bindings unpack to a single variable for simpler kernel APIs.
  • Bug Fixes

    • Prevented runtime errors by defensively handling evaluated expressions that aren't calls while preserving GEMM behavior.
  • Tests

    • Added regression tests covering empty-kernel lowering, dead-code scenarios, and binding-variant kernels.

@github-actions
Copy link

github-actions bot commented Oct 9, 2025

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 9, 2025

Walkthrough

Add a defensive cast for EvaluateNode values to avoid invalid downcasts; add tests reproducing the issue and kernel-binding variants; implement a new SplitHostDevice pass that extracts device kernels into a device IRModule and rewrites hosts to call them; expose and wire the pass into TileLang Python APIs and support single-binding unpacking in kernel bindings.

Changes

Cohort / File(s) Summary
Tile lowering safety fix
src/transform/lower_tile_op.cc
Replace a direct Downcast<Call>(op->value) with a defensive op->value.as<CallNode>() check; return early if not a CallNode and only construct Call when present; preserve GEMM/GemmSP handling and lowering control flow.
Issue reproduction tests
testing/python/issue/test_tilelang_issue_830.py
Add a Python test module with multiple tilelang JIT kernels and tests: _empty_kernel, _empty_with_dead_code_kernel, _empty_kernel_with_binding_variants and corresponding test functions that exercise lowering, dead-code scenario (CUDA), and single-vs-tuple binding variants.
Host/device split pass implementation
src/transform/split_host_device.cc
Add HostDeviceSplitter (tir::StmtMutator) and SplitHostDevice(tir::PrimFunc, IRModule*, std::function<GlobalVar()>); implement logic to analyze host regions, emit device PrimFuncs with appropriate params/attributes, register them in a device IRModule, and rewrite the host to call generated kernels with optional error propagation. Register the pass via FFI.
TileLang pass wiring & API
tilelang/engine/phase.py, tilelang/transform/__init__.py
Replace usage of TVM's tir.transform.SplitHostDevice with tilelang.transform.SplitHostDevice in the OptimizeForTarget flow and add a Python wrapper SplitHostDevice() in tilelang.transform.__init__ that returns the new pass.
Kernel binding unpacking
tilelang/language/kernel.py
Allow single-dimension kernel bindings to be unpacked by making Var iterable/length-1 and adding _normalize_bindings(bindings); normalize bindings in KernelLaunchFrame.__enter__; update docstrings/examples to document unpacking semantics.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Host as Host PrimFunc
  participant Splitter as HostDeviceSplitter
  participant DeviceMod as Device IRModule
  participant Kernel as Generated PrimFunc (device)
  participant HostCall as Rewritten Host (kernel call)

  Host->>Splitter: Visit host body (detect device regions)
  Splitter->>Splitter: Analyze uses/defs → collect params, buffers
  Splitter->>DeviceMod: Emit device PrimFunc (params, attrs, target)
  DeviceMod-->>Splitter: Register Kernel GlobalVar
  Splitter->>HostCall: Replace device region with kernel call (build args)
  alt error propagation enabled
    HostCall->>Host: Insert runtime check (Let/Assert) around call
  else
    HostCall->>Host: Emit direct kernel call
  end
  HostCall-->>Host: Return rewritten host function
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Poem

I nibble code with careful paws,
I guard each cast from risky laws.
I split the host and tuck kernels neat,
Tests hop in with tiny feet. 🐇

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 22.22% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly describes a bugfix for the compilation of a dummy kernel, which reflects the core objective of resolving the issue and adding the corresponding test infrastructure.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 10adb79 and cce7b6a.

📒 Files selected for processing (2)
  • src/transform/lower_tile_op.cc (1 hunks)
  • testing/python/issue/test_tilelang_issue_830.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/issue/test_tilelang_issue_830.py (5)
src/tl_templates/cuda/reduce.h (1)
  • T (75-147)
tilelang/jit/__init__.py (1)
  • jit (237-310)
tilelang/language/kernel.py (1)
  • threads (195-199)
tilelang/language/allocate.py (1)
  • alloc_shared (21-36)
tilelang/language/__init__.py (1)
  • symbolic (80-91)
🪛 GitHub Actions: CI Test on AMD
testing/python/issue/test_tilelang_issue_830.py

[error] 4-4: F401: torch imported but unused


[error] 12-12: F841: Local variable 'pid' is assigned to but never used


[error] 13-13: F841: Local variable 'A_shared' is assigned to but never used


[error] 25-25: F841: Local variable 'y' is assigned to but never used


[error] 32-32: F841: Local variable 'kernel' is assigned to but never used

🪛 Ruff (0.13.3)
testing/python/issue/test_tilelang_issue_830.py

1-1: The file is executable but no shebang is present

(EXE002)


11-11: Local variable pid is assigned to but never used

Remove assignment to unused variable pid

(F841)


12-12: Local variable A_shared is assigned to but never used

Remove assignment to unused variable A_shared

(F841)


24-24: Local variable y is assigned to but never used

Remove assignment to unused variable y

(F841)


30-30: Local variable kernel is assigned to but never used

Remove assignment to unused variable kernel

(F841)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: pr_reminder
  • GitHub Check: format-check
🔇 Additional comments (4)
src/transform/lower_tile_op.cc (1)

88-93: LGTM: Safe cast prevents crashes on non-Call EvaluateNode values.

The defensive pattern correctly handles cases where EvaluateNode::value is not a Call, preventing the crash that likely caused issue #830. The safe cast to CallNode*, null check, early return, and subsequent Downcast<Call> form a robust guard for GEMM collection.

testing/python/issue/test_tilelang_issue_830.py (3)

11-12: Unused variables pid and A_shared are intentional.

These variables are intentionally unused as this is test scaffolding for kernel generation (issue #830), not execution. The static analysis warnings can be safely ignored here.


24-24: Unused variable y is intentional.

This variable is intentionally unused as part of the test scaffolding for kernel generation. The static analysis warning can be safely ignored.


28-32: Test scaffolding acknowledged; kernel execution commented out as expected.

Per the PR description, the underlying issue for #830 is not completely resolved, so the commented-out kernel execution is appropriate for now. The unused kernel variable flag can be safely ignored until the test is fully enabled.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cce7b6a and 8f1f9d2.

📒 Files selected for processing (1)
  • testing/python/issue/test_tilelang_issue_830.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/issue/test_tilelang_issue_830.py (6)
src/tl_templates/cuda/reduce.h (1)
  • T (75-147)
tilelang/jit/__init__.py (1)
  • jit (237-310)
tilelang/language/kernel.py (1)
  • threads (195-199)
tilelang/language/allocate.py (1)
  • alloc_shared (21-36)
tilelang/language/__init__.py (1)
  • symbolic (80-91)
tilelang/language/proxy.py (1)
  • Tensor (254-255)
🪛 GitHub Actions: CI Test on Metal
testing/python/issue/test_tilelang_issue_830.py

[error] 11-11: flake8: F841 Local variable 'pid' is assigned to but never used


[error] 12-12: flake8: F841 Local variable 'A_shared' is assigned to but never used


[error] 24-24: flake8: F841 Local variable 'y' is assigned to but never used


[error] 31-31: flake8: F841 Local variable 'kernel' is assigned to but never used

🪛 Ruff (0.13.3)
testing/python/issue/test_tilelang_issue_830.py

1-1: The file is executable but no shebang is present

(EXE002)


10-10: Local variable pid is assigned to but never used

Remove assignment to unused variable pid

(F841)


11-11: Local variable A_shared is assigned to but never used

Remove assignment to unused variable A_shared

(F841)


23-23: Local variable y is assigned to but never used

Remove assignment to unused variable y

(F841)


29-29: Local variable kernel is assigned to but never used

Remove assignment to unused variable kernel

(F841)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: format-check

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
testing/python/issue/test_tilelang_issue_830.py (2)

15-16: Prefix unused variables with underscore to pass linting.

The variables pid and A_shared are assigned but never used, causing linting failures. Since these are part of the bug reproduction case, prefix them with underscore to mark them as intentionally unused.

Apply this diff:

-        with T.Kernel(1, threads=32) as pid:
-            A_shared = T.alloc_shared((1,), "float32")
+        with T.Kernel(1, threads=32) as _pid:
+            _A_shared = T.alloc_shared((1,), "float32")

28-28: Prefix unused variable with underscore to pass linting.

The variable y is assigned but never used, causing linting failures. Since this is part of the bug reproduction case, prefix it with underscore to mark it as intentionally unused.

Apply this diff:

-            y = x[pid]
+            _y = x[pid]
🧹 Nitpick comments (1)
testing/python/issue/test_tilelang_issue_830.py (1)

32-37: Consider using pytest.skip or pytest.xfail instead of commented code.

The test is currently inactive via commented code. A cleaner approach would be to use pytest's skip or xfail decorators to formally mark the test as blocked, which provides better visibility in test reports and maintains executable code.

For example:

import pytest

@pytest.mark.skip(reason="Blocked by issue #830: underlying problem not fully resolved")
def test_dummy_kernel_gen():
    """Test dummy kernel generation"""
    kernel = get_buggy_kernel()
    kernel()

Or with xfail if you want it to run but expect failure:

@pytest.mark.xfail(reason="Issue #830 not fully resolved")
def test_dummy_kernel_gen():
    """Test dummy kernel generation"""
    kernel = get_buggy_kernel()
    kernel()
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bfc1668 and c349bd8.

📒 Files selected for processing (1)
  • testing/python/issue/test_tilelang_issue_830.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/issue/test_tilelang_issue_830.py (4)
tilelang/jit/__init__.py (1)
  • jit (237-310)
tilelang/language/kernel.py (1)
  • threads (195-199)
tilelang/language/allocate.py (1)
  • alloc_shared (21-36)
tilelang/language/__init__.py (1)
  • symbolic (80-91)
🔇 Additional comments (1)
testing/python/issue/test_tilelang_issue_830.py (1)

21-30: Verify whether get_buggy_kernel1() should be tested or removed.

This function is defined but never called in the test suite. If it's intended to test a different aspect of issue #830, consider adding it to test_dummy_kernel_gen() (when the underlying issue is resolved) or documenting its purpose. If it's not needed, remove it to reduce maintenance burden.

@LeiWang1999
Copy link
Member

made an enhancement, now this can be:

Before SplitHostDevice
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def buggy():
        T.func_attr({"target": T.target({"arch": "sm_90", "host": {"keys": ["cpu"], "kind": "c", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_entry_func": True, "tma_descriptor_args": {}})
        T.evaluate(0)
After SplitHostDevice
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def buggy():
        T.func_attr({"target": T.target({"arch": "sm_90", "host": {"keys": ["cpu"], "kind": "c", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_entry_func": True, "tma_descriptor_args": {}})
        Module.buggy_kernel()

    @T.prim_func(private=True)
    def buggy_kernel():
        T.func_attr({"target": T.target({"arch": "sm_90", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_global_func": True, "tir.noalias": True})
        T.evaluate(0)

and test

import tilelang
import tilelang.testing
import tilelang.language as T


@tilelang.jit
def _empty_kernel():
    @T.prim_func
    def empty_kernel():
        with T.Kernel(1, threads=32) as thread_idx:
            A_shared = T.alloc_shared((1,), "float32")

    return empty_kernel


def test_empty_kernel_lowering():
    kernel = _empty_kernel()
    kernel()


if __name__ == "__main__":
    test_empty_kernel_lowering()

can pass.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
testing/python/issue/test_tilelang_issue_830.py (3)

1-1: ruff: noqa suppresses all lint warnings for the entire file.

The # ruff: noqa comment disables all Ruff checks for this file. If past review comments about unused variables have been addressed, consider removing this blanket suppression and using more targeted suppressions if needed.

Consider removing the blanket suppression if all lint issues are resolved:

-# ruff: noqa

8-16: Consider adding a docstring to document the test case.

The _empty_kernel function creates a minimal empty kernel for reproduction. Adding a docstring would help explain the purpose and context of this test case, especially its relation to issue #830.

 @tilelang.jit
 def _empty_kernel():
+    """Create a minimal empty kernel for testing host-device splitting.
+    
+    This reproduces the issue where empty kernels failed to lower correctly
+    before the SplitHostDevice enhancement. See issue #830.
+    """
 
     @T.prim_func
     def empty_kernel():

13-14: Prefix intentionally unused variables with underscore or add inline suppression
thread_idx and A_shared are not referenced in this minimal repro—rename to _thread_idx, _A_shared or append # noqa: F841 to avoid potential lint warnings.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c349bd8 and 577bc6d.

📒 Files selected for processing (5)
  • src/transform/lower_tile_op.cc (1 hunks)
  • src/transform/split_host_device.cc (1 hunks)
  • testing/python/issue/test_tilelang_issue_830.py (1 hunks)
  • tilelang/engine/phase.py (1 hunks)
  • tilelang/transform/__init__.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/transform/lower_tile_op.cc
🧰 Additional context used
🧬 Code graph analysis (4)
tilelang/transform/__init__.py (1)
src/transform/split_host_device.cc (4)
  • SplitHostDevice (138-158)
  • SplitHostDevice (138-139)
  • SplitHostDevice (162-193)
  • SplitHostDevice (162-162)
testing/python/issue/test_tilelang_issue_830.py (2)
tilelang/language/kernel.py (1)
  • threads (195-199)
tilelang/language/allocate.py (1)
  • alloc_shared (21-36)
src/transform/split_host_device.cc (2)
tilelang/language/tir/op.py (1)
  • ret (1758-1771)
tilelang/transform/__init__.py (1)
  • SplitHostDevice (285-293)
tilelang/engine/phase.py (2)
src/transform/split_host_device.cc (4)
  • SplitHostDevice (138-158)
  • SplitHostDevice (138-139)
  • SplitHostDevice (162-193)
  • SplitHostDevice (162-162)
tilelang/transform/__init__.py (1)
  • SplitHostDevice (285-293)
🪛 GitHub Actions: CI Test on AMD
src/transform/split_host_device.cc

[error] 1-1: Reformatted files. Please review and stage the changes.

🪛 GitHub Actions: CI Test on Metal
src/transform/split_host_device.cc

[error] 1-1: Code formatting changes detected by tile-lang formatter. Please review and stage changes or run the formatter. The step './format.sh' exited with code 1.

🔇 Additional comments (13)
tilelang/engine/phase.py (1)

189-189: LGTM! Correct delegation to TileLang-specific host-device splitting.

The change correctly delegates to the new TileLang-specific SplitHostDevice implementation that handles empty kernels, aligning with the PR objective to fix empty kernel lowering.

tilelang/transform/__init__.py (1)

285-293: LGTM! Consistent with existing pass accessor pattern.

The SplitHostDevice() function follows the established pattern for pass accessors in this module, with clear documentation and proper FFI delegation.

testing/python/issue/test_tilelang_issue_830.py (1)

19-21: LGTM! Test correctly exercises the empty kernel lowering path.

The test appropriately invokes the empty kernel to verify the host-device splitting works correctly for empty kernels, which is the core issue being addressed in this PR.

src/transform/split_host_device.cc (10)

43-46: LGTM! Well-designed class structure for host-device splitting.

The HostDeviceSplitter class appropriately extends tir::StmtMutator and stores the device module and symbol generator for creating device functions.


48-55: LGTM! Correct detection and handling of device regions.

The visitor correctly identifies device regions by checking for the kTarget attribute and delegates to SplitDeviceFunc for processing, properly tracking whether a device region was found.


57-61: LGTM! Simple utility methods for forced splitting and status checking.

The ForceSplit method provides a way to force device function creation even without explicit device regions, which is essential for handling empty kernels. The found_device_region getter is appropriately straightforward.


66-84: LGTM! Correct analysis and sorting of device function parameters.

The use of VarUseDefAnalyzer to identify undefined variables (function parameters) is appropriate, and the sorting logic ensures a consistent parameter order (handles before non-handles, then alphabetically by name).


90-101: LGTM! Appropriate error propagation for supported targets.

The logic correctly identifies targets that can propagate error codes (CPU, ExtDev, Hexagon) and adds appropriate return handling. For other targets (e.g., CUDA), void return is used, which is correct for device kernels.


103-115: LGTM! Correct device function creation and registration.

The device function is properly constructed with buffer declarations, appropriate attributes (target, no_alias, is_global_func), and registered in the device module with a fresh global symbol.


117-129: LGTM! Proper host-side call generation with error checking.

The host-side replacement correctly generates either:

  • An error-checking sequence (for targets that support error propagation)
  • A simple evaluate call (for device-only targets)

This ensures proper error handling where supported while maintaining simplicity for pure device targets.


138-158: LGTM! Critical fix for empty kernel handling.

This function correctly implements the fix for issue #830 by forcing a split for empty entry functions with device targets (lines 144-154). The logic appropriately:

  • Attempts normal splitting first
  • Falls back to ForceSplit only when no device region exists AND the function is an entry function with a no-op body
  • Only modifies the function if changes are necessary

This directly addresses the PR objective of fixing empty kernel lowering.


162-193: LGTM! Correct module pass implementation.

The pass correctly:

  • Processes all PrimFuncs in the module
  • Generates unique kernel names with "_kernel" suffix
  • Accumulates changes in separate modules (updates and device_mod)
  • Applies updates atomically
  • Converts to SSA form before returning

The implementation follows TVM's pass conventions and properly integrates the device functions into the module.


195-198: LGTM! Proper FFI registration for Python access.

The static initialization block correctly registers the pass under "tl.transform.SplitHostDevice" using TVM's reflection system, enabling Python access through the API defined in tilelang/transform/__init__.py.

Comment on lines 1 to 202
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file split_host_device.cc
* \brief Split device function from host.
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/global_var_supply.h>
#include <tvm/ir/transform.h>
#include <tvm/target/target.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "tir/analysis/var_use_def_analysis.h"

namespace tvm {
namespace tl {

namespace tir = tvm::tir;

class HostDeviceSplitter : public tir::StmtMutator {
public:
explicit HostDeviceSplitter(IRModule* device_mod, std::function<GlobalVar()> var_supply)
: device_mod_(device_mod), var_supply_(std::move(var_supply)) {}

tir::Stmt VisitStmt_(const tir::AttrStmtNode* op) final {
if (op->attr_key == tvm::attr::kTarget) {
found_device_region_ = true;
auto device_target = op->node.as<tvm::Target>().value().WithoutHost();
return SplitDeviceFunc(op->body, device_target);
}
return tir::StmtMutator::VisitStmt_(op);
}

tir::Stmt ForceSplit(tir::Stmt body, tvm::Target device_target) {
return SplitDeviceFunc(std::move(body), std::move(device_target));
}

bool found_device_region() const { return found_device_region_; }

private:
bool found_device_region_{false};

tir::Stmt SplitDeviceFunc(tir::Stmt body, tvm::Target device_target) {
auto [params, buffers_to_declare] =
[&]() -> std::tuple<Array<tir::Var>, Array<tir::Buffer>> {
tir::VarUseDefAnalyzer use_def(/*defined_vars=*/{}, /*visit_thread_extent=*/true);
use_def(body);

// Sort first by variable type, then by variable name
std::vector<tir::Var> params{use_def.undefined_.begin(), use_def.undefined_.end()};
std::sort(params.begin(), params.end(), [](const tir::Var& a, const tir::Var& b) {
auto sort_key = [](const tir::Var& var) {
return std::tuple{
!var->dtype.is_handle(),
var->name_hint,
};
};
return sort_key(a) < sort_key(b);
});
return {params, use_def.undefined_buffers_};
}();

// CodeGenCPU is used for some device-side targets, such as
// "ext_dev", and expects to be able to return a int32_t status
// code.

bool can_propagate_errors = [&]() {
auto kind = device_target->GetTargetDeviceType();
return kind == kDLCPU || kind == kDLExtDev || kind == kDLHexagon;
}();
IntImm success(DataType::Int(32), 0);
Type kernel_ret_type;
if (can_propagate_errors) {
kernel_ret_type = PrimType(DataType::Int(32));
body = tir::SeqStmt::Flatten(body, tir::Evaluate(ret(success)));
} else {
kernel_ret_type = VoidType();
}

for (tir::Buffer buf : buffers_to_declare) {
body = tir::DeclBuffer(buf, std::move(body));
}
tir::PrimFunc device_func(params, body, kernel_ret_type);
device_func = WithAttrs(
std::move(device_func),
{{tvm::attr::kTarget, device_target},
{tir::attr::kNoAlias, true},
{tir::attr::kIsGlobalFunc, true}});

GlobalVar kernel_symbol_global = var_supply_();
(*device_mod_)->Add(kernel_symbol_global, device_func);
Array<PrimExpr> args = params.Map([](const tir::Var& var) -> PrimExpr { return var; });

if (can_propagate_errors) {
tir::Var kernel_error_code("kernel_error_code", success->dtype);
tir::Call kernel_call(success->dtype, kernel_symbol_global, args);
tir::AssertStmt assert_success(
kernel_error_code == success, tir::StringImm("Error executing compute kernel"),
tir::Evaluate(0));
tir::LetStmt let_check(kernel_error_code, kernel_call, assert_success);

return let_check;

} else {
return tir::Evaluate(tir::Call(DataType::Void(), kernel_symbol_global, args));
}
}

// target ir module
IRModule* device_mod_;
// Generate new GlobalVar for the kernel
std::function<GlobalVar()> var_supply_;
};

tir::PrimFunc SplitHostDevice(tir::PrimFunc func, IRModule* device_mod,
std::function<GlobalVar()> var_supply) {
HostDeviceSplitter splitter(device_mod, std::move(var_supply));

if (auto body = splitter(func->body); !body.same_as(func->body)) {
func.CopyOnWrite()->body = body;
} else if (!splitter.found_device_region()) {
if (auto target = func->GetAttr<Target>(tvm::attr::kTarget)) {
auto device_target = target.value().WithoutHost();
if (device_target.defined() && func->HasNonzeroAttr(tir::attr::kIsEntryFunc) &&
tir::is_no_op(func->body)) {
if (auto forced = splitter.ForceSplit(func->body, device_target);
!forced.same_as(func->body)) {
func.CopyOnWrite()->body = forced;
}
}
}
}

return func;
}

namespace transform {

tvm::transform::Pass SplitHostDevice() {
auto pass_func = [](IRModule mod, tvm::transform::PassContext ctx) {
tvm::GlobalVarSupply global_var_supply(mod);

IRModule device_mod = IRModule(Map<GlobalVar, BaseFunc>({}));
IRModule updates = IRModule(Map<GlobalVar, BaseFunc>({}));

for (const auto& [gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<tir::PrimFunc>()) {
tir::PrimFunc func = opt.value();

auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
auto name_prefix = global_symbol.value_or(gvar->name_hint);
auto kernel_name = name_prefix + "_kernel";
auto var_supply = [&global_var_supply, &kernel_name]() -> GlobalVar {
return global_var_supply->FreshGlobal(kernel_name, false);
};

func = ::tvm::tl::SplitHostDevice(std::move(func), &device_mod, var_supply);
if (!func.same_as(base_func)) {
updates->Add(gvar, func);
}
}
}

mod->Update(updates);
mod->Update(device_mod);
return tir::transform::ConvertSSA()(mod);
};

return tvm::transform::CreateModulePass(pass_func, 0, "tl.SplitHostDevice", {});
}

TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.SplitHostDevice", SplitHostDevice);
});

} // namespace transform
} // namespace tl
} // namespace tvm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Address code formatting to pass CI checks.

The CI pipeline is failing due to code formatting issues detected by the tile-lang formatter. Please run the formatter to ensure the code adheres to the project's style guidelines.

Run the formatter to fix the issues:

#!/bin/bash
# Run the tile-lang formatter
./format.sh
🧰 Tools
🪛 GitHub Actions: CI Test on AMD

[error] 1-1: Reformatted files. Please review and stage the changes.

🪛 GitHub Actions: CI Test on Metal

[error] 1-1: Code formatting changes detected by tile-lang formatter. Please review and stage changes or run the formatter. The step './format.sh' exited with code 1.

🤖 Prompt for AI Agents
In src/transform/split_host_device.cc lines 1-202 the file fails the project's
code-style checks; run the repository formatter and commit the resulting changes
to fix spacing/indentation and line-wrapping so CI passes: execute the provided
formatter script (./format.sh), review the modified file to ensure no logic
changes were introduced, stage and commit the formatted file, and push the
commit so the CI pipeline can re-run.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
src/transform/split_host_device.cc (3)

52-52: Add defensive check for Target cast.

The .value() call will throw if the optional is empty. While the node should always be a Target when attr_key == kTarget, defensive programming suggests checking first.

Apply this diff:

-      auto device_target = op->node.as<tvm::Target>().value().WithoutHost();
+      auto target_opt = op->node.as<tvm::Target>();
+      ICHECK(target_opt.defined()) << "Expected Target node for kTarget attribute";
+      auto device_target = target_opt.value().WithoutHost();

147-161: Clarify the empty kernel handling logic.

The fallback logic for empty kernels with ForceSplit is subtle. Consider adding a comment explaining why empty entry functions with device targets need forced splitting.

Apply this diff:

   if (auto body = splitter(func->body); !body.same_as(func->body)) {
     func.CopyOnWrite()->body = body;
   } else if (!splitter.found_device_region()) {
+    // Empty entry functions with device targets still need host/device
+    // splitting to properly generate kernel stubs and device modules.
     if (auto target = func->GetAttr<Target>(tvm::attr::kTarget)) {

169-196: FreshGlobal guarantees unique kernel names
Consider adding a test that runs SplitHostDevice on multiple functions sharing the same name_prefix to verify no name collisions occur.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c349bd8 and 7109ddf.

📒 Files selected for processing (5)
  • src/transform/lower_tile_op.cc (1 hunks)
  • src/transform/split_host_device.cc (1 hunks)
  • testing/python/issue/test_tilelang_issue_830.py (1 hunks)
  • tilelang/engine/phase.py (1 hunks)
  • tilelang/transform/__init__.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • src/transform/lower_tile_op.cc
  • testing/python/issue/test_tilelang_issue_830.py
🧰 Additional context used
🧬 Code graph analysis (3)
tilelang/transform/__init__.py (1)
src/transform/split_host_device.cc (4)
  • SplitHostDevice (143-164)
  • SplitHostDevice (143-144)
  • SplitHostDevice (168-201)
  • SplitHostDevice (168-168)
tilelang/engine/phase.py (2)
src/transform/split_host_device.cc (4)
  • SplitHostDevice (143-164)
  • SplitHostDevice (143-144)
  • SplitHostDevice (168-201)
  • SplitHostDevice (168-168)
tilelang/transform/__init__.py (1)
  • SplitHostDevice (285-293)
src/transform/split_host_device.cc (2)
tilelang/language/tir/op.py (1)
  • ret (1758-1771)
tilelang/transform/__init__.py (1)
  • SplitHostDevice (285-293)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-metal
  • GitHub Check: build-test-amd
🔇 Additional comments (5)
src/transform/split_host_device.cc (3)

94-105: LGTM! Error propagation logic is well-designed.

The conditional error propagation based on device type is correct. CPU, ExtDev, and Hexagon backends can return status codes, while GPU backends use void returns. The success value and return statement injection are properly implemented.


68-88: LGTM! Variable analysis and parameter sorting are correct.

The VarUseDefAnalyzer correctly identifies undefined variables (function parameters) and buffers. The sorting logic (non-handles first, then by name) ensures a stable, predictable parameter order for generated kernels.


203-206: LGTM! FFI registration follows TVM conventions.

The registration under tl.transform.SplitHostDevice properly exposes the pass to Python via the FFI layer.

tilelang/transform/__init__.py (1)

285-293: LGTM! FFI wrapper follows consistent patterns.

The SplitHostDevice wrapper correctly delegates to _ffi_api.SplitHostDevice() and includes an appropriate docstring highlighting the "even for empty kernels" capability.

tilelang/engine/phase.py (1)

189-189: LGTM! Correctly switches to TileLang-specific SplitHostDevice.

This change properly integrates the new TileLang SplitHostDevice pass (with empty kernel handling) into the compilation pipeline, replacing TVM's default implementation.

@SiriusNEO SiriusNEO changed the title [Bugfix] Fix visit EvaluateNode in BufferGemmCollector [Bugfix] Fix dummy kernel compliation Oct 10, 2025
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (5)
testing/python/issue/test_tilelang_issue_830.py (5)

1-1: Consider more targeted linting suppressions.

The blanket # ruff: noqa suppresses all linting for the file. Since this test intentionally has unused variables to reproduce bugs, consider using targeted suppressions like # ruff: noqa: F841 to allow other checks to run.


20-22: Consider adding assertions for more informative test failures.

The test calls the kernel but doesn't assert anything about the behavior. Consider adding an assertion to make test failures more informative, such as checking that the kernel executes without raising exceptions.

Apply this diff:

 def test_empty_kernel_lowering():
     kernel = _empty_kernel()
-    kernel()
+    assert kernel() is None  # or check expected return value

37-41: Consider adding an assertion.

The test could verify that the kernel executes successfully by asserting it doesn't raise an exception or by checking output values.

Apply this diff:

 @tilelang.testing.requires_cuda
 def test_empty_with_dead_code_kernel():
     kernel = _empty_with_dead_code_kernel()
     x = torch.randn((128,), dtype=torch.float32, device="cuda")
-    kernel(x)
+    kernel(x)  # Should not raise
+    assert x is not None  # Verify input remains valid

44-59: Remove debug print statements.

Both kernel variants contain print(pid) statements that appear to be debug code. These should either be removed or replaced with meaningful assertions if they're intended to verify kernel execution.

Apply this diff:

     @T.prim_func
     def kernel_with_tuple_kernel_binding():
         with T.Kernel(1, threads=32) as (pid,):
-            print(pid)
             pass
 
     @T.prim_func
     def kernel_with_scalar_kernel_binding():
         with T.Kernel(1, threads=32) as pid:
-            print(pid)
             pass

62-67: Consider adding assertions to verify binding variants work correctly.

The test exercises both binding styles but doesn't assert their correctness. Consider adding assertions to make test failures more informative.

Apply this diff:

 def test_empty_kernel_with_binding_variants():
     kernel = _empty_kernel_with_binding_variants()
-    kernel()
+    assert kernel() is None  # Scalar binding should work
 
     tuple_kernel = _empty_kernel_with_binding_variants(use_tuple_binding=True)
-    tuple_kernel()
+    assert tuple_kernel() is None  # Tuple binding should work
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7109ddf and 40adc97.

📒 Files selected for processing (4)
  • src/transform/lower_tile_op.cc (1 hunks)
  • testing/python/issue/test_tilelang_issue_830.py (1 hunks)
  • tilelang/engine/phase.py (1 hunks)
  • tilelang/language/kernel.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • src/transform/lower_tile_op.cc
  • tilelang/engine/phase.py
🧰 Additional context used
🧬 Code graph analysis (2)
tilelang/language/kernel.py (1)
tilelang/language/ast/ir.py (2)
  • iter_var (1616-1639)
  • var (1517-1533)
testing/python/issue/test_tilelang_issue_830.py (4)
src/tl_templates/cuda/reduce.h (1)
  • T (75-147)
tilelang/jit/__init__.py (1)
  • jit (244-317)
tilelang/language/kernel.py (2)
  • Kernel (229-303)
  • threads (215-219)
tilelang/language/__init__.py (1)
  • symbolic (84-95)
🪛 Ruff (0.13.3)
tilelang/language/kernel.py

22-22: Unused lambda argument: self

(ARG005)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-amd
  • GitHub Check: build-test-metal
🔇 Additional comments (9)
testing/python/issue/test_tilelang_issue_830.py (4)

3-6: LGTM!

All imports are appropriately used in the test cases.


9-17: LGTM!

The empty kernel factory correctly tests basic kernel compilation and launch frame creation.


25-34: LGTM!

The dead code kernel correctly reproduces the bug scenario from issue #830 where an assigned but unused variable triggers compilation issues.


70-71: LGTM!

The main guard correctly uses the TileLang testing framework's entry point.

tilelang/language/kernel.py (5)

83-92: LGTM!

The helper function cleanly normalizes bindings to return a bare Var for single-dimension launches while preserving list semantics for multi-dimensional launches. This enables both with T.Kernel(n) as bx: and with T.Kernel(n, m) as (bx, by): unpacking patterns.


118-118: LGTM!

Applying _normalize_bindings to the CPU path ensures consistent unpacking behavior across both CPU and GPU kernel launches.


122-122: LGTM!

Applying _normalize_bindings to the non-CPU path completes the symmetric unpacking support for single and multi-dimensional kernel launches.


257-281: LGTM!

The updated examples clearly demonstrate the new unpacking behavior for single and multi-dimensional kernel launches. The examples cover 1-D CUDA, 2-D CUDA with multiple thread dimensions, and CPU kernels.

Note: The CPU example shows as (i,): with a trailing comma. Users might find it more natural to write as i: for single-dimension CPU kernels, which is now supported.


12-22: Monkey-patching is safe No code paths check kernel bindings’ type as list; all with T.Kernel usages remain compatible. The unused-self warning is a false positive.

@LeiWang1999
Copy link
Member

Also, fix the pid length matching bug from issue #830 with the following code:

if not hasattr(Var, "__iter__"):

    def _var_iter(self):
        yield self

    Var.__iter__ = _var_iter  # type: ignore[attr-defined]

if not hasattr(Var, "__len__"):
    Var.__len__ = lambda self: 1  # type: ignore[attr-defined]

Now, both of these code styles are valid and will return the correct pid value.

@tilelang.jit
def _empty_kernel_with_binding_variants(use_tuple_binding: bool = False):

    @T.prim_func
    def kernel_with_tuple_kernel_binding():
        with T.Kernel(1, threads=32) as (pid,):
            print(pid)
            pass

    @T.prim_func
    def kernel_with_scalar_kernel_binding():
        with T.Kernel(1, threads=32) as pid:
            print(pid)
            pass

    return kernel_with_tuple_kernel_binding if use_tuple_binding else kernel_with_scalar_kernel_binding

@LeiWang1999
Copy link
Member

Local test Pass, we can merge this pull request now.

@LeiWang1999 LeiWang1999 merged commit 7913fb1 into tile-ai:main Oct 10, 2025
6 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants