- 
        Couldn't load subscription status. 
- Fork 286
          [Transform] Migrate LowerIntrin from tvm into tilelang
          #999
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| 👋 Hi! Thank you for contributing to the TileLang project. Please remember to run  We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 | 
| WalkthroughAdds a C++ intrinsic-lowering transform (IntrinInjecter) that lowers intrinsics to device-specific IR, exposes a LowerIntrin pass via FFI and a Python wrapper, updates engine to call tilelang's pass, relaxes a JIT index bound, adds device inference in the cython wrapper, introduces ceildiv TileLang tests, and adjusts local CI pytest flags. Changes
 Sequence Diagram(s)sequenceDiagram
  autonumber
  participant User
  participant Engine as tilelang.engine.lower
  participant Transform as tilelang.transform
  participant FFI as _ffi_api
  participant Cpp as LowerIntrin (C++)
  User->>Engine: trigger build / codegen
  Engine->>Transform: call LowerIntrin()
  Transform->>FFI: _ffi_api.LowerIntrin(PrimFunc)
  FFI->>Cpp: apply IntrinInjecter to PrimFunc.body
  Cpp-->>FFI: transformed PrimFunc
  FFI-->>Transform: return pass
  Transform-->>Engine: updated IRModule
  Engine-->>User: continue pipeline
sequenceDiagram
  autonumber
  participant Py as CythonKernelWrapper.forward
  participant Inputs as inputs (list)
  participant Infer as _infer_output_device
  participant Kernel as CompiledKernel
  Py->>Py: device = None
  Py->>Infer: scan Inputs for torch.Tensor.device
  Infer-->>Py: torch.device or current_cuda
  loop allocate outputs
    Py->>Py: create output on inferred device
  end
  Py->>Kernel: launch with inputs/outputs
  Kernel-->>Py: results
  Py-->>Inputs: return outputs
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
 Suggested reviewers
 Poem
 Pre-merge checks and finishing touches❌ Failed checks (1 warning)
 ✅ Passed checks (2 passed)
 ✨ Finishing touches
 🧪 Generate unit tests (beta)
 Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️  Outside diff range comments (1)
