Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Oct 13, 2025

This pull request introduces new tests for the ceildiv operation and refactors how intrinsic lowering is handled in the code generation pipeline. It also improves device inference for output tensors in the Cython kernel adapter and fixes a minor bug in result index validation.

Testing improvements:

  • Added a new test file test_tilelang_language_ceildiv.py to validate the ceildiv operation for various input cases, including dynamic kernel invocation and CUDA device support.

Code generation pipeline updates:

  • Changed intrinsic lowering in tilelang/engine/lower.py to use tilelang.transform.LowerIntrin() instead of tir.transform.LowerIntrin() for host and device codegen functions, aligning with custom tilelang transformations. [1] [2] [3]
  • Added a LowerIntrin function to tilelang/transform/__init__.py to expose the new intrinsic lowering pass.

Cython kernel adapter enhancements:

  • Implemented _infer_output_device in CythonKernelWrapper to robustly determine the device for output tensors based on input tensors, improving device assignment logic. [1] [2] [3]

Bug fixes:

  • Corrected result index validation in tilelang/jit/adapter/base.py to properly handle negative indices, ensuring accurate parameter indexing.

Summary by CodeRabbit

  • New Features

    • Target-aware intrinsic lowering pass for improved codegen and math accuracy (div/mod, FMA).
    • Public API to invoke intrinsic lowering as part of the compile pipeline.
    • JIT outputs now infer device from inputs automatically.
  • Bug Fixes

    • Accepts -len(...) as a valid negative index in JIT adapter validation.
    • More reliable device selection to avoid mismatched-device outputs.
  • Tests

    • Added ceildiv tests (static, dynamic, CUDA).
  • Chores

    • More verbose local CI pytest reporting and diagnostics.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 13, 2025

Walkthrough

Adds a C++ intrinsic-lowering transform (IntrinInjecter) that lowers intrinsics to device-specific IR, exposes a LowerIntrin pass via FFI and a Python wrapper, updates engine to call tilelang's pass, relaxes a JIT index bound, adds device inference in the cython wrapper, introduces ceildiv TileLang tests, and adjusts local CI pytest flags.

Changes

Cohort / File(s) Summary of changes
Intrinsic lowering pass (C++ core)
src/transform/lower_intrin.cc
Implements IntrinInjecter (IRMutatorWithAnalyzer) to lower intrinsics via attribute-driven rules: function-call dispatch, optional FMA emission, integer floor-div/mod rewrites (sign/edge-case paths & bitwise shortcuts), SwapBroadcastCast heuristic, and related helpers. Adds Stmt LowerIntrinStmt(Stmt, const std::string&), tvm::transform::Pass LowerIntrin(), and FFI registration.
Transform API
tilelang/transform/__init__.py
Adds public wrapper LowerIntrin() delegating to _ffi_api.LowerIntrin().
Engine integration
tilelang/engine/lower.py
Replaces calls to tir.transform.LowerIntrin() with tilelang.transform.LowerIntrin() in host/device codegen paths.
JIT adapter — base
tilelang/jit/adapter/base.py
Loosens negative-index bound check for result_idx lists: allows -len(params); keeps normalization of negative indices to positive.
JIT adapter — cython wrapper
tilelang/jit/adapter/cython/cython_wrapper.pyx
Adds cdef object _infer_output_device(self, list inputs) to find a torch device among inputs (falls back to current CUDA device). Caches inferred device in forward(...) and uses it when allocating output tensors.
Tests — ceildiv
testing/python/language/test_tilelang_language_ceildiv.py
Adds tests and helpers for ceildiv kernels (static and dynamic), host runners that print kernel source/results, and a CUDA-guarded dynamic test.
CI script
maint/scripts/run_local_ci_test.sh
Expands pytest invocations to include --verbose --color=yes --durations=0 --showlocals --cache-clear for more detailed diagnostics.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant User
  participant Engine as tilelang.engine.lower
  participant Transform as tilelang.transform
  participant FFI as _ffi_api
  participant Cpp as LowerIntrin (C++)

  User->>Engine: trigger build / codegen
  Engine->>Transform: call LowerIntrin()
  Transform->>FFI: _ffi_api.LowerIntrin(PrimFunc)
  FFI->>Cpp: apply IntrinInjecter to PrimFunc.body
  Cpp-->>FFI: transformed PrimFunc
  FFI-->>Transform: return pass
  Transform-->>Engine: updated IRModule
  Engine-->>User: continue pipeline
