-
Couldn't load subscription status.
- Fork 286
[Bugfix] Fix dummy kernel compliation #962
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughAdd a defensive cast for EvaluateNode values to avoid invalid downcasts; add tests reproducing the issue and kernel-binding variants; implement a new SplitHostDevice pass that extracts device kernels into a device IRModule and rewrites hosts to call them; expose and wire the pass into TileLang Python APIs and support single-binding unpacking in kernel bindings. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Host as Host PrimFunc
participant Splitter as HostDeviceSplitter
participant DeviceMod as Device IRModule
participant Kernel as Generated PrimFunc (device)
participant HostCall as Rewritten Host (kernel call)
Host->>Splitter: Visit host body (detect device regions)
Splitter->>Splitter: Analyze uses/defs → collect params, buffers
Splitter->>DeviceMod: Emit device PrimFunc (params, attrs, target)
DeviceMod-->>Splitter: Register Kernel GlobalVar
Splitter->>HostCall: Replace device region with kernel call (build args)
alt error propagation enabled
HostCall->>Host: Insert runtime check (Let/Assert) around call
else
HostCall->>Host: Emit direct kernel call
end
HostCall-->>Host: Return rewritten host function
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/transform/lower_tile_op.cc(1 hunks)testing/python/issue/test_tilelang_issue_830.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/issue/test_tilelang_issue_830.py (5)
src/tl_templates/cuda/reduce.h (1)
T(75-147)tilelang/jit/__init__.py (1)
jit(237-310)tilelang/language/kernel.py (1)
threads(195-199)tilelang/language/allocate.py (1)
alloc_shared(21-36)tilelang/language/__init__.py (1)
symbolic(80-91)
🪛 GitHub Actions: CI Test on AMD
testing/python/issue/test_tilelang_issue_830.py
[error] 4-4: F401: torch imported but unused
[error] 12-12: F841: Local variable 'pid' is assigned to but never used
[error] 13-13: F841: Local variable 'A_shared' is assigned to but never used
[error] 25-25: F841: Local variable 'y' is assigned to but never used
[error] 32-32: F841: Local variable 'kernel' is assigned to but never used
🪛 Ruff (0.13.3)
testing/python/issue/test_tilelang_issue_830.py
1-1: The file is executable but no shebang is present
(EXE002)
11-11: Local variable pid is assigned to but never used
Remove assignment to unused variable pid
(F841)
12-12: Local variable A_shared is assigned to but never used
Remove assignment to unused variable A_shared
(F841)
24-24: Local variable y is assigned to but never used
Remove assignment to unused variable y
(F841)
30-30: Local variable kernel is assigned to but never used
Remove assignment to unused variable kernel
(F841)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: pr_reminder
- GitHub Check: format-check
🔇 Additional comments (4)
src/transform/lower_tile_op.cc (1)
88-93: LGTM: Safe cast prevents crashes on non-Call EvaluateNode values.The defensive pattern correctly handles cases where
EvaluateNode::valueis not aCall, preventing the crash that likely caused issue #830. The safe cast toCallNode*, null check, early return, and subsequentDowncast<Call>form a robust guard for GEMM collection.testing/python/issue/test_tilelang_issue_830.py (3)
11-12: Unused variablespidandA_sharedare intentional.These variables are intentionally unused as this is test scaffolding for kernel generation (issue #830), not execution. The static analysis warnings can be safely ignored here.
24-24: Unused variableyis intentional.This variable is intentionally unused as part of the test scaffolding for kernel generation. The static analysis warning can be safely ignored.
28-32: Test scaffolding acknowledged; kernel execution commented out as expected.Per the PR description, the underlying issue for #830 is not completely resolved, so the commented-out kernel execution is appropriate for now. The unused
kernelvariable flag can be safely ignored until the test is fully enabled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
testing/python/issue/test_tilelang_issue_830.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/issue/test_tilelang_issue_830.py (6)
src/tl_templates/cuda/reduce.h (1)
T(75-147)tilelang/jit/__init__.py (1)
jit(237-310)tilelang/language/kernel.py (1)
threads(195-199)tilelang/language/allocate.py (1)
alloc_shared(21-36)tilelang/language/__init__.py (1)
symbolic(80-91)tilelang/language/proxy.py (1)
Tensor(254-255)
🪛 GitHub Actions: CI Test on Metal
testing/python/issue/test_tilelang_issue_830.py
[error] 11-11: flake8: F841 Local variable 'pid' is assigned to but never used
[error] 12-12: flake8: F841 Local variable 'A_shared' is assigned to but never used
[error] 24-24: flake8: F841 Local variable 'y' is assigned to but never used
[error] 31-31: flake8: F841 Local variable 'kernel' is assigned to but never used
🪛 Ruff (0.13.3)
testing/python/issue/test_tilelang_issue_830.py
1-1: The file is executable but no shebang is present
(EXE002)
10-10: Local variable pid is assigned to but never used
Remove assignment to unused variable pid
(F841)
11-11: Local variable A_shared is assigned to but never used
Remove assignment to unused variable A_shared
(F841)
23-23: Local variable y is assigned to but never used
Remove assignment to unused variable y
(F841)
29-29: Local variable kernel is assigned to but never used
Remove assignment to unused variable kernel
(F841)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: format-check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
testing/python/issue/test_tilelang_issue_830.py (2)
15-16: Prefix unused variables with underscore to pass linting.The variables
pidandA_sharedare assigned but never used, causing linting failures. Since these are part of the bug reproduction case, prefix them with underscore to mark them as intentionally unused.Apply this diff:
- with T.Kernel(1, threads=32) as pid: - A_shared = T.alloc_shared((1,), "float32") + with T.Kernel(1, threads=32) as _pid: + _A_shared = T.alloc_shared((1,), "float32")
28-28: Prefix unused variable with underscore to pass linting.The variable
yis assigned but never used, causing linting failures. Since this is part of the bug reproduction case, prefix it with underscore to mark it as intentionally unused.Apply this diff:
- y = x[pid] + _y = x[pid]
🧹 Nitpick comments (1)
testing/python/issue/test_tilelang_issue_830.py (1)
32-37: Consider using pytest.skip or pytest.xfail instead of commented code.The test is currently inactive via commented code. A cleaner approach would be to use pytest's skip or xfail decorators to formally mark the test as blocked, which provides better visibility in test reports and maintains executable code.
For example:
import pytest @pytest.mark.skip(reason="Blocked by issue #830: underlying problem not fully resolved") def test_dummy_kernel_gen(): """Test dummy kernel generation""" kernel = get_buggy_kernel() kernel()Or with xfail if you want it to run but expect failure:
@pytest.mark.xfail(reason="Issue #830 not fully resolved") def test_dummy_kernel_gen(): """Test dummy kernel generation""" kernel = get_buggy_kernel() kernel()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
testing/python/issue/test_tilelang_issue_830.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/issue/test_tilelang_issue_830.py (4)
tilelang/jit/__init__.py (1)
jit(237-310)tilelang/language/kernel.py (1)
threads(195-199)tilelang/language/allocate.py (1)
alloc_shared(21-36)tilelang/language/__init__.py (1)
symbolic(80-91)
🔇 Additional comments (1)
testing/python/issue/test_tilelang_issue_830.py (1)
21-30: Verify whetherget_buggy_kernel1()should be tested or removed.This function is defined but never called in the test suite. If it's intended to test a different aspect of issue #830, consider adding it to
test_dummy_kernel_gen()(when the underlying issue is resolved) or documenting its purpose. If it's not needed, remove it to reduce maintenance burden.
…an/fix_dummy_kernel
|
made an enhancement, now this can be: Before SplitHostDevice
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def buggy():
T.func_attr({"target": T.target({"arch": "sm_90", "host": {"keys": ["cpu"], "kind": "c", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_entry_func": True, "tma_descriptor_args": {}})
T.evaluate(0)
After SplitHostDevice
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def buggy():
T.func_attr({"target": T.target({"arch": "sm_90", "host": {"keys": ["cpu"], "kind": "c", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_entry_func": True, "tma_descriptor_args": {}})
Module.buggy_kernel()
@T.prim_func(private=True)
def buggy_kernel():
T.func_attr({"target": T.target({"arch": "sm_90", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_global_func": True, "tir.noalias": True})
T.evaluate(0)and test import tilelang
import tilelang.testing
import tilelang.language as T
@tilelang.jit
def _empty_kernel():
@T.prim_func
def empty_kernel():
with T.Kernel(1, threads=32) as thread_idx:
A_shared = T.alloc_shared((1,), "float32")
return empty_kernel
def test_empty_kernel_lowering():
kernel = _empty_kernel()
kernel()
if __name__ == "__main__":
test_empty_kernel_lowering()can pass. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (3)
testing/python/issue/test_tilelang_issue_830.py (3)
1-1:ruff: noqasuppresses all lint warnings for the entire file.The
# ruff: noqacomment disables all Ruff checks for this file. If past review comments about unused variables have been addressed, consider removing this blanket suppression and using more targeted suppressions if needed.Consider removing the blanket suppression if all lint issues are resolved:
-# ruff: noqa
8-16: Consider adding a docstring to document the test case.The
_empty_kernelfunction creates a minimal empty kernel for reproduction. Adding a docstring would help explain the purpose and context of this test case, especially its relation to issue #830.@tilelang.jit def _empty_kernel(): + """Create a minimal empty kernel for testing host-device splitting. + + This reproduces the issue where empty kernels failed to lower correctly + before the SplitHostDevice enhancement. See issue #830. + """ @T.prim_func def empty_kernel():
13-14: Prefix intentionally unused variables with underscore or add inline suppression
thread_idxandA_sharedare not referenced in this minimal repro—rename to_thread_idx,_A_sharedor append# noqa: F841to avoid potential lint warnings.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
src/transform/lower_tile_op.cc(1 hunks)src/transform/split_host_device.cc(1 hunks)testing/python/issue/test_tilelang_issue_830.py(1 hunks)tilelang/engine/phase.py(1 hunks)tilelang/transform/__init__.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/transform/lower_tile_op.cc
🧰 Additional context used
🧬 Code graph analysis (4)
tilelang/transform/__init__.py (1)
src/transform/split_host_device.cc (4)
SplitHostDevice(138-158)SplitHostDevice(138-139)SplitHostDevice(162-193)SplitHostDevice(162-162)
testing/python/issue/test_tilelang_issue_830.py (2)
tilelang/language/kernel.py (1)
threads(195-199)tilelang/language/allocate.py (1)
alloc_shared(21-36)
src/transform/split_host_device.cc (2)
tilelang/language/tir/op.py (1)
ret(1758-1771)tilelang/transform/__init__.py (1)
SplitHostDevice(285-293)
tilelang/engine/phase.py (2)
src/transform/split_host_device.cc (4)
SplitHostDevice(138-158)SplitHostDevice(138-139)SplitHostDevice(162-193)SplitHostDevice(162-162)tilelang/transform/__init__.py (1)
SplitHostDevice(285-293)
🪛 GitHub Actions: CI Test on AMD
src/transform/split_host_device.cc
[error] 1-1: Reformatted files. Please review and stage the changes.
🪛 GitHub Actions: CI Test on Metal
src/transform/split_host_device.cc
[error] 1-1: Code formatting changes detected by tile-lang formatter. Please review and stage changes or run the formatter. The step './format.sh' exited with code 1.
🔇 Additional comments (13)
tilelang/engine/phase.py (1)
189-189: LGTM! Correct delegation to TileLang-specific host-device splitting.The change correctly delegates to the new TileLang-specific
SplitHostDeviceimplementation that handles empty kernels, aligning with the PR objective to fix empty kernel lowering.tilelang/transform/__init__.py (1)
285-293: LGTM! Consistent with existing pass accessor pattern.The
SplitHostDevice()function follows the established pattern for pass accessors in this module, with clear documentation and proper FFI delegation.testing/python/issue/test_tilelang_issue_830.py (1)
19-21: LGTM! Test correctly exercises the empty kernel lowering path.The test appropriately invokes the empty kernel to verify the host-device splitting works correctly for empty kernels, which is the core issue being addressed in this PR.
src/transform/split_host_device.cc (10)
43-46: LGTM! Well-designed class structure for host-device splitting.The
HostDeviceSplitterclass appropriately extendstir::StmtMutatorand stores the device module and symbol generator for creating device functions.
48-55: LGTM! Correct detection and handling of device regions.The visitor correctly identifies device regions by checking for the
kTargetattribute and delegates toSplitDeviceFuncfor processing, properly tracking whether a device region was found.
57-61: LGTM! Simple utility methods for forced splitting and status checking.The
ForceSplitmethod provides a way to force device function creation even without explicit device regions, which is essential for handling empty kernels. Thefound_device_regiongetter is appropriately straightforward.
66-84: LGTM! Correct analysis and sorting of device function parameters.The use of
VarUseDefAnalyzerto identify undefined variables (function parameters) is appropriate, and the sorting logic ensures a consistent parameter order (handles before non-handles, then alphabetically by name).
90-101: LGTM! Appropriate error propagation for supported targets.The logic correctly identifies targets that can propagate error codes (CPU, ExtDev, Hexagon) and adds appropriate return handling. For other targets (e.g., CUDA), void return is used, which is correct for device kernels.
103-115: LGTM! Correct device function creation and registration.The device function is properly constructed with buffer declarations, appropriate attributes (target, no_alias, is_global_func), and registered in the device module with a fresh global symbol.
117-129: LGTM! Proper host-side call generation with error checking.The host-side replacement correctly generates either:
- An error-checking sequence (for targets that support error propagation)
- A simple evaluate call (for device-only targets)
This ensures proper error handling where supported while maintaining simplicity for pure device targets.
138-158: LGTM! Critical fix for empty kernel handling.This function correctly implements the fix for issue #830 by forcing a split for empty entry functions with device targets (lines 144-154). The logic appropriately:
- Attempts normal splitting first
- Falls back to
ForceSplitonly when no device region exists AND the function is an entry function with a no-op body- Only modifies the function if changes are necessary
This directly addresses the PR objective of fixing empty kernel lowering.
162-193: LGTM! Correct module pass implementation.The pass correctly:
- Processes all PrimFuncs in the module
- Generates unique kernel names with "_kernel" suffix
- Accumulates changes in separate modules (updates and device_mod)
- Applies updates atomically
- Converts to SSA form before returning
The implementation follows TVM's pass conventions and properly integrates the device functions into the module.
195-198: LGTM! Proper FFI registration for Python access.The static initialization block correctly registers the pass under "tl.transform.SplitHostDevice" using TVM's reflection system, enabling Python access through the API defined in
tilelang/transform/__init__.py.
src/transform/split_host_device.cc
Outdated
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * "License"); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, | ||
| * software distributed under the License is distributed on an | ||
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| * KIND, either express or implied. See the License for the | ||
| * specific language governing permissions and limitations | ||
| * under the License. | ||
| */ | ||
|
|
||
| /*! | ||
| * \file split_host_device.cc | ||
| * \brief Split device function from host. | ||
| */ | ||
| #include <tvm/ffi/function.h> | ||
| #include <tvm/ffi/reflection/registry.h> | ||
| #include <tvm/ir/global_var_supply.h> | ||
| #include <tvm/ir/transform.h> | ||
| #include <tvm/target/target.h> | ||
| #include <tvm/tir/analysis.h> | ||
| #include <tvm/tir/builtin.h> | ||
| #include <tvm/tir/expr.h> | ||
| #include <tvm/tir/op.h> | ||
| #include <tvm/tir/stmt_functor.h> | ||
| #include <tvm/tir/transform.h> | ||
|
|
||
| #include "tir/analysis/var_use_def_analysis.h" | ||
|
|
||
| namespace tvm { | ||
| namespace tl { | ||
|
|
||
| namespace tir = tvm::tir; | ||
|
|
||
| class HostDeviceSplitter : public tir::StmtMutator { | ||
| public: | ||
| explicit HostDeviceSplitter(IRModule* device_mod, std::function<GlobalVar()> var_supply) | ||
| : device_mod_(device_mod), var_supply_(std::move(var_supply)) {} | ||
|
|
||
| tir::Stmt VisitStmt_(const tir::AttrStmtNode* op) final { | ||
| if (op->attr_key == tvm::attr::kTarget) { | ||
| found_device_region_ = true; | ||
| auto device_target = op->node.as<tvm::Target>().value().WithoutHost(); | ||
| return SplitDeviceFunc(op->body, device_target); | ||
| } | ||
| return tir::StmtMutator::VisitStmt_(op); | ||
| } | ||
|
|
||
| tir::Stmt ForceSplit(tir::Stmt body, tvm::Target device_target) { | ||
| return SplitDeviceFunc(std::move(body), std::move(device_target)); | ||
| } | ||
|
|
||
| bool found_device_region() const { return found_device_region_; } | ||
|
|
||
| private: | ||
| bool found_device_region_{false}; | ||
|
|
||
| tir::Stmt SplitDeviceFunc(tir::Stmt body, tvm::Target device_target) { | ||
| auto [params, buffers_to_declare] = | ||
| [&]() -> std::tuple<Array<tir::Var>, Array<tir::Buffer>> { | ||
| tir::VarUseDefAnalyzer use_def(/*defined_vars=*/{}, /*visit_thread_extent=*/true); | ||
| use_def(body); | ||
|
|
||
| // Sort first by variable type, then by variable name | ||
| std::vector<tir::Var> params{use_def.undefined_.begin(), use_def.undefined_.end()}; | ||
| std::sort(params.begin(), params.end(), [](const tir::Var& a, const tir::Var& b) { | ||
| auto sort_key = [](const tir::Var& var) { | ||
| return std::tuple{ | ||
| !var->dtype.is_handle(), | ||
| var->name_hint, | ||
| }; | ||
| }; | ||
| return sort_key(a) < sort_key(b); | ||
| }); | ||
| return {params, use_def.undefined_buffers_}; | ||
| }(); | ||
|
|
||
| // CodeGenCPU is used for some device-side targets, such as | ||
| // "ext_dev", and expects to be able to return a int32_t status | ||
| // code. | ||
|
|
||
| bool can_propagate_errors = [&]() { | ||
| auto kind = device_target->GetTargetDeviceType(); | ||
| return kind == kDLCPU || kind == kDLExtDev || kind == kDLHexagon; | ||
| }(); | ||
| IntImm success(DataType::Int(32), 0); | ||
| Type kernel_ret_type; | ||
| if (can_propagate_errors) { | ||
| kernel_ret_type = PrimType(DataType::Int(32)); | ||
| body = tir::SeqStmt::Flatten(body, tir::Evaluate(ret(success))); | ||
| } else { | ||
| kernel_ret_type = VoidType(); | ||
| } | ||
|
|
||
| for (tir::Buffer buf : buffers_to_declare) { | ||
| body = tir::DeclBuffer(buf, std::move(body)); | ||
| } | ||
| tir::PrimFunc device_func(params, body, kernel_ret_type); | ||
| device_func = WithAttrs( | ||
| std::move(device_func), | ||
| {{tvm::attr::kTarget, device_target}, | ||
| {tir::attr::kNoAlias, true}, | ||
| {tir::attr::kIsGlobalFunc, true}}); | ||
|
|
||
| GlobalVar kernel_symbol_global = var_supply_(); | ||
| (*device_mod_)->Add(kernel_symbol_global, device_func); | ||
| Array<PrimExpr> args = params.Map([](const tir::Var& var) -> PrimExpr { return var; }); | ||
|
|
||
| if (can_propagate_errors) { | ||
| tir::Var kernel_error_code("kernel_error_code", success->dtype); | ||
| tir::Call kernel_call(success->dtype, kernel_symbol_global, args); | ||
| tir::AssertStmt assert_success( | ||
| kernel_error_code == success, tir::StringImm("Error executing compute kernel"), | ||
| tir::Evaluate(0)); | ||
| tir::LetStmt let_check(kernel_error_code, kernel_call, assert_success); | ||
|
|
||
| return let_check; | ||
|
|
||
| } else { | ||
| return tir::Evaluate(tir::Call(DataType::Void(), kernel_symbol_global, args)); | ||
| } | ||
| } | ||
|
|
||
| // target ir module | ||
| IRModule* device_mod_; | ||
| // Generate new GlobalVar for the kernel | ||
| std::function<GlobalVar()> var_supply_; | ||
| }; | ||
|
|
||
| tir::PrimFunc SplitHostDevice(tir::PrimFunc func, IRModule* device_mod, | ||
| std::function<GlobalVar()> var_supply) { | ||
| HostDeviceSplitter splitter(device_mod, std::move(var_supply)); | ||
|
|
||
| if (auto body = splitter(func->body); !body.same_as(func->body)) { | ||
| func.CopyOnWrite()->body = body; | ||
| } else if (!splitter.found_device_region()) { | ||
| if (auto target = func->GetAttr<Target>(tvm::attr::kTarget)) { | ||
| auto device_target = target.value().WithoutHost(); | ||
| if (device_target.defined() && func->HasNonzeroAttr(tir::attr::kIsEntryFunc) && | ||
| tir::is_no_op(func->body)) { | ||
| if (auto forced = splitter.ForceSplit(func->body, device_target); | ||
| !forced.same_as(func->body)) { | ||
| func.CopyOnWrite()->body = forced; | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return func; | ||
| } | ||
|
|
||
| namespace transform { | ||
|
|
||
| tvm::transform::Pass SplitHostDevice() { | ||
| auto pass_func = [](IRModule mod, tvm::transform::PassContext ctx) { | ||
| tvm::GlobalVarSupply global_var_supply(mod); | ||
|
|
||
| IRModule device_mod = IRModule(Map<GlobalVar, BaseFunc>({})); | ||
| IRModule updates = IRModule(Map<GlobalVar, BaseFunc>({})); | ||
|
|
||
| for (const auto& [gvar, base_func] : mod->functions) { | ||
| if (auto opt = base_func.as<tir::PrimFunc>()) { | ||
| tir::PrimFunc func = opt.value(); | ||
|
|
||
| auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol); | ||
| auto name_prefix = global_symbol.value_or(gvar->name_hint); | ||
| auto kernel_name = name_prefix + "_kernel"; | ||
| auto var_supply = [&global_var_supply, &kernel_name]() -> GlobalVar { | ||
| return global_var_supply->FreshGlobal(kernel_name, false); | ||
| }; | ||
|
|
||
| func = ::tvm::tl::SplitHostDevice(std::move(func), &device_mod, var_supply); | ||
| if (!func.same_as(base_func)) { | ||
| updates->Add(gvar, func); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| mod->Update(updates); | ||
| mod->Update(device_mod); | ||
| return tir::transform::ConvertSSA()(mod); | ||
| }; | ||
|
|
||
| return tvm::transform::CreateModulePass(pass_func, 0, "tl.SplitHostDevice", {}); | ||
| } | ||
|
|
||
| TVM_FFI_STATIC_INIT_BLOCK({ | ||
| namespace refl = tvm::ffi::reflection; | ||
| refl::GlobalDef().def("tl.transform.SplitHostDevice", SplitHostDevice); | ||
| }); | ||
|
|
||
| } // namespace transform | ||
| } // namespace tl | ||
| } // namespace tvm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Address code formatting to pass CI checks.
The CI pipeline is failing due to code formatting issues detected by the tile-lang formatter. Please run the formatter to ensure the code adheres to the project's style guidelines.
Run the formatter to fix the issues:
#!/bin/bash
# Run the tile-lang formatter
./format.sh🧰 Tools
🪛 GitHub Actions: CI Test on AMD
[error] 1-1: Reformatted files. Please review and stage the changes.
🪛 GitHub Actions: CI Test on Metal
[error] 1-1: Code formatting changes detected by tile-lang formatter. Please review and stage changes or run the formatter. The step './format.sh' exited with code 1.
🤖 Prompt for AI Agents
In src/transform/split_host_device.cc lines 1-202 the file fails the project's
code-style checks; run the repository formatter and commit the resulting changes
to fix spacing/indentation and line-wrapping so CI passes: execute the provided
formatter script (./format.sh), review the modified file to ensure no logic
changes were introduced, stage and commit the formatted file, and push the
commit so the CI pipeline can re-run.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
src/transform/split_host_device.cc (3)
52-52: Add defensive check for Target cast.The
.value()call will throw if the optional is empty. While the node should always be a Target whenattr_key == kTarget, defensive programming suggests checking first.Apply this diff:
- auto device_target = op->node.as<tvm::Target>().value().WithoutHost(); + auto target_opt = op->node.as<tvm::Target>(); + ICHECK(target_opt.defined()) << "Expected Target node for kTarget attribute"; + auto device_target = target_opt.value().WithoutHost();
147-161: Clarify the empty kernel handling logic.The fallback logic for empty kernels with
ForceSplitis subtle. Consider adding a comment explaining why empty entry functions with device targets need forced splitting.Apply this diff:
if (auto body = splitter(func->body); !body.same_as(func->body)) { func.CopyOnWrite()->body = body; } else if (!splitter.found_device_region()) { + // Empty entry functions with device targets still need host/device + // splitting to properly generate kernel stubs and device modules. if (auto target = func->GetAttr<Target>(tvm::attr::kTarget)) {
169-196: FreshGlobal guarantees unique kernel names
Consider adding a test that runs SplitHostDevice on multiple functions sharing the samename_prefixto verify no name collisions occur.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
src/transform/lower_tile_op.cc(1 hunks)src/transform/split_host_device.cc(1 hunks)testing/python/issue/test_tilelang_issue_830.py(1 hunks)tilelang/engine/phase.py(1 hunks)tilelang/transform/__init__.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- src/transform/lower_tile_op.cc
- testing/python/issue/test_tilelang_issue_830.py
🧰 Additional context used
🧬 Code graph analysis (3)
tilelang/transform/__init__.py (1)
src/transform/split_host_device.cc (4)
SplitHostDevice(143-164)SplitHostDevice(143-144)SplitHostDevice(168-201)SplitHostDevice(168-168)
tilelang/engine/phase.py (2)
src/transform/split_host_device.cc (4)
SplitHostDevice(143-164)SplitHostDevice(143-144)SplitHostDevice(168-201)SplitHostDevice(168-168)tilelang/transform/__init__.py (1)
SplitHostDevice(285-293)
src/transform/split_host_device.cc (2)
tilelang/language/tir/op.py (1)
ret(1758-1771)tilelang/transform/__init__.py (1)
SplitHostDevice(285-293)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-metal
- GitHub Check: build-test-amd
🔇 Additional comments (5)
src/transform/split_host_device.cc (3)
94-105: LGTM! Error propagation logic is well-designed.The conditional error propagation based on device type is correct. CPU, ExtDev, and Hexagon backends can return status codes, while GPU backends use void returns. The success value and return statement injection are properly implemented.
68-88: LGTM! Variable analysis and parameter sorting are correct.The VarUseDefAnalyzer correctly identifies undefined variables (function parameters) and buffers. The sorting logic (non-handles first, then by name) ensures a stable, predictable parameter order for generated kernels.
203-206: LGTM! FFI registration follows TVM conventions.The registration under
tl.transform.SplitHostDeviceproperly exposes the pass to Python via the FFI layer.tilelang/transform/__init__.py (1)
285-293: LGTM! FFI wrapper follows consistent patterns.The SplitHostDevice wrapper correctly delegates to
_ffi_api.SplitHostDevice()and includes an appropriate docstring highlighting the "even for empty kernels" capability.tilelang/engine/phase.py (1)
189-189: LGTM! Correctly switches to TileLang-specific SplitHostDevice.This change properly integrates the new TileLang SplitHostDevice pass (with empty kernel handling) into the compilation pipeline, replacing TVM's default implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (5)
testing/python/issue/test_tilelang_issue_830.py (5)
1-1: Consider more targeted linting suppressions.The blanket
# ruff: noqasuppresses all linting for the file. Since this test intentionally has unused variables to reproduce bugs, consider using targeted suppressions like# ruff: noqa: F841to allow other checks to run.
20-22: Consider adding assertions for more informative test failures.The test calls the kernel but doesn't assert anything about the behavior. Consider adding an assertion to make test failures more informative, such as checking that the kernel executes without raising exceptions.
Apply this diff:
def test_empty_kernel_lowering(): kernel = _empty_kernel() - kernel() + assert kernel() is None # or check expected return value
37-41: Consider adding an assertion.The test could verify that the kernel executes successfully by asserting it doesn't raise an exception or by checking output values.
Apply this diff:
@tilelang.testing.requires_cuda def test_empty_with_dead_code_kernel(): kernel = _empty_with_dead_code_kernel() x = torch.randn((128,), dtype=torch.float32, device="cuda") - kernel(x) + kernel(x) # Should not raise + assert x is not None # Verify input remains valid
44-59: Remove debug print statements.Both kernel variants contain
print(pid)statements that appear to be debug code. These should either be removed or replaced with meaningful assertions if they're intended to verify kernel execution.Apply this diff:
@T.prim_func def kernel_with_tuple_kernel_binding(): with T.Kernel(1, threads=32) as (pid,): - print(pid) pass @T.prim_func def kernel_with_scalar_kernel_binding(): with T.Kernel(1, threads=32) as pid: - print(pid) pass
62-67: Consider adding assertions to verify binding variants work correctly.The test exercises both binding styles but doesn't assert their correctness. Consider adding assertions to make test failures more informative.
Apply this diff:
def test_empty_kernel_with_binding_variants(): kernel = _empty_kernel_with_binding_variants() - kernel() + assert kernel() is None # Scalar binding should work tuple_kernel = _empty_kernel_with_binding_variants(use_tuple_binding=True) - tuple_kernel() + assert tuple_kernel() is None # Tuple binding should work
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/transform/lower_tile_op.cc(1 hunks)testing/python/issue/test_tilelang_issue_830.py(1 hunks)tilelang/engine/phase.py(1 hunks)tilelang/language/kernel.py(4 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- src/transform/lower_tile_op.cc
- tilelang/engine/phase.py
🧰 Additional context used
🧬 Code graph analysis (2)
tilelang/language/kernel.py (1)
tilelang/language/ast/ir.py (2)
iter_var(1616-1639)var(1517-1533)
testing/python/issue/test_tilelang_issue_830.py (4)
src/tl_templates/cuda/reduce.h (1)
T(75-147)tilelang/jit/__init__.py (1)
jit(244-317)tilelang/language/kernel.py (2)
Kernel(229-303)threads(215-219)tilelang/language/__init__.py (1)
symbolic(84-95)
🪛 Ruff (0.13.3)
tilelang/language/kernel.py
22-22: Unused lambda argument: self
(ARG005)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-amd
- GitHub Check: build-test-metal
🔇 Additional comments (9)
testing/python/issue/test_tilelang_issue_830.py (4)
3-6: LGTM!All imports are appropriately used in the test cases.
9-17: LGTM!The empty kernel factory correctly tests basic kernel compilation and launch frame creation.
25-34: LGTM!The dead code kernel correctly reproduces the bug scenario from issue #830 where an assigned but unused variable triggers compilation issues.
70-71: LGTM!The main guard correctly uses the TileLang testing framework's entry point.
tilelang/language/kernel.py (5)
83-92: LGTM!The helper function cleanly normalizes bindings to return a bare Var for single-dimension launches while preserving list semantics for multi-dimensional launches. This enables both
with T.Kernel(n) as bx:andwith T.Kernel(n, m) as (bx, by):unpacking patterns.
118-118: LGTM!Applying
_normalize_bindingsto the CPU path ensures consistent unpacking behavior across both CPU and GPU kernel launches.
122-122: LGTM!Applying
_normalize_bindingsto the non-CPU path completes the symmetric unpacking support for single and multi-dimensional kernel launches.
257-281: LGTM!The updated examples clearly demonstrate the new unpacking behavior for single and multi-dimensional kernel launches. The examples cover 1-D CUDA, 2-D CUDA with multiple thread dimensions, and CPU kernels.
Note: The CPU example shows
as (i,):with a trailing comma. Users might find it more natural to writeas i:for single-dimension CPU kernels, which is now supported.
12-22: Monkey-patching is safe No code paths check kernel bindings’ type aslist; allwith T.Kernelusages remain compatible. The unused-selfwarning is a false positive.
|
Also, fix the if not hasattr(Var, "__iter__"):
def _var_iter(self):
yield self
Var.__iter__ = _var_iter # type: ignore[attr-defined]
if not hasattr(Var, "__len__"):
Var.__len__ = lambda self: 1 # type: ignore[attr-defined]Now, both of these code styles are valid and will return the correct @tilelang.jit
def _empty_kernel_with_binding_variants(use_tuple_binding: bool = False):
@T.prim_func
def kernel_with_tuple_kernel_binding():
with T.Kernel(1, threads=32) as (pid,):
print(pid)
pass
@T.prim_func
def kernel_with_scalar_kernel_binding():
with T.Kernel(1, threads=32) as pid:
print(pid)
pass
return kernel_with_tuple_kernel_binding if use_tuple_binding else kernel_with_scalar_kernel_binding |
|
Local test Pass, we can merge this pull request now. |
This PR fixes part of #830. Also it adds a test for this issue. Since the problem hasn't been completely fixed yet, the test is commented for now.
Summary by CodeRabbit
New Features
Bug Fixes
Tests