tilelang/jit/adapter/base.py (1)
24-31: Prevent out-of-bounds and fix inconsistent bounds/error message for result_idx
- Int case currently allows result_idx == len(params), which is OOB. Use >=.
- Error message claims range [-len(params)-1, len(params)-1], but valid is [-len(params), len(params)-1].
- List case now correctly allows -len(params); align message too. Optionally enforce int elements.
Apply:
- elif isinstance(result_idx, int): - if result_idx > len(params) or result_idx < -len(params): - raise ValueError( - f"result_idx should be an integer between {-len(params) - 1} and {len(params) - 1}" - ) + elif isinstance(result_idx, int): + if result_idx >= len(params) or result_idx < -len(params): + raise ValueError( + f"result_idx should be an integer between {-len(params)} and {len(params) - 1}" + ) if result_idx < 0: result_idx = len(params) + result_idx result_idx = [result_idx] elif isinstance(result_idx, list): + if not all(isinstance(v, int) for v in result_idx): + raise ValueError("result_idx should be a list of integers") for i, idx in enumerate(result_idx): - if idx >= len(params) or idx < -len(params): + if idx >= len(params) or idx < -len(params): raise ValueError( - f"result_idx should be an integer between {-len(params) - 1} and {len(params) - 1}" + f"result_idx should be an integer between {-len(params)} and {len(params) - 1}" ) if idx < 0: result_idx[i] = len(params) + idxAlso applies to: 33-39
🧹 Nitpick comments (5)
tilelang/jit/adapter/cython/cython_wrapper.pyx (1)
148-155: Make device inference robust when no tensor inputsIf a kernel has only scalar/pointer inputs, raising is unnecessary. Prefer:
- Use declared output devices from buffer_device_map (for any result_idx).
- Else fall back to CUDA current device if available, else CPU.- cdef object _infer_output_device(self, list inputs): - for tensor in inputs: - if isinstance(tensor, torch.Tensor): - return tensor.device - raise ValueError( - "Unable to determine output tensor device: expected at least one torch.Tensor input" - ) + cdef object _infer_output_device(self, list inputs): + # Prefer first tensor input + for tensor in inputs: + if isinstance(tensor, torch.Tensor): + return tensor.device + # Try declared output devices + if self.buffer_device_map is not None and self.result_idx is not None: + for _, (buffer_idx, dev) in self.buffer_device_map.items(): + if buffer_idx in self.result_idx: + return dev + # Fallback: current CUDA device or CPU + if torch.cuda.is_available(): + return torch.device("cuda", self.get_current_device()) + return torch.device("cpu")
src/transform/lower_intrin.cc (2)
49-70: Consider caching attr maps across instances to reduce overheadConstructor walks multiple Op attr maps on each pass invocation. For many PrimFuncs this adds overhead. Consider static/global caching keyed by (target, mtriple) to reuse attr_maps_ and fma_ safely.
200-213: Minor: fix typos in DLOG messages“divident” -> “dividend”; “divsor” -> “divisor”.
- DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident"; + DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of dividend"; ... - DLOG(INFO) - << "LowerFloorMod: Cannot decide the sign of divsor and divident"; + DLOG(INFO) + << "LowerFloorMod: Cannot decide the sign of divisor and dividend";Also applies to: 320-333
tilelang/transform/__init__.py (1)
441-445: LGTM; enhance docstring to note target requirementMention that functions must have a Target attr (set by BindTarget/OptimizeForTarget), mirroring the pass’s ICHECK.
-def LowerIntrin(): - """LowerIntrin - """ +def LowerIntrin(): + """Lower intrinsic calls/ops to target-specific IR. + + Note: PrimFuncs must carry a Target attribute (e.g., via BindTarget), + otherwise the pass will fail. + """ return _ffi_api.LowerIntrin() # type: ignoretesting/python/language/test_tilelang_language_ceildiv.py (1)
1-4: LGTM: Test file structure follows conventions.The test file is well-organized with proper imports, test discovery via
tilelang.testing.main(), and appropriate use of@tilelang.testing.requires_cudafor platform-specific tests.Consider adding module-level docstring to document:
- Purpose of ceildiv operation
- Expected behavior for negative operands
- Reference to the bug being fixed (PR #999)
Also applies to: 58-59
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
- src/transform/lower_intrin.cc(1 hunks)
- testing/python/language/test_tilelang_language_ceildiv.py(1 hunks)
- tilelang/engine/lower.py(3 hunks)
- tilelang/jit/adapter/base.py(1 hunks)
- tilelang/jit/adapter/cython/cython_wrapper.pyx(4 hunks)
- tilelang/transform/__init__.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
tilelang/engine/lower.py (2)
src/transform/lower_intrin.cc (2)
LowerIntrin(434-446)
LowerIntrin(434-434)tilelang/transform/__init__.py (1)
LowerIntrin(441-444)
src/transform/lower_intrin.cc (2)
tilelang/language/tir/op.py (6)
ret(1758-1771)
truncdiv(3020-3043)
max_value(1860-1876)
truncmod(3046-3069)
floordiv(3072-3091)
floormod(3094-3113)tilelang/transform/__init__.py (1)
LowerIntrin(441-444)
tilelang/jit/adapter/base.py (1)
tilelang/jit/kernel.py (1)
params(466-467)
tilelang/transform/__init__.py (1)
src/transform/lower_intrin.cc (2)
LowerIntrin(434-446)
LowerIntrin(434-434)
testing/python/language/test_tilelang_language_ceildiv.py (2)
tilelang/jit/__init__.py (1)
jit(244-317)tilelang/jit/adapter/base.py (1)
get_kernel_source(51-52)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-metal
- GitHub Check: format-check
🔇 Additional comments (5)
tilelang/jit/adapter/cython/cython_wrapper.pyx (1)
181-181: LGTM: device cached once and reused for all outputsLazy inference and reuse per forward is correct and avoids repeated calls.
Also applies to: 198-200
src/transform/lower_intrin.cc (1)
142-147: Ceildiv safety: good guard against P2 shift simplificationThe explicit ceildiv-numerator check prevents rewriting (a + b - 1) / b into shifts. This aligns with the PR objective.
tilelang/engine/lower.py (1)
141-141: Approve LowerIntrin integration
tilelang.transform.LowerIntrin is implemented in Python and registered in C++ (src/transform/lower_intrin.cc:450); host/device pipelines invoke it consistently.testing/python/language/test_tilelang_language_ceildiv.py (2)
6-14: LGTM: Static kernel definition is well-structured.The kernel correctly captures
aandbat JIT compilation time and usesout_idx=[-1]to indicate the output tensor. The type conversions toT.int32are appropriate.
31-39: Omittingout_idxon dynamic kernel is intentional: dynamic JIT uses its default (out_idx=None) to exercise automatic output inference.
| def run_ceildiv(a=128, b=32): | ||
| kernel = _ceildiv_kernel(a, b) | ||
| A = kernel() | ||
| print(kernel.get_kernel_source()) | ||
| print(A) | ||
|  | ||
|  | ||
| def test_ceildiv(): | ||
| run_ceildiv(a=128, b=32) | ||
| run_ceildiv(a=1, b=32) | ||
| run_ceildiv(a=-1, b=32) | ||
| run_ceildiv(a=-2, b=32) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add assertions to validate ceildiv correctness.
The test currently only prints results without validating them. This means:
- Regressions won't be caught automatically
- The fix mentioned in the PR title ("Do not simplify T.ceildiv(a, 32)into(a + 32 -1) >> 5") cannot be verified programmatically
Add explicit assertions for expected results:
 def run_ceildiv(a=128, b=32):
     kernel = _ceildiv_kernel(a, b)
     A = kernel()
     print(kernel.get_kernel_source())
     print(A)
+    return A
 
 
 def test_ceildiv():
-    run_ceildiv(a=128, b=32)
-    run_ceildiv(a=1, b=32)
-    run_ceildiv(a=-1, b=32)
-    run_ceildiv(a=-2, b=32)
+    import math
+    result = run_ceildiv(a=128, b=32)
+    assert result[0] == math.ceil(128 / 32), f"Expected {math.ceil(128/32)}, got {result[0]}"
+    
+    result = run_ceildiv(a=1, b=32)
+    assert result[0] == math.ceil(1 / 32), f"Expected {math.ceil(1/32)}, got {result[0]}"
+    
+    result = run_ceildiv(a=-1, b=32)
+    assert result[0] == math.ceil(-1 / 32), f"Expected {math.ceil(-1/32)}, got {result[0]}"
+    
+    result = run_ceildiv(a=-2, b=32)
+    assert result[0] == math.ceil(-2 / 32), f"Expected {math.ceil(-2/32)}, got {result[0]}"📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def run_ceildiv(a=128, b=32): | |
| kernel = _ceildiv_kernel(a, b) | |
| A = kernel() | |
| print(kernel.get_kernel_source()) | |
| print(A) | |
| def test_ceildiv(): | |
| run_ceildiv(a=128, b=32) | |
| run_ceildiv(a=1, b=32) | |
| run_ceildiv(a=-1, b=32) | |
| run_ceildiv(a=-2, b=32) | |
| def run_ceildiv(a=128, b=32): | |
| kernel = _ceildiv_kernel(a, b) | |
| A = kernel() | |
| print(kernel.get_kernel_source()) | |
| print(A) | |
| return A | |
| def test_ceildiv(): | |
| import math | |
| result = run_ceildiv(a=128, b=32) | |
| assert result[0] == math.ceil(128 / 32), f"Expected {math.ceil(128/32)}, got {result[0]}" | |
| result = run_ceildiv(a=1, b=32) | |
| assert result[0] == math.ceil(1 / 32), f"Expected {math.ceil(1/32)}, got {result[0]}" | |
| result = run_ceildiv(a=-1, b=32) | |
| assert result[0] == math.ceil(-1 / 32), f"Expected {math.ceil(-1/32)}, got {result[0]}" | |
| result = run_ceildiv(a=-2, b=32) | |
| assert result[0] == math.ceil(-2 / 32), f"Expected {math.ceil(-2/32)}, got {result[0]}" | 
| def run_ceildiv(a=128, b=32): | ||
| kernel = _ceildiv_kernel(a, b) | ||
| A = kernel() | ||
| print(kernel.get_kernel_source()) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Verify the generated kernel doesn't use incorrect bitshift optimization.
The PR title explicitly mentions preventing simplification of T.ceildiv(a, 32) into (a + 32 -1) >> 5. Consider adding assertions to verify the generated kernel source doesn't contain this incorrect pattern.
Example verification:
def run_ceildiv(a=128, b=32):
    kernel = _ceildiv_kernel(a, b)
    A = kernel()
    source = kernel.get_kernel_source()
    print(source)
    
    # Verify the kernel doesn't use bitshift optimization incorrectly
    # The actual verification depends on the expected lowering pattern
    assert "ceildiv" in source or "correct_pattern" in source, \
        "Kernel source should use correct ceildiv lowering"
    
    print(A)
    return AAlso applies to: 46-46
🤖 Prompt for AI Agents
testing/python/language/test_tilelang_language_ceildiv.py around lines 20 and
46: the test currently prints the generated kernel source but does not assert
the kernel avoids the incorrect bitshift optimization (e.g., "(a + 32 - 1) >> 5"
or ">> 5"); update the test to capture kernel.get_kernel_source() into a
variable and add assertions that disallow the bad pattern (assert ">> 5" not in
source or assert "(a + 32 - 1) >> 5" not in source) and assert the expected
correct lowering is present (e.g., assert "ceildiv" in source or assert the
known correct pattern string in source) so the test fails if the kernel uses the
incorrect bitshift simplification.
| def run_ceildiv_dyn(a=128, b=32): | ||
| kernel = _ceildiv_kernel_dyn(b) | ||
| A = torch.empty((1,), dtype=torch.int32, device="cuda") | ||
| kernel(A, a) | ||
| print(kernel.get_kernel_source()) | ||
| print(A) | ||
|  | ||
|  | ||
| @tilelang.testing.requires_cuda | ||
| def test_ceildiv_dyn(): | ||
| run_ceildiv_dyn(a=128, b=32) | ||
| run_ceildiv_dyn(a=1, b=32) | ||
| run_ceildiv_dyn(a=-1, b=32) | ||
| run_ceildiv_dyn(a=-2, b=32) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add assertions and expand test coverage.
Similar to the static test, this dynamic test lacks assertions for correctness validation.
Additionally, consider testing edge cases:
- Division by zero behavior
- Both operands negative
- Large values that might overflow
- Zero as dividend
Apply this diff to add assertions:
 def run_ceildiv_dyn(a=128, b=32):
     kernel = _ceildiv_kernel_dyn(b)
     A = torch.empty((1,), dtype=torch.int32, device="cuda")
     kernel(A, a)
     print(kernel.get_kernel_source())
     print(A)
+    return A
 
 
 @tilelang.testing.requires_cuda
 def test_ceildiv_dyn():
-    run_ceildiv_dyn(a=128, b=32)
-    run_ceildiv_dyn(a=1, b=32)
-    run_ceildiv_dyn(a=-1, b=32)
-    run_ceildiv_dyn(a=-2, b=32)
+    import math
+    result = run_ceildiv_dyn(a=128, b=32)
+    assert result[0].item() == math.ceil(128 / 32), f"Expected {math.ceil(128/32)}, got {result[0].item()}"
+    
+    result = run_ceildiv_dyn(a=1, b=32)
+    assert result[0].item() == math.ceil(1 / 32), f"Expected {math.ceil(1/32)}, got {result[0].item()}"
+    
+    result = run_ceildiv_dyn(a=-1, b=32)
+    assert result[0].item() == math.ceil(-1 / 32), f"Expected {math.ceil(-1/32)}, got {result[0].item()}"
+    
+    result = run_ceildiv_dyn(a=-2, b=32)
+    assert result[0].item() == math.ceil(-2 / 32), f"Expected {math.ceil(-2/32)}, got {result[0].item()}"🤖 Prompt for AI Agents
In testing/python/language/test_tilelang_language_ceildiv.py around lines 42–55,
the dynamic ceildiv test only prints results and lacks correctness checks and
edge-case coverage; update run_ceildiv_dyn to read back the kernel output and
assert it equals the expected Python ceil division result (use integer math
consistent with the kernel semantics), then expand test_ceildiv_dyn to include
assertions for: normal positive inputs (e.g. a=128,b=32), small magnitudes (a=1,
a=-1, a=-2), zero dividend (a=0), both operands negative (e.g. a=-5,b=-2), large
values near int32 limits to detect overflow, and explicit handling of
division-by-zero (call with b=0 and assert expected behavior — either expect an
exception or a defined sentinel/result per implementation); make sure all tensor
reads use CUDA synchronization if needed and compare scalar A.item() to the
computed expected integer result.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
- maint/scripts/run_local_ci_test.sh(1 hunks)
- tilelang/jit/adapter/cython/cython_wrapper.pyx(4 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-metal
- GitHub Check: build-test-amd
🔇 Additional comments (2)
maint/scripts/run_local_ci_test.sh (1)
14-14: LGTM: Enhanced test diagnostics.The additional pytest flags improve CI observability by providing verbose output, test durations, local variables on failures, and clearing the cache for reproducible runs.
Also applies to: 19-19
tilelang/jit/adapter/cython/cython_wrapper.pyx (1)
179-179: LGTM: Efficient lazy device inference.The lazy initialization pattern efficiently determines and caches the output device only when needed, avoiding redundant inference for multiple outputs. Once the bug in
_infer_output_device(lines 148-152) is fixed, this logic will work correctly.Also applies to: 195-197
| cdef object _infer_output_device(self, list inputs): | ||
| for tensor in inputs: | ||
| if isinstance(tensor, torch.Tensor): | ||
| return tensor.device | ||
| return torch.cuda.current_device() | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Type mismatch in fallback device.
Line 152 returns an integer from torch.cuda.current_device(), but line 151 returns a torch.device object. Line 205 expects a torch.device object or device string (e.g., 'cuda:0'), not an integer, so the fallback will fail at runtime with a TypeError.
Additionally, line 152 doesn't check torch.cuda.is_available() before accessing CUDA functions, which will raise an error on CPU-only systems.
Apply this diff to fix both issues:
 cdef object _infer_output_device(self, list inputs):
     for tensor in inputs:
         if isinstance(tensor, torch.Tensor):
             return tensor.device
-    return torch.cuda.current_device()
+    if torch.cuda.is_available():
+        return torch.device('cuda', torch.cuda.current_device())
+    return torch.device('cpu')📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| cdef object _infer_output_device(self, list inputs): | |
| for tensor in inputs: | |
| if isinstance(tensor, torch.Tensor): | |
| return tensor.device | |
| return torch.cuda.current_device() | |
| cdef object _infer_output_device(self, list inputs): | |
| for tensor in inputs: | |
| if isinstance(tensor, torch.Tensor): | |
| return tensor.device | |
| if torch.cuda.is_available(): | |
| return torch.device('cuda', torch.cuda.current_device()) | |
| return torch.device('cpu') | 
🤖 Prompt for AI Agents
In tilelang/jit/adapter/cython/cython_wrapper.pyx around lines 148-152, the
fallback returns an integer from torch.cuda.current_device() and doesn't check
CUDA availability; change the function to return a torch.device object
consistently: iterate inputs and return tensor.device if a torch.Tensor is
found, otherwise if torch.cuda.is_available() use torch.cuda.current_device() to
build and return a torch.device for the CUDA index (e.g., torch.device('cuda',
idx) or 'cuda:idx' string), and if CUDA is not available return
torch.device('cpu').
| This is not a simple change, we should rethink about the relationship between floordiv in python and trucdiv in cpp. 
 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
src/transform/lower_intrin.cc (2)
111-115: Document correctness for negative dividends.The power-of-2 optimization
floordiv(a, 2^n) → a >> nis correct when right shift is arithmetic (sign-extending), but this is implementation-defined in C++ before C++20. Whilesupport_bitwise_op_likely gates this optimization on platforms with known arithmetic right shift behavior, the code lacks a comment explaining why this transformation is valid for negative values ofa.For readers unfamiliar with the property that arithmetic right shift of negative values matches floor division for power-of-2 divisors, this could appear to be a bug.
Consider adding a comment:
if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) { - // lower to right shift if possible. + // Lower to right shift. For negative dividends, arithmetic right shift + // (sign-extending) matches floordiv semantics for power-of-2 divisors. + // support_bitwise_op_ is set only when arithmetic right shift is guaranteed. return op->a >> make_const(dtype, shift); }
213-218: Document correctness for negative dividends.Similar to the floordiv optimization, the power-of-2 masking for floormod relies on two's complement representation to produce correct results for negative values. For example,
floormod(-7, 4) = 1and-7 & 3 = 1in two's complement. While mathematically correct, this is subtle and worth documenting.Consider adding a comment:
if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) { - // lower to masking if possible. + // Lower to masking. Two's complement ensures a & mask produces correct + // floormod result for both positive and negative dividends with power-of-2 divisors. int64_t mask = (static_cast<int64_t>(1) << static_cast<int64_t>(shift)) - 1; return op->a & make_const(dtype, mask); }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
- src/transform/lower_intrin.cc(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/lower_intrin.cc (2)
tilelang/language/tir/op.py (5)
ret(1758-1771)
truncdiv(3020-3043)
truncmod(3046-3069)
floordiv(3072-3091)
floormod(3094-3113)tilelang/transform/__init__.py (1)
LowerIntrin(441-444)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-nvidia
- GitHub Check: build-test-amd
🔇 Additional comments (1)
src/transform/lower_intrin.cc (1)
126-167: Confirm overflow safety for offset calculation.No existing overflow checks were found for the expression
op->a + op->b * ceildiv. Ensure either:
- analyzer_ bounds guarantee no overflow for both the multiplication and addition at runtime, or
- explicit overflow checks/casts to a wider type are added.
T.ceildiv(a, 32) into (a + 32 -1) >> 5LowerIntrin from tvm into tilelang
      
This pull request introduces new tests for the
ceildivoperation and refactors how intrinsic lowering is handled in the code generation pipeline. It also improves device inference for output tensors in the Cython kernel adapter and fixes a minor bug in result index validation.Testing improvements:
test_tilelang_language_ceildiv.pyto validate theceildivoperation for various input cases, including dynamic kernel invocation and CUDA device support.Code generation pipeline updates:
tilelang/engine/lower.pyto usetilelang.transform.LowerIntrin()instead oftir.transform.LowerIntrin()for host and device codegen functions, aligning with custom tilelang transformations. [1] [2] [3]LowerIntrinfunction totilelang/transform/__init__.pyto expose the new intrinsic lowering pass.Cython kernel adapter enhancements:
_infer_output_deviceinCythonKernelWrapperto robustly determine the device for output tensors based on input tensors, improving device assignment logic. [1] [2] [3]Bug fixes:
tilelang/jit/adapter/base.pyto properly handle negative indices, ensuring accurate parameter indexing.Summary by CodeRabbit
New Features
Bug Fixes
Tests
Chores