Loading
sequenceDiagram
  autonumber
  participant Py as CythonKernelWrapper.forward
  participant Inputs as inputs (list)
  participant Infer as _infer_output_device
  participant Kernel as CompiledKernel

  Py->>Py: device = None
  Py->>Infer: scan Inputs for torch.Tensor.device
  Infer-->>Py: torch.device or current_cuda
  loop allocate outputs
    Py->>Py: create output on inferred device
  end
  Py->>Kernel: launch with inputs/outputs
  Kernel-->>Py: results
  Py-->>Inputs: return outputs
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • fix amd tir&add examples #784 — Modifies intrinsic-lowering infrastructure and HIP-specific intrinsic mappings; likely related to intrinsic lowering changes.

Suggested reviewers

  • tzj-fxz

Poem

I nibble through IR with careful cheer,
Lowering intrinsics, the pathways clear.
I sniff devices from tensors near,
Ceildiv kernels hop into gear.
Thump-thump, I stamp this merge—hip hooray! 🐇✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 8.33% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title focuses on migrating the LowerIntrin pass, which is indeed part of the changeset, but it omits significant additions such as ceildiv tests, adapter updates, and bug fixes, making it only partially representative of the PR’s primary objectives.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/jit/adapter/base.py (1)

24-31: Prevent out-of-bounds and fix inconsistent bounds/error message for result_idx

  • Int case currently allows result_idx == len(params), which is OOB. Use >=.
  • Error message claims range [-len(params)-1, len(params)-1], but valid is [-len(params), len(params)-1].
  • List case now correctly allows -len(params); align message too. Optionally enforce int elements.

Apply:

-        elif isinstance(result_idx, int):
-            if result_idx > len(params) or result_idx < -len(params):
-                raise ValueError(
-                    f"result_idx should be an integer between {-len(params) - 1} and {len(params) - 1}"
-                )
+        elif isinstance(result_idx, int):
+            if result_idx >= len(params) or result_idx < -len(params):
+                raise ValueError(
+                    f"result_idx should be an integer between {-len(params)} and {len(params) - 1}"
+                )
             if result_idx < 0:
                 result_idx = len(params) + result_idx
             result_idx = [result_idx]
         elif isinstance(result_idx, list):
+            if not all(isinstance(v, int) for v in result_idx):
+                raise ValueError("result_idx should be a list of integers")
             for i, idx in enumerate(result_idx):
-                if idx >= len(params) or idx < -len(params):
+                if idx >= len(params) or idx < -len(params):
                     raise ValueError(
-                        f"result_idx should be an integer between {-len(params) - 1} and {len(params) - 1}"
+                        f"result_idx should be an integer between {-len(params)} and {len(params) - 1}"
                     )
                 if idx < 0:
                     result_idx[i] = len(params) + idx

Also applies to: 33-39

🧹 Nitpick comments (5)
tilelang/jit/adapter/cython/cython_wrapper.pyx (1)

148-155: Make device inference robust when no tensor inputs

If a kernel has only scalar/pointer inputs, raising is unnecessary. Prefer:

  • Use declared output devices from buffer_device_map (for any result_idx).
  • Else fall back to CUDA current device if available, else CPU.
-    cdef object _infer_output_device(self, list inputs):
-        for tensor in inputs:
-            if isinstance(tensor, torch.Tensor):
-                return tensor.device
-        raise ValueError(
-            "Unable to determine output tensor device: expected at least one torch.Tensor input"
-        )
+    cdef object _infer_output_device(self, list inputs):
+        # Prefer first tensor input
+        for tensor in inputs:
+            if isinstance(tensor, torch.Tensor):
+                return tensor.device
+        # Try declared output devices
+        if self.buffer_device_map is not None and self.result_idx is not None:
+            for _, (buffer_idx, dev) in self.buffer_device_map.items():
+                if buffer_idx in self.result_idx:
+                    return dev
+        # Fallback: current CUDA device or CPU
+        if torch.cuda.is_available():
+            return torch.device("cuda", self.get_current_device())
+        return torch.device("cpu")
src/transform/lower_intrin.cc (2)

49-70: Consider caching attr maps across instances to reduce overhead

Constructor walks multiple Op attr maps on each pass invocation. For many PrimFuncs this adds overhead. Consider static/global caching keyed by (target, mtriple) to reuse attr_maps_ and fma_ safely.


200-213: Minor: fix typos in DLOG messages

“divident” -> “dividend”; “divsor” -> “divisor”.

-      DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident";
+      DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of dividend";
...
-        DLOG(INFO)
-            << "LowerFloorMod: Cannot decide the sign of divsor and divident";
+        DLOG(INFO)
+            << "LowerFloorMod: Cannot decide the sign of divisor and dividend";

Also applies to: 320-333

tilelang/transform/__init__.py (1)

441-445: LGTM; enhance docstring to note target requirement

Mention that functions must have a Target attr (set by BindTarget/OptimizeForTarget), mirroring the pass’s ICHECK.

-def LowerIntrin():
-    """LowerIntrin
-    """
+def LowerIntrin():
+    """Lower intrinsic calls/ops to target-specific IR.
+
+    Note: PrimFuncs must carry a Target attribute (e.g., via BindTarget),
+    otherwise the pass will fail.
+    """
     return _ffi_api.LowerIntrin()  # type: ignore
testing/python/language/test_tilelang_language_ceildiv.py (1)

1-4: LGTM: Test file structure follows conventions.

The test file is well-organized with proper imports, test discovery via tilelang.testing.main(), and appropriate use of @tilelang.testing.requires_cuda for platform-specific tests.

Consider adding module-level docstring to document:

  • Purpose of ceildiv operation
  • Expected behavior for negative operands
  • Reference to the bug being fixed (PR #999)

Also applies to: 58-59

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 340bfc5 and c899ab9.

📒 Files selected for processing (6)
  • src/transform/lower_intrin.cc (1 hunks)
  • testing/python/language/test_tilelang_language_ceildiv.py (1 hunks)
  • tilelang/engine/lower.py (3 hunks)
  • tilelang/jit/adapter/base.py (1 hunks)
  • tilelang/jit/adapter/cython/cython_wrapper.pyx (4 hunks)
  • tilelang/transform/__init__.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
tilelang/engine/lower.py (2)
src/transform/lower_intrin.cc (2)
  • LowerIntrin (434-446)
  • LowerIntrin (434-434)
tilelang/transform/__init__.py (1)
  • LowerIntrin (441-444)
src/transform/lower_intrin.cc (2)
tilelang/language/tir/op.py (6)
  • ret (1758-1771)
  • truncdiv (3020-3043)
  • max_value (1860-1876)
  • truncmod (3046-3069)
  • floordiv (3072-3091)
  • floormod (3094-3113)
tilelang/transform/__init__.py (1)
  • LowerIntrin (441-444)
tilelang/jit/adapter/base.py (1)
tilelang/jit/kernel.py (1)
  • params (466-467)
tilelang/transform/__init__.py (1)
src/transform/lower_intrin.cc (2)
  • LowerIntrin (434-446)
  • LowerIntrin (434-434)
testing/python/language/test_tilelang_language_ceildiv.py (2)
tilelang/jit/__init__.py (1)
  • jit (244-317)
tilelang/jit/adapter/base.py (1)
  • get_kernel_source (51-52)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-metal
  • GitHub Check: format-check
🔇 Additional comments (5)
tilelang/jit/adapter/cython/cython_wrapper.pyx (1)

181-181: LGTM: device cached once and reused for all outputs

Lazy inference and reuse per forward is correct and avoids repeated calls.

Also applies to: 198-200

src/transform/lower_intrin.cc (1)

142-147: Ceildiv safety: good guard against P2 shift simplification

The explicit ceildiv-numerator check prevents rewriting (a + b - 1) / b into shifts. This aligns with the PR objective.

tilelang/engine/lower.py (1)

141-141: Approve LowerIntrin integration
tilelang.transform.LowerIntrin is implemented in Python and registered in C++ (src/transform/lower_intrin.cc:450); host/device pipelines invoke it consistently.

testing/python/language/test_tilelang_language_ceildiv.py (2)

6-14: LGTM: Static kernel definition is well-structured.

The kernel correctly captures a and b at JIT compilation time and uses out_idx=[-1] to indicate the output tensor. The type conversions to T.int32 are appropriate.


31-39: Omitting out_idx on dynamic kernel is intentional: dynamic JIT uses its default (out_idx=None) to exercise automatic output inference.

Comment on lines +17 to +28
def run_ceildiv(a=128, b=32):
kernel = _ceildiv_kernel(a, b)
A = kernel()
print(kernel.get_kernel_source())
print(A)


def test_ceildiv():
run_ceildiv(a=128, b=32)
run_ceildiv(a=1, b=32)
run_ceildiv(a=-1, b=32)
run_ceildiv(a=-2, b=32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add assertions to validate ceildiv correctness.

The test currently only prints results without validating them. This means:

  • Regressions won't be caught automatically
  • The fix mentioned in the PR title ("Do not simplify T.ceildiv(a, 32) into (a + 32 -1) >> 5") cannot be verified programmatically

Add explicit assertions for expected results:

 def run_ceildiv(a=128, b=32):
     kernel = _ceildiv_kernel(a, b)
     A = kernel()
     print(kernel.get_kernel_source())
     print(A)
+    return A
 
 
 def test_ceildiv():
-    run_ceildiv(a=128, b=32)
-    run_ceildiv(a=1, b=32)
-    run_ceildiv(a=-1, b=32)
-    run_ceildiv(a=-2, b=32)
+    import math
+    result = run_ceildiv(a=128, b=32)
+    assert result[0] == math.ceil(128 / 32), f"Expected {math.ceil(128/32)}, got {result[0]}"
+    
+    result = run_ceildiv(a=1, b=32)
+    assert result[0] == math.ceil(1 / 32), f"Expected {math.ceil(1/32)}, got {result[0]}"
+    
+    result = run_ceildiv(a=-1, b=32)
+    assert result[0] == math.ceil(-1 / 32), f"Expected {math.ceil(-1/32)}, got {result[0]}"
+    
+    result = run_ceildiv(a=-2, b=32)
+    assert result[0] == math.ceil(-2 / 32), f"Expected {math.ceil(-2/32)}, got {result[0]}"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def run_ceildiv(a=128, b=32):
kernel = _ceildiv_kernel(a, b)
A = kernel()
print(kernel.get_kernel_source())
print(A)
def test_ceildiv():
run_ceildiv(a=128, b=32)
run_ceildiv(a=1, b=32)
run_ceildiv(a=-1, b=32)
run_ceildiv(a=-2, b=32)
def run_ceildiv(a=128, b=32):
kernel = _ceildiv_kernel(a, b)
A = kernel()
print(kernel.get_kernel_source())
print(A)
return A
def test_ceildiv():
import math
result = run_ceildiv(a=128, b=32)
assert result[0] == math.ceil(128 / 32), f"Expected {math.ceil(128/32)}, got {result[0]}"
result = run_ceildiv(a=1, b=32)
assert result[0] == math.ceil(1 / 32), f"Expected {math.ceil(1/32)}, got {result[0]}"
result = run_ceildiv(a=-1, b=32)
assert result[0] == math.ceil(-1 / 32), f"Expected {math.ceil(-1/32)}, got {result[0]}"
result = run_ceildiv(a=-2, b=32)
assert result[0] == math.ceil(-2 / 32), f"Expected {math.ceil(-2/32)}, got {result[0]}"

def run_ceildiv(a=128, b=32):
kernel = _ceildiv_kernel(a, b)
A = kernel()
print(kernel.get_kernel_source())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Verify the generated kernel doesn't use incorrect bitshift optimization.

The PR title explicitly mentions preventing simplification of T.ceildiv(a, 32) into (a + 32 -1) >> 5. Consider adding assertions to verify the generated kernel source doesn't contain this incorrect pattern.

Example verification:

def run_ceildiv(a=128, b=32):
    kernel = _ceildiv_kernel(a, b)
    A = kernel()
    source = kernel.get_kernel_source()
    print(source)
    
    # Verify the kernel doesn't use bitshift optimization incorrectly
    # The actual verification depends on the expected lowering pattern
    assert "ceildiv" in source or "correct_pattern" in source, \
        "Kernel source should use correct ceildiv lowering"
    
    print(A)
    return A

Also applies to: 46-46

🤖 Prompt for AI Agents
testing/python/language/test_tilelang_language_ceildiv.py around lines 20 and
46: the test currently prints the generated kernel source but does not assert
the kernel avoids the incorrect bitshift optimization (e.g., "(a + 32 - 1) >> 5"
or ">> 5"); update the test to capture kernel.get_kernel_source() into a
variable and add assertions that disallow the bad pattern (assert ">> 5" not in
source or assert "(a + 32 - 1) >> 5" not in source) and assert the expected
correct lowering is present (e.g., assert "ceildiv" in source or assert the
known correct pattern string in source) so the test fails if the kernel uses the
incorrect bitshift simplification.

Comment on lines +42 to +55
def run_ceildiv_dyn(a=128, b=32):
kernel = _ceildiv_kernel_dyn(b)
A = torch.empty((1,), dtype=torch.int32, device="cuda")
kernel(A, a)
print(kernel.get_kernel_source())
print(A)


@tilelang.testing.requires_cuda
def test_ceildiv_dyn():
run_ceildiv_dyn(a=128, b=32)
run_ceildiv_dyn(a=1, b=32)
run_ceildiv_dyn(a=-1, b=32)
run_ceildiv_dyn(a=-2, b=32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add assertions and expand test coverage.

Similar to the static test, this dynamic test lacks assertions for correctness validation.

Additionally, consider testing edge cases:

  • Division by zero behavior
  • Both operands negative
  • Large values that might overflow
  • Zero as dividend

Apply this diff to add assertions:

 def run_ceildiv_dyn(a=128, b=32):
     kernel = _ceildiv_kernel_dyn(b)
     A = torch.empty((1,), dtype=torch.int32, device="cuda")
     kernel(A, a)
     print(kernel.get_kernel_source())
     print(A)
+    return A
 
 
 @tilelang.testing.requires_cuda
 def test_ceildiv_dyn():
-    run_ceildiv_dyn(a=128, b=32)
-    run_ceildiv_dyn(a=1, b=32)
-    run_ceildiv_dyn(a=-1, b=32)
-    run_ceildiv_dyn(a=-2, b=32)
+    import math
+    result = run_ceildiv_dyn(a=128, b=32)
+    assert result[0].item() == math.ceil(128 / 32), f"Expected {math.ceil(128/32)}, got {result[0].item()}"
+    
+    result = run_ceildiv_dyn(a=1, b=32)
+    assert result[0].item() == math.ceil(1 / 32), f"Expected {math.ceil(1/32)}, got {result[0].item()}"
+    
+    result = run_ceildiv_dyn(a=-1, b=32)
+    assert result[0].item() == math.ceil(-1 / 32), f"Expected {math.ceil(-1/32)}, got {result[0].item()}"
+    
+    result = run_ceildiv_dyn(a=-2, b=32)
+    assert result[0].item() == math.ceil(-2 / 32), f"Expected {math.ceil(-2/32)}, got {result[0].item()}"
🤖 Prompt for AI Agents
In testing/python/language/test_tilelang_language_ceildiv.py around lines 42–55,
the dynamic ceildiv test only prints results and lacks correctness checks and
edge-case coverage; update run_ceildiv_dyn to read back the kernel output and
assert it equals the expected Python ceil division result (use integer math
consistent with the kernel semantics), then expand test_ceildiv_dyn to include
assertions for: normal positive inputs (e.g. a=128,b=32), small magnitudes (a=1,
a=-1, a=-2), zero dividend (a=0), both operands negative (e.g. a=-5,b=-2), large
values near int32 limits to detect overflow, and explicit handling of
division-by-zero (call with b=0 and assert expected behavior — either expect an
exception or a defined sentinel/result per implementation); make sure all tensor
reads use CUDA synchronization if needed and compare scalar A.item() to the
computed expected integer result.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c899ab9 and e21d12d.

📒 Files selected for processing (2)
  • maint/scripts/run_local_ci_test.sh (1 hunks)
  • tilelang/jit/adapter/cython/cython_wrapper.pyx (4 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-metal
  • GitHub Check: build-test-amd
🔇 Additional comments (2)
maint/scripts/run_local_ci_test.sh (1)

14-14: LGTM: Enhanced test diagnostics.

The additional pytest flags improve CI observability by providing verbose output, test durations, local variables on failures, and clearing the cache for reproducible runs.

Also applies to: 19-19

tilelang/jit/adapter/cython/cython_wrapper.pyx (1)

179-179: LGTM: Efficient lazy device inference.

The lazy initialization pattern efficiently determines and caches the output device only when needed, avoiding redundant inference for multiple outputs. Once the bug in _infer_output_device (lines 148-152) is fixed, this logic will work correctly.

Also applies to: 195-197

Comment on lines +148 to +152
cdef object _infer_output_device(self, list inputs):
for tensor in inputs:
if isinstance(tensor, torch.Tensor):
return tensor.device
return torch.cuda.current_device()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Type mismatch in fallback device.

Line 152 returns an integer from torch.cuda.current_device(), but line 151 returns a torch.device object. Line 205 expects a torch.device object or device string (e.g., 'cuda:0'), not an integer, so the fallback will fail at runtime with a TypeError.

Additionally, line 152 doesn't check torch.cuda.is_available() before accessing CUDA functions, which will raise an error on CPU-only systems.

Apply this diff to fix both issues:

 cdef object _infer_output_device(self, list inputs):
     for tensor in inputs:
         if isinstance(tensor, torch.Tensor):
             return tensor.device
-    return torch.cuda.current_device()
+    if torch.cuda.is_available():
+        return torch.device('cuda', torch.cuda.current_device())
+    return torch.device('cpu')
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
cdef object _infer_output_device(self, list inputs):
for tensor in inputs:
if isinstance(tensor, torch.Tensor):
return tensor.device
return torch.cuda.current_device()
cdef object _infer_output_device(self, list inputs):
for tensor in inputs:
if isinstance(tensor, torch.Tensor):
return tensor.device
if torch.cuda.is_available():
return torch.device('cuda', torch.cuda.current_device())
return torch.device('cpu')
🤖 Prompt for AI Agents
In tilelang/jit/adapter/cython/cython_wrapper.pyx around lines 148-152, the
fallback returns an integer from torch.cuda.current_device() and doesn't check
CUDA availability; change the function to return a torch.device object
consistently: iterate inputs and return tensor.device if a torch.Tensor is
found, otherwise if torch.cuda.is_available() use torch.cuda.current_device() to
build and return a torch.device for the CUDA index (e.g., torch.device('cuda',
idx) or 'cuda:idx' string), and if CUDA is not available return
torch.device('cpu').

@LeiWang1999
Copy link
Member Author

This is not a simple change, we should rethink about the relationship between floordiv in python and trucdiv in cpp.

  1. What about a const fold in python when we directly invoke T.floordiv(-1, 2)?

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
src/transform/lower_intrin.cc (2)

111-115: Document correctness for negative dividends.

The power-of-2 optimization floordiv(a, 2^n) → a >> n is correct when right shift is arithmetic (sign-extending), but this is implementation-defined in C++ before C++20. While support_bitwise_op_ likely gates this optimization on platforms with known arithmetic right shift behavior, the code lacks a comment explaining why this transformation is valid for negative values of a.

For readers unfamiliar with the property that arithmetic right shift of negative values matches floor division for power-of-2 divisors, this could appear to be a bug.

Consider adding a comment:

     if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) {
-      // lower to right shift if possible.
+      // Lower to right shift. For negative dividends, arithmetic right shift
+      // (sign-extending) matches floordiv semantics for power-of-2 divisors.
+      // support_bitwise_op_ is set only when arithmetic right shift is guaranteed.
       return op->a >> make_const(dtype, shift);
     }

213-218: Document correctness for negative dividends.

Similar to the floordiv optimization, the power-of-2 masking for floormod relies on two's complement representation to produce correct results for negative values. For example, floormod(-7, 4) = 1 and -7 & 3 = 1 in two's complement. While mathematically correct, this is subtle and worth documenting.

Consider adding a comment:

     if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) {
-      // lower to masking if possible.
+      // Lower to masking. Two's complement ensures a & mask produces correct
+      // floormod result for both positive and negative dividends with power-of-2 divisors.
       int64_t mask = (static_cast<int64_t>(1) << static_cast<int64_t>(shift)) - 1;
       return op->a & make_const(dtype, mask);
     }
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e21d12d and 39ff2ca.

📒 Files selected for processing (1)
  • src/transform/lower_intrin.cc (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/lower_intrin.cc (2)
tilelang/language/tir/op.py (5)
  • ret (1758-1771)
  • truncdiv (3020-3043)
  • truncmod (3046-3069)
  • floordiv (3072-3091)
  • floormod (3094-3113)
tilelang/transform/__init__.py (1)
  • LowerIntrin (441-444)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-nvidia
  • GitHub Check: build-test-amd
🔇 Additional comments (1)
src/transform/lower_intrin.cc (1)

126-167: Confirm overflow safety for offset calculation.

No existing overflow checks were found for the expression op->a + op->b * ceildiv. Ensure either:

  • analyzer_ bounds guarantee no overflow for both the multiplication and addition at runtime, or
  • explicit overflow checks/casts to a wider type are added.

@LeiWang1999 LeiWang1999 changed the title [Language] Do not simplify T.ceildiv(a, 32) into (a + 32 -1) >> 5 [Transform] Migrate LowerIntrin from tvm into tilelang Oct 14, 2025
@LeiWang1999 LeiWang1999 merged commit 7a5077e into tile-ai:main Oct 14, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant