Skip to content

Conversation

@chengyupku
Copy link

@chengyupku chengyupku commented Sep 25, 2025

This pull request introduces a new distributed all-gather GEMM example and adds key utilities for peer-to-peer signaling and tensor allocation in distributed settings. The changes improve support for distributed GPU computations, particularly for multi-process setups, and enhance the flexibility of tensor allocation and communication primitives.

New distributed example and utilities

  • Added a new example script example_allgather_gemm_ipc.py that demonstrates distributed all-gather followed by GEMM computation using custom signaling and buffer sharing, including a correctness check against PyTorch's implementation.
  • Introduced set_signal and wait_eq functions in tilelang/distributed/utils.py for efficient peer-to-peer signaling between CUDA streams using atomic operations.

Tensor allocation and API improvements

  • Updated the allocator logic in tilelang/utils/allocator.py and the tensor API in tilelang/utils/tensor.py to support allocation of peer tensors across distributed ranks via a new return_peers argument. This enables returning a list of tensors, one per peer, for distributed memory access. [1] [2] [3]
  • Fixed pointer assignment in the allocator to use .value for correct memory handling.

Miscellaneous

  • Added missing imports and type hints to support new features and maintain code consistency.
  • Minor cleanup in example_all_to_all.py by removing an unused import.

Summary by CodeRabbit

  • New Features

    • Added a distributed all‑gather GEMM example with validation and a CUDA‑IPC style workflow.
    • New signaling/wait utilities and a helper to query max CUDA stream priority.
    • Allocator/tensor APIs can optionally return per‑peer tensors for multi‑GPU workflows.
  • Documentation

    • Consolidated package‑level imports across docs and examples; simplified reference implementations.
  • Tests

    • CI runs distributed and non‑distributed tests in separate passes.
  • Chores

    • Added ninja and cuda‑python to test/dev requirements; adjusted native‑extension build/setup handling and build scripts.

@coderabbitai
Copy link

coderabbitai bot commented Sep 25, 2025

Important

Review skipped

Review was skipped as selected files did not have any reviewable changes.

💤 Files selected but had no reviewable changes (2)
  • tilelang/language/distributed/init.py
  • tilelang/language/distributed/multi_device/init.py

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Walkthrough

Adds an all-gather GEMM IPC example, re-exports IPC helpers and distributed utilities at package root, adds signaling/stream helpers, extends allocator/tensor APIs to optionally return per-peer tensors, moves PyBind extension wiring into setup.py (removing CMake custom targets), updates many import paths and small example cleanups, and splits CI into distributed vs non-distributed test passes.

Changes

Cohort / File(s) Summary
New example: All-Gather GEMM with IPC
examples/distributed/example_allgather_gemm_ipc.py
New example implementing tiled/pipelined GEMM and intra-node IPC-like all-gather producer and orchestration (barriers/streams), Torch reference, CLI and multiprocessing spawn; exports kernel and orchestration helpers.
Distributed package re-exports & utils
tilelang/distributed/__init__.py, tilelang/distributed/utils.py, tilelang/distributed/common/README.md
Re-export IPC helpers at package root; add set_signal, wait_eq, cuda_stream_max_priority; update docs to import distributed symbols from package root and adjust internal ipc_ext import paths.
Allocator & Tensor API changes
tilelang/utils/allocator.py, tilelang/utils/tensor.py
_allocate_tensor(..., return_peers=False) added and may return per-peer tensor list; tensor(..., return_peers=None) forwards flag and can return list[Tensor]; minor pointer handling adjustment after cudaMalloc.
Examples / Benchmarks / Tests — import updates & small cleanups
examples/distributed/*, examples/distributed/primitives/*, benchmark/distributed/*, benchmark/distributed/ipc_impls/*, tilelang/distributed/testing/sync/test_barrierall_sys.py
Replace imports from tilelang.distributed.utilstilelang.distributed; remove unused imports/calls (e.g., os, tilelang.disable_cache()), simplify loops by using unused _ where index is unused, drop some local rank retrievals, and remove storing of verification booleans.
Build & packaging: PyBind extension flow moved to setup.py
setup.py, CMakeLists.txt, requirements-test.txt, requirements-dev.txt
Move editable file-based PyBind extension entries for alloc_cuda and ipc_ext into setup.py; remove corresponding CMake custom targets from CMakeLists.txt; add ninja==1.10.0 and cuda-python>=12.0.0 to test/dev requirements.
CI: test selection split
.github/workflows/ci.yml
CI now runs two pytest passes per job: distributed tests with TILELANG_USE_DISTRIBUTED=1 and non-distributed tests, each with parallel pytest -n 4 and informative messages when no tests are found.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant User
  participant Script as example_allgather_gemm_ipc.py
  participant Dist as tilelang.distributed
  participant Alloc as Allocator
  participant IPC as IPC Engine (cp_engine)
  participant GEMM as gemm_kernel
  participant Torch as torch.distributed

  User->>Script: main(args)
  Script->>Dist: init_distributed(...)
  Script->>Alloc: allocate A,B,C, ag_buffer (return_peers?)
  Note over Alloc: ag_buffer may return local + per-peer tensors
  Script->>Dist: set_signal / wait_eq (barriers)
  Script->>IPC: cp_engine_producer_all_gather_put(local, ag_buffer, signal...)
  IPC-->>Script: peer signals
  Script->>GEMM: launch tiled GEMM (A_gathered, B)
  GEMM-->>Script: writes C
  par Reference
    Script->>Torch: all_gather_into_tensor(...)
    Torch-->>Script: A_ref
    Script->>Torch: matmul(A_ref, B)
  end
  Script-->>User: validate results
Loading
sequenceDiagram
  autonumber
  participant Caller
  participant TensorAPI as tilelang.utils.tensor.tensor
  participant Alloc as BaseAllocator._allocate_tensor

  Caller->>TensorAPI: tensor(shape, dtype, allocator, return_peers=True)
  TensorAPI->>Alloc: _allocate_tensor(shape, dtype, return_peers=True)
  Alloc-->>TensorAPI: list[Tensor] (local + per-peer views)
  TensorAPI-->>Caller: list[Tensor]
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • tzj-fxz

Poem

I hop through ranks with tiny feet,
Buffers hum and signals meet;
Streams align, peers share their part,
Kernels dance and multiply the heart.
We gather, compute, and cheer—GEMM complete. 🥕

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly captures the two primary contributions of the pull request by indicating both the addition of return_peers support in _allocate_tensor and the introduction of the ag_gemm_ipc example, making it clear to reviewers what key features are being added without extraneous detail.
Docstring Coverage ✅ Passed No functions found in the changes. Docstring coverage check skipped.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/utils/allocator.py (1)

172-201: Prevent invalid cudaFree on suballocations when take_ownership=True

If take_ownership=True on a suballocation, the tensor deleter will call cudaFree on a non-base pointer, which is undefined and can crash. Either disallow take_ownership for suballocations or require the allocation to span the entire backing buffer.

Apply this diff to guard against unsafe ownership:

         numel = _prod_shape(shape)
         itemsize = _element_size_bytes(dtype)
         bytes_needed = numel * itemsize

         bytes_alloc = _align_up(bytes_needed, self._align)

         current_offset = int(self._ptr.value) - int(self._base_ptr.value)
+        # Disallow taking ownership for suballocations: cudaFree must only free the base pointer.
+        if take_ownership and not (current_offset == 0 and bytes_alloc == self.size):
+            raise NotImplementedError(
+                "take_ownership=True is only supported when the allocation spans the entire pre-allocated buffer."
+            )
🧹 Nitpick comments (5)
tilelang/utils/allocator.py (1)

166-171: Return type should reflect multi-tensor return when return_peers=True

The method can return either a single tensor or a list; update the type hint to avoid confusion and downstream type issues.

Apply this diff:

-    def _allocate_tensor(self,
-                         shape: Tuple[int, ...],
-                         dtype: torch.dtype,
-                         return_peers=False,
-                         take_ownership: bool = False) -> torch.Tensor:
+    def _allocate_tensor(self,
+                         shape: Tuple[int, ...],
+                         dtype: torch.dtype,
+                         return_peers: bool = False,
+                         take_ownership: bool = False) -> Union[torch.Tensor, list[torch.Tensor]]:
tilelang/distributed/utils.py (2)

268-282: Support int64 signals and raise specific exceptions

Broaden dtype support and avoid generic Exception. This also addresses static analysis warnings.

Apply this diff:

-def set_signal(signal_tensor: torch.Tensor,
-               signal: int,
-               stream: Optional[torch.cuda.Stream] = None):
+def set_signal(signal_tensor: torch.Tensor,
+               signal: int,
+               stream: Optional[torch.cuda.Stream] = None):
     stream = stream or torch.cuda.current_stream()
     if signal_tensor.dtype == torch.int32:
         (err,) = cuda.cuStreamWriteValue32(
             stream.cuda_stream,
             signal_tensor.data_ptr(),
             signal,
             cuda.CUstreamWriteValue_flags.CU_STREAM_WRITE_VALUE_DEFAULT,
         )
         CUDA_CHECK(err)
-    else:
-        raise Exception(f"Unsupported signal dtype {signal_tensor.dtype}")
+    elif signal_tensor.dtype == torch.int64:
+        (err,) = cuda.cuStreamWriteValue64(
+            stream.cuda_stream,
+            signal_tensor.data_ptr(),
+            signal,
+            cuda.CUstreamWriteValue_flags.CU_STREAM_WRITE_VALUE_DEFAULT,
+        )
+        CUDA_CHECK(err)
+    else:
+        raise TypeError(f"Unsupported signal dtype {signal_tensor.dtype}")

284-299: Honor require_i64 and add int64 wait path; avoid generic Exception

Implement the 64-bit wait when requested or when the tensor is int64. Use a specific exception type.

Apply this diff:

-def wait_eq(signal_tensor: torch.Tensor,
-            signal: int,
-            stream: Optional[torch.cuda.Stream] = None,
-            require_i64=False):
+def wait_eq(signal_tensor: torch.Tensor,
+            signal: int,
+            stream: Optional[torch.cuda.Stream] = None,
+            require_i64=False):
     stream = stream or torch.cuda.current_stream()
-    if signal_tensor.dtype == torch.int32:
+    if signal_tensor.dtype == torch.int32 and not require_i64:
         (err,) = cuda.cuStreamWaitValue32(
             stream.cuda_stream,
             signal_tensor.data_ptr(),
             signal,
             cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ,
         )
         CUDA_CHECK(err)
-    else:
-        raise Exception(f"Unsupported signal dtype {signal_tensor.dtype}")
+    elif signal_tensor.dtype == torch.int64 or require_i64:
+        (err,) = cuda.cuStreamWaitValue64(
+            stream.cuda_stream,
+            signal_tensor.data_ptr(),
+            signal,
+            cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ,
+        )
+        CUDA_CHECK(err)
+    else:
+        raise TypeError(f"Unsupported signal dtype {signal_tensor.dtype}")
examples/distributed/example_allgather_gemm_ipc.py (2)

94-96: Clarify ag op signature: rename N to K to match semantics

The arg forwarded to the copy path is K. Rename for correctness.

Apply this diff:

-def ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, N, signal_target, rank, group,
+def ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank, group,
                local_world_size, world_size, gemm_kernel, ag_stream):

Also update the call site below.


99-103: Forward the renamed K to the producer

Keep parameter usage consistent with the rename.

Apply this diff:

-    with torch.cuda.stream(ag_stream):
-        cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, N, signal_target,
-                                          rank, local_world_size, world_size, ag_stream)
+    with torch.cuda.stream(ag_stream):
+        cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, K, signal_target,
+                                          rank, local_world_size, world_size, ag_stream)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2fbbd76 and 9de9e1b.

📒 Files selected for processing (5)
  • examples/distributed/example_all_to_all.py (0 hunks)
  • examples/distributed/example_allgather_gemm_ipc.py (1 hunks)
  • tilelang/distributed/utils.py (2 hunks)
  • tilelang/utils/allocator.py (3 hunks)
  • tilelang/utils/tensor.py (1 hunks)
💤 Files with no reviewable changes (1)
  • examples/distributed/example_all_to_all.py
🧰 Additional context used
🧬 Code graph analysis (3)
tilelang/utils/tensor.py (2)
tilelang/utils/allocator.py (4)
  • BaseAllocator (71-234)
  • initialized (163-164)
  • device (114-115)
  • _allocate_tensor (166-218)
tilelang/utils/target.py (1)
  • parse_device (87-100)
tilelang/utils/allocator.py (1)
tilelang/utils/cpp/tensor_from_ptr.cpp (2)
  • tensor_from_ptr (46-94)
  • tensor_from_ptr (46-48)
examples/distributed/example_allgather_gemm_ipc.py (5)
tilelang/distributed/utils.py (3)
  • init_dist (34-56)
  • set_signal (268-281)
  • wait_eq (284-298)
tilelang/env.py (1)
  • disable_cache (247-248)
tilelang/utils/allocator.py (2)
  • get_allocator (237-249)
  • device (114-115)
tilelang/jit/kernel.py (1)
  • initialize (400-409)
tilelang/utils/tensor.py (1)
  • tensor (45-58)
🪛 Ruff (0.13.1)
tilelang/distributed/utils.py

281-281: Create your own exception

(TRY002)


281-281: Avoid specifying long messages outside the exception class

(TRY003)


287-287: Unused function argument: require_i64

(ARG001)


298-298: Create your own exception

(TRY002)


298-298: Avoid specifying long messages outside the exception class

(TRY003)

examples/distributed/example_allgather_gemm_ipc.py

83-83: Unpacked variable err is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


184-184: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: bot-task
🔇 Additional comments (2)
tilelang/utils/allocator.py (1)

128-128: Correct fix: assign pointer via .value

Copying the device pointer value into _ptr is correct and fixes the previous aliasing bug.

tilelang/utils/tensor.py (1)

56-57: Forwarding return_peers — LGTM

Propagating return_peers to the allocator is correct and aligns the public API with allocator behavior.

Comment on lines 8 to 11
from tilelang.distributed.utils import init_dist
from cuda import cudart
from tilelang.distributed.utils import set_signal, wait_eq

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Make cudart import compatible with cuda-python >= 12.8 and import CUDA_CHECK

Directly importing from cuda import cudart breaks on newer cuda-python. Use the same version-gated import as utils and import CUDA_CHECK for error handling.

Apply this diff:

-from tilelang.distributed.utils import init_dist
-from cuda import cudart
-from tilelang.distributed.utils import set_signal, wait_eq
+from tilelang.distributed.utils import init_dist
+import importlib.metadata
+from packaging import version
+cuda_python_version = importlib.metadata.version("cuda-python")
+if version.parse(cuda_python_version) >= version.parse("12.8.0"):
+    from cuda.bindings import runtime as cudart
+else:
+    from cuda import cudart
+from tilelang.distributed.utils import set_signal, wait_eq, CUDA_CHECK
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from tilelang.distributed.utils import init_dist
from cuda import cudart
from tilelang.distributed.utils import set_signal, wait_eq
from tilelang.distributed.utils import init_dist
import importlib.metadata
from packaging import version
cuda_python_version = importlib.metadata.version("cuda-python")
if version.parse(cuda_python_version) >= version.parse("12.8.0"):
from cuda.bindings import runtime as cudart
else:
from cuda import cudart
from tilelang.distributed.utils import set_signal, wait_eq, CUDA_CHECK
🤖 Prompt for AI Agents
In examples/distributed/example_allgather_gemm_ipc.py around lines 8 to 11, the
direct "from cuda import cudart" import breaks on cuda-python >= 12.8; update
the imports to match the version-gated pattern used in
tilelang.distributed.utils (perform a conditional import that tries the new
module path and falls back to the old one) and also import CUDA_CHECK for error
handling, then replace usages of cudart error checks to use CUDA_CHECK where
appropriate.

Comment on lines +33 to +44
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix grid size: using N instead of N_per_rank causes out-of-bounds on B/C

Kernel grid must reflect local N dimension per rank. Using global N overshoots.

Apply this diff:

-        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
+        with T.Kernel(T.ceildiv(N // num_rank, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
with T.Kernel(T.ceildiv(N // num_rank, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
🤖 Prompt for AI Agents
In examples/distributed/example_allgather_gemm_ipc.py around lines 33 to 44 the
kernel grid uses the global N which causes out-of-bounds accesses for B/C on
each rank; change the x-dimension grid calculation to use the per-rank local
width (replace T.ceildiv(N, block_N) with T.ceildiv(N_per_rank, block_N)) and
ensure any indexing/limits that assume the x-dimension (reads from B and writes
to C) are bounded by N_per_rank so the kernel only addresses the local slice.

Comment on lines +48 to +51
def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, N,
signal_target, rank, local_world_size, world_size,
intranode_ag_stream):
local_rank = rank % local_world_size
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Parameter naming bug: N here is actually K (second dim of A)

Misnaming invites misuse. Rename to K and propagate to all uses.

Apply this diff:

-def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, N,
+def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, K,
                                       signal_target, rank, local_world_size, world_size,
                                       intranode_ag_stream):

Follow-up diffs below update its uses.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, N,
signal_target, rank, local_world_size, world_size,
intranode_ag_stream):
local_rank = rank % local_world_size
def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, K,
signal_target, rank, local_world_size, world_size,
intranode_ag_stream):
local_rank = rank % local_world_size
🤖 Prompt for AI Agents
In examples/distributed/example_allgather_gemm_ipc.py around lines 48 to 51, the
function parameter named N is incorrect — it represents K (the second dimension
of A); rename the parameter to K in the function signature and update every call
site and internal reference to use K instead of N, ensuring variable names and
any related comments/docstrings reflect the new name so behavior remains
unchanged.

Comment on lines +55 to +69
for i in range(1, local_world_size):
segment = rank * M_per_rank * N
local_dst_rank = (local_rank + local_world_size - i) % local_world_size
src_ptr = ag_buffer[local_rank].data_ptr() + segment * local_tensor.element_size()
dst_ptr = ag_buffer[local_dst_rank].data_ptr() + segment * local_tensor.element_size()
# Using copy engine to perform intranode transmission
# Sending rank-th local tensor to other ranks inside the node.
(err,) = cudart.cudaMemcpyAsync(
dst_ptr,
src_ptr,
M_per_rank * N * local_tensor.element_size(),
cudart.cudaMemcpyKind.cudaMemcpyDefault,
intranode_ag_stream.cuda_stream,
)
# Notify the peer that the transmission is done.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Use K for segment sizing and add CUDA error checks

Copy size/offset must be based on the width of A (K). Also check cudaMemcpyAsync return codes.

Apply this diff:

-    for i in range(1, local_world_size):
-        segment = rank * M_per_rank * N
+    for i in range(1, local_world_size):
+        segment = rank * M_per_rank * K
         local_dst_rank = (local_rank + local_world_size - i) % local_world_size
-        src_ptr = ag_buffer[local_rank].data_ptr() + segment * local_tensor.element_size()
-        dst_ptr = ag_buffer[local_dst_rank].data_ptr() + segment * local_tensor.element_size()
+        src_ptr = ag_buffer[local_rank].data_ptr() + segment * local_tensor.element_size()
+        dst_ptr = ag_buffer[local_dst_rank].data_ptr() + segment * local_tensor.element_size()
         # Using copy engine to perform intranode transmission
         # Sending rank-th local tensor to other ranks inside the node.
         (err,) = cudart.cudaMemcpyAsync(
             dst_ptr,
             src_ptr,
-            M_per_rank * N * local_tensor.element_size(),
+            M_per_rank * K * local_tensor.element_size(),
             cudart.cudaMemcpyKind.cudaMemcpyDefault,
             intranode_ag_stream.cuda_stream,
         )
+        CUDA_CHECK(err)
         # Notify the peer that the transmission is done.
         set_signal(signal_buffer[local_dst_rank][rank], signal_target, intranode_ag_stream)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for i in range(1, local_world_size):
segment = rank * M_per_rank * N
local_dst_rank = (local_rank + local_world_size - i) % local_world_size
src_ptr = ag_buffer[local_rank].data_ptr() + segment * local_tensor.element_size()
dst_ptr = ag_buffer[local_dst_rank].data_ptr() + segment * local_tensor.element_size()
# Using copy engine to perform intranode transmission
# Sending rank-th local tensor to other ranks inside the node.
(err,) = cudart.cudaMemcpyAsync(
dst_ptr,
src_ptr,
M_per_rank * N * local_tensor.element_size(),
cudart.cudaMemcpyKind.cudaMemcpyDefault,
intranode_ag_stream.cuda_stream,
)
# Notify the peer that the transmission is done.
for i in range(1, local_world_size):
segment = rank * M_per_rank * K
local_dst_rank = (local_rank + local_world_size - i) % local_world_size
src_ptr = ag_buffer[local_rank].data_ptr() + segment * local_tensor.element_size()
dst_ptr = ag_buffer[local_dst_rank].data_ptr() + segment * local_tensor.element_size()
# Using copy engine to perform intranode transmission
# Sending rank-th local tensor to other ranks inside the node.
(err,) = cudart.cudaMemcpyAsync(
dst_ptr,
src_ptr,
M_per_rank * K * local_tensor.element_size(),
cudart.cudaMemcpyKind.cudaMemcpyDefault,
intranode_ag_stream.cuda_stream,
)
CUDA_CHECK(err)
# Notify the peer that the transmission is done.
set_signal(signal_buffer[local_dst_rank][rank], signal_target, intranode_ag_stream)

Comment on lines +72 to +92
for i in range(1, n_nodes):
recv_rank = local_rank + (node_rank + n_nodes - i) % n_nodes * local_world_size
recv_segment = recv_rank * M_per_rank * N
# Waiting for the internode data ready
wait_eq(signal_buffer[local_rank][recv_rank], signal_target, intranode_ag_stream)
src_ptr = ag_buffer[local_rank].data_ptr() + recv_segment * local_tensor.element_size()
for j in range(1, local_world_size):
local_dst_rank = (local_rank + local_world_size - j) % local_world_size
dst_ptr = ag_buffer[local_dst_rank].data_ptr(
) + recv_segment * local_tensor.element_size()
# Sending (local_rank + j*local_world_size) % world_size -th local tensor to other ranks inside the node.
(err,) = cudart.cudaMemcpyAsync(
dst_ptr,
src_ptr,
M_per_rank * N * local_tensor.element_size(),
cudart.cudaMemcpyKind.cudaMemcpyDefault,
intranode_ag_stream.cuda_stream,
)
# Notify the peer that the transmission is done.
set_signal(signal_buffer[local_dst_rank][recv_rank], signal_target, intranode_ag_stream)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Likewise for inter-node segment sizing and error checks

Mirror the K-based sizing and add error handling.

Apply this diff:

-    for i in range(1, n_nodes):
-        recv_rank = local_rank + (node_rank + n_nodes - i) % n_nodes * local_world_size
-        recv_segment = recv_rank * M_per_rank * N
+    for i in range(1, n_nodes):
+        recv_rank = local_rank + (node_rank + n_nodes - i) % n_nodes * local_world_size
+        recv_segment = recv_rank * M_per_rank * K
         # Waiting for the internode data ready
         wait_eq(signal_buffer[local_rank][recv_rank], signal_target, intranode_ag_stream)
         src_ptr = ag_buffer[local_rank].data_ptr() + recv_segment * local_tensor.element_size()
         for j in range(1, local_world_size):
             local_dst_rank = (local_rank + local_world_size - j) % local_world_size
             dst_ptr = ag_buffer[local_dst_rank].data_ptr(
             ) + recv_segment * local_tensor.element_size()
             # Sending (local_rank + j*local_world_size) % world_size -th local tensor to other ranks inside the node.
             (err,) = cudart.cudaMemcpyAsync(
                 dst_ptr,
                 src_ptr,
-                M_per_rank * N * local_tensor.element_size(),
+                M_per_rank * K * local_tensor.element_size(),
                 cudart.cudaMemcpyKind.cudaMemcpyDefault,
                 intranode_ag_stream.cuda_stream,
             )
+            CUDA_CHECK(err)
             # Notify the peer that the transmission is done.
             set_signal(signal_buffer[local_dst_rank][recv_rank], signal_target, intranode_ag_stream)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for i in range(1, n_nodes):
recv_rank = local_rank + (node_rank + n_nodes - i) % n_nodes * local_world_size
recv_segment = recv_rank * M_per_rank * N
# Waiting for the internode data ready
wait_eq(signal_buffer[local_rank][recv_rank], signal_target, intranode_ag_stream)
src_ptr = ag_buffer[local_rank].data_ptr() + recv_segment * local_tensor.element_size()
for j in range(1, local_world_size):
local_dst_rank = (local_rank + local_world_size - j) % local_world_size
dst_ptr = ag_buffer[local_dst_rank].data_ptr(
) + recv_segment * local_tensor.element_size()
# Sending (local_rank + j*local_world_size) % world_size -th local tensor to other ranks inside the node.
(err,) = cudart.cudaMemcpyAsync(
dst_ptr,
src_ptr,
M_per_rank * N * local_tensor.element_size(),
cudart.cudaMemcpyKind.cudaMemcpyDefault,
intranode_ag_stream.cuda_stream,
)
# Notify the peer that the transmission is done.
set_signal(signal_buffer[local_dst_rank][recv_rank], signal_target, intranode_ag_stream)
for i in range(1, n_nodes):
recv_rank = local_rank + (node_rank + n_nodes - i) % n_nodes * local_world_size
recv_segment = recv_rank * M_per_rank * K
# Waiting for the internode data ready
wait_eq(signal_buffer[local_rank][recv_rank], signal_target, intranode_ag_stream)
src_ptr = ag_buffer[local_rank].data_ptr() + recv_segment * local_tensor.element_size()
for j in range(1, local_world_size):
local_dst_rank = (local_rank + local_world_size - j) % local_world_size
dst_ptr = ag_buffer[local_dst_rank].data_ptr(
) + recv_segment * local_tensor.element_size()
# Sending (local_rank + j*local_world_size) % world_size -th local tensor to other ranks inside the node.
(err,) = cudart.cudaMemcpyAsync(
dst_ptr,
src_ptr,
M_per_rank * K * local_tensor.element_size(),
cudart.cudaMemcpyKind.cudaMemcpyDefault,
intranode_ag_stream.cuda_stream,
)
CUDA_CHECK(err)
# Notify the peer that the transmission is done.
set_signal(signal_buffer[local_dst_rank][recv_rank], signal_target, intranode_ag_stream)
🧰 Tools
🪛 Ruff (0.13.1)

83-83: Unpacked variable err is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
In examples/distributed/example_allgather_gemm_ipc.py around lines 72 to 92, the
inter-node segment sizing currently uses M_per_rank but should mirror the
K-based sizing (use K_per_rank) and it lacks error checks on cudaMemcpyAsync;
change recv_segment and the byte count passed to cudaMemcpyAsync to use
K_per_rank * N (times element_size), check the returned err from
cudart.cudaMemcpyAsync and raise or log a RuntimeError if err is non-zero, and
only call set_signal after the copy completed successfully (or after handling
the error).

Comment on lines +173 to +175
tilelang_C = ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank,
group, num_local_ranks, num_local_ranks, kernel, ag_stream)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Pass the correct world size (global) to ag_gemm_op

You currently pass num_local_ranks for both local and global sizes. Use the true num_ranks for world_size.

Apply this diff:

-    tilelang_C = ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank,
-                            group, num_local_ranks, num_local_ranks, kernel, ag_stream)
+    tilelang_C = ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank,
+                            group, num_local_ranks, num_ranks, kernel, ag_stream)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
tilelang_C = ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank,
group, num_local_ranks, num_local_ranks, kernel, ag_stream)
tilelang_C = ag_gemm_op(
A, B, C,
ag_buffer, signal_buffer,
M_per_rank, K,
signal_target, rank,
group,
num_local_ranks, num_ranks,
kernel, ag_stream
)
🤖 Prompt for AI Agents
In examples/distributed/example_allgather_gemm_ipc.py around lines 173 to 175,
the call to ag_gemm_op passes num_local_ranks twice (for both local and
global/world sizes); change the second num_local_ranks argument to the true
global world size variable num_ranks so the function receives local_size,
world_size correctly—update the ag_gemm_op call to pass num_ranks as the
world_size parameter.

Comment on lines +202 to +211
if return_peers:
peer_ts = []
for i in range(self._group.size()):
if i == self._local_rank:
peer_ts.append(t)
else:
peer_ptr_val = int(self._buffer_ptrs[i]) + current_offset
peer_t = tensor_from_ptr(peer_ptr_val, shape, dtype_str, self._device, False)
peer_ts.append(peer_t)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard return_peers usage for non-distributed allocators

When return_peers=True, code assumes _group and _buffer_ptrs are initialized. Without distribution, this raises at runtime.

Apply this diff:

-        if return_peers:
+        if return_peers:
+            if not self._is_distributed or self._group is None or self._buffer_ptrs is None:
+                raise RuntimeError(
+                    "return_peers=True requires a distributed allocator with initialized peer mappings."
+                )
             peer_ts = []
             for i in range(self._group.size()):
                 if i == self._local_rank:
                     peer_ts.append(t)
                 else:
                     peer_ptr_val = int(self._buffer_ptrs[i]) + current_offset
                     peer_t = tensor_from_ptr(peer_ptr_val, shape, dtype_str, self._device, False)
                     peer_ts.append(peer_t)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if return_peers:
peer_ts = []
for i in range(self._group.size()):
if i == self._local_rank:
peer_ts.append(t)
else:
peer_ptr_val = int(self._buffer_ptrs[i]) + current_offset
peer_t = tensor_from_ptr(peer_ptr_val, shape, dtype_str, self._device, False)
peer_ts.append(peer_t)
if return_peers:
if not self._is_distributed or self._group is None or self._buffer_ptrs is None:
raise RuntimeError(
"return_peers=True requires a distributed allocator with initialized peer mappings."
)
peer_ts = []
for i in range(self._group.size()):
if i == self._local_rank:
peer_ts.append(t)
else:
peer_ptr_val = int(self._buffer_ptrs[i]) + current_offset
peer_t = tensor_from_ptr(peer_ptr_val, shape, dtype_str, self._device, False)
peer_ts.append(peer_t)
🤖 Prompt for AI Agents
In tilelang/utils/allocator.py around lines 202 to 211, the block that builds
peer_ts when return_peers=True assumes self._group and self._buffer_ptrs exist
and is distributed; guard this by checking whether the allocator is running in
distributed mode (e.g., if self._group is None or not getattr(self,
"_buffer_ptrs", None):) before using them, and if the check fails raise a clear
ValueError or RuntimeError stating that return_peers requires a distributed
allocator (or instruct the caller to set return_peers=False), otherwise proceed
to build peer_ts as before.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
examples/distributed/example_allgather_gemm_ipc.py (1)

94-106: Align ag_gemm_op signature with K and compute local_rank

Prevents confusion and enables correct local peer indexing later.

Apply this diff:

-def ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, N, signal_target, rank, group,
+def ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank, group,
                local_world_size, world_size, gemm_kernel, ag_stream):
 
     dist.barrier(group)
 
     # all_gather A to ag_buffer
     with torch.cuda.stream(ag_stream):
-        cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, N, signal_target,
+        cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, K, signal_target,
                                           rank, local_world_size, world_size, ag_stream)
 
     current_stream = torch.cuda.current_stream()
     current_stream.wait_stream(ag_stream)
+    local_rank = rank % local_world_size
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9de9e1b and e9d360a.

📒 Files selected for processing (1)
  • examples/distributed/example_allgather_gemm_ipc.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/distributed/example_allgather_gemm_ipc.py (6)
tilelang/distributed/utils.py (3)
  • init_dist (34-56)
  • set_signal (268-281)
  • wait_eq (284-298)
tilelang/env.py (1)
  • disable_cache (247-248)
tilelang/utils/allocator.py (2)
  • get_allocator (237-249)
  • device (114-115)
tilelang/jit/__init__.py (1)
  • compile (32-81)
tilelang/jit/kernel.py (1)
  • initialize (400-409)
tilelang/utils/tensor.py (1)
  • tensor (45-58)
🪛 Ruff (0.13.1)
examples/distributed/example_allgather_gemm_ipc.py

83-83: Unpacked variable err is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


183-183: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (8)
examples/distributed/example_allgather_gemm_ipc.py (8)

48-51: Rename parameter: N here is K (width of A)

Avoid confusion and downstream misuse by renaming.

Apply this diff:

-def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, N,
+def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, K,
                                       signal_target, rank, local_world_size, world_size,
                                       intranode_ag_stream):

Follow-up: update all internal references and call sites to use K (see separate comments).


119-127: Torch reference path looks good

Using all_gather_into_tensor + matmul is correct for the reference result.


8-11: Gate cudart import for cuda-python >= 12.8 and import CUDA_CHECK

Direct from cuda import cudart breaks on newer cuda-python. Mirror the version-gated import pattern and bring in CUDA_CHECK for error handling.

Apply this diff:

-from tilelang.distributed.utils import init_dist
-from cuda import cudart
-from tilelang.distributed.utils import set_signal, wait_eq
+from tilelang.distributed.utils import init_dist
+import importlib.metadata
+from packaging import version
+cuda_python_version = importlib.metadata.version("cuda-python")
+if version.parse(cuda_python_version) >= version.parse("12.8.0"):
+    from cuda.bindings import runtime as cudart
+else:
+    from cuda import cudart
+from tilelang.distributed.utils import set_signal, wait_eq, CUDA_CHECK

33-44: Fix grid x-dimension to use local N per rank

Kernel grid uses global N, causing out-of-bounds for B/C which are sized to N // num_rank.

Apply this diff:

-        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
+        with T.Kernel(T.ceildiv(N // num_rank, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):

55-71: Use K for segment sizing and check cudaMemcpyAsync errors

Byte counts/offsets must use K. Also, check return codes.

Apply this diff:

-    for i in range(1, local_world_size):
-        segment = rank * M_per_rank * N
+    for i in range(1, local_world_size):
+        segment = rank * M_per_rank * K
         local_dst_rank = (local_rank + local_world_size - i) % local_world_size
         src_ptr = ag_buffer[local_rank].data_ptr() + segment * local_tensor.element_size()
         dst_ptr = ag_buffer[local_dst_rank].data_ptr() + segment * local_tensor.element_size()
         # Using copy engine to perform intranode transmission
         # Sending rank-th local tensor to other ranks inside the node.
-        (err,) = cudart.cudaMemcpyAsync(
+        (err,) = cudart.cudaMemcpyAsync(
             dst_ptr,
             src_ptr,
-            M_per_rank * N * local_tensor.element_size(),
+            M_per_rank * K * local_tensor.element_size(),
             cudart.cudaMemcpyKind.cudaMemcpyDefault,
             intranode_ag_stream.cuda_stream,
         )
+        CUDA_CHECK(err)
         # Notify the peer that the transmission is done.
         set_signal(signal_buffer[local_dst_rank][rank], signal_target, intranode_ag_stream)

72-92: Mirror K-based sizing for inter-node path and add error checks

Also resolves “err is never used” warning.

Apply this diff:

-    for i in range(1, n_nodes):
-        recv_rank = local_rank + (node_rank + n_nodes - i) % n_nodes * local_world_size
-        recv_segment = recv_rank * M_per_rank * N
+    for i in range(1, n_nodes):
+        recv_rank = local_rank + (node_rank + n_nodes - i) % n_nodes * local_world_size
+        recv_segment = recv_rank * M_per_rank * K
         # Waiting for the internode data ready
         wait_eq(signal_buffer[local_rank][recv_rank], signal_target, intranode_ag_stream)
         src_ptr = ag_buffer[local_rank].data_ptr() + recv_segment * local_tensor.element_size()
         for j in range(1, local_world_size):
             local_dst_rank = (local_rank + local_world_size - j) % local_world_size
             dst_ptr = ag_buffer[local_dst_rank].data_ptr(
             ) + recv_segment * local_tensor.element_size()
             # Sending (local_rank + j*local_world_size) % world_size -th local tensor to other ranks inside the node.
-            (err,) = cudart.cudaMemcpyAsync(
+            (err,) = cudart.cudaMemcpyAsync(
                 dst_ptr,
                 src_ptr,
-                M_per_rank * N * local_tensor.element_size(),
+                M_per_rank * K * local_tensor.element_size(),
                 cudart.cudaMemcpyKind.cudaMemcpyDefault,
                 intranode_ag_stream.cuda_stream,
             )
+            CUDA_CHECK(err)
             # Notify the peer that the transmission is done.
             set_signal(signal_buffer[local_dst_rank][recv_rank], signal_target, intranode_ag_stream)

172-174: Pass global world_size to ag_gemm_op

You currently pass num_local_ranks for both local and global sizes.

Apply this diff:

-    tilelang_C = ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank,
-                            group, num_local_ranks, num_local_ranks, kernel, ag_stream)
+    tilelang_C = ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank,
+                            group, num_local_ranks, num_ranks, kernel, ag_stream)

12-12: Module‐level disable_cache alias is defined
disable_cache = CacheState.disable is declared in tilelang/env.py and re-exported in tilelang/__init__.py, so tilelang.disable_cache() is valid. No changes needed.

Likely an incorrect or invalid review comment.

Comment on lines +110 to +114
torch.cuda.synchronize()
torch.distributed.barrier(group)
gemm_kernel(ag_buffer[rank], B, C)
torch.cuda.synchronize()
torch.distributed.barrier(group)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Use local peer for kernel input; drop redundant sync

Fix out-of-range access when rank >= local_world_size. Also remove duplicate synchronize.

Apply this diff:

-    torch.cuda.synchronize()
-    torch.distributed.barrier(group)
-    gemm_kernel(ag_buffer[rank], B, C)
+    gemm_kernel(ag_buffer[local_rank], B, C)
     torch.cuda.synchronize()
-    torch.distributed.barrier(group)
+    torch.distributed.barrier(group)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
torch.cuda.synchronize()
torch.distributed.barrier(group)
gemm_kernel(ag_buffer[rank], B, C)
torch.cuda.synchronize()
torch.distributed.barrier(group)
gemm_kernel(ag_buffer[local_rank], B, C)
torch.cuda.synchronize()
torch.distributed.barrier(group)
🤖 Prompt for AI Agents
In examples/distributed/example_allgather_gemm_ipc.py around lines 110 to 114,
the code uses the global rank as the index into ag_buffer which can be
out-of-range when rank >= local_world_size; change the kernel input index to a
local peer index (e.g., local_peer = rank % local_world_size or however
local_world_size is computed in this module) and call
gemm_kernel(ag_buffer[local_peer], B, C) instead of ag_buffer[rank]; also remove
the redundant torch.cuda.synchronize() (keep a single synchronize after the
kernel) so you only call torch.distributed.barrier(group), gemm_kernel(... with
local_peer ...), torch.cuda.synchronize(), and torch.distributed.barrier(group).

Comment on lines +130 to +151
def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
dtype = torch.float16
M = args.M if args else 8192
N = args.N if args else 8192
K = args.K if args else 8192
M_per_rank = M // num_local_ranks
N_per_rank = N // num_local_ranks

BLOCK_M = 128
BLOCK_N = 128
BLOCK_K = 64
threads = 256

rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
allocator = tilelang.get_allocator(
size=2**30,
device="cuda",
is_distributed=True,
local_rank=local_rank,
num_local_ranks=num_local_ranks,
group=group)
kernel = tilelang.compile(gemm_kernel(M, N, K, num_ranks, BLOCK_M, BLOCK_N, BLOCK_K, threads))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Compute per-rank sizes using the global world size and after init_dist

Current M_per_rank/N_per_rank use num_local_ranks, which breaks multi-node (OOB segment indexing).

Apply this diff:

-    M = args.M if args else 8192
-    N = args.N if args else 8192
-    K = args.K if args else 8192
-    M_per_rank = M // num_local_ranks
-    N_per_rank = N // num_local_ranks
+    M = args.M if args else 8192
+    N = args.N if args else 8192
+    K = args.K if args else 8192
@@
-    rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
+    rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
+    M_per_rank = M // num_ranks
+    N_per_rank = N // num_ranks
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
dtype = torch.float16
M = args.M if args else 8192
N = args.N if args else 8192
K = args.K if args else 8192
M_per_rank = M // num_local_ranks
N_per_rank = N // num_local_ranks
BLOCK_M = 128
BLOCK_N = 128
BLOCK_K = 64
threads = 256
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
allocator = tilelang.get_allocator(
size=2**30,
device="cuda",
is_distributed=True,
local_rank=local_rank,
num_local_ranks=num_local_ranks,
group=group)
kernel = tilelang.compile(gemm_kernel(M, N, K, num_ranks, BLOCK_M, BLOCK_N, BLOCK_K, threads))
def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
dtype = torch.float16
M = args.M if args else 8192
N = args.N if args else 8192
K = args.K if args else 8192
BLOCK_M = 128
BLOCK_N = 128
BLOCK_K = 64
threads = 256
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
M_per_rank = M // num_ranks
N_per_rank = N // num_ranks
allocator = tilelang.get_allocator(
size=2**30,
device="cuda",
is_distributed=True,
local_rank=local_rank,
num_local_ranks=num_local_ranks,
group=group)
kernel = tilelang.compile(
gemm_kernel(M, N, K, num_ranks, BLOCK_M, BLOCK_N, BLOCK_K, threads)
)
🤖 Prompt for AI Agents
In examples/distributed/example_allgather_gemm_ipc.py around lines 130-151,
M_per_rank and N_per_rank are computed too early using num_local_ranks which
breaks multi-node indexing; move the M_per_rank/N_per_rank calculations to after
the init_dist(...) call and compute them with the global world size (use
num_ranks returned by init_dist), e.g. M_per_rank = M // num_ranks and
N_per_rank = N // num_ranks, ensuring the per-rank sizes are based on the global
number of ranks and evaluated after init_dist.

Comment on lines +156 to +166
A = tilelang.tensor((M_per_rank, K), dtype, allocator=allocator).normal_()
B = tilelang.tensor((K, N_per_rank), dtype, allocator=allocator).normal_()
C = tilelang.tensor((M, N_per_rank), dtype, allocator=allocator)
ag_buffer = tilelang.tensor((M, K), dtype, allocator=allocator, return_peers=True)
signal_buffer = tilelang.tensor((num_local_ranks,),
torch.int32,
allocator=allocator,
return_peers=True)
signal_buffer[rank].fill_(0)
ag_buffer[rank][rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix peer buffer/signal shapes and local indexing

  • signal_buffer inner length must be world_size (you index by global rank).
  • Index ag_buffer/signal_buffer by local_rank, not global rank.

Apply this diff:

-    A = tilelang.tensor((M_per_rank, K), dtype, allocator=allocator).normal_()
-    B = tilelang.tensor((K, N_per_rank), dtype, allocator=allocator).normal_()
-    C = tilelang.tensor((M, N_per_rank), dtype, allocator=allocator)
-    ag_buffer = tilelang.tensor((M, K), dtype, allocator=allocator, return_peers=True)
-    signal_buffer = tilelang.tensor((num_local_ranks,),
-                                    torch.int32,
-                                    allocator=allocator,
-                                    return_peers=True)
-    signal_buffer[rank].fill_(0)
-    ag_buffer[rank][rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A)
+    A = tilelang.tensor((M_per_rank, K), dtype, allocator=allocator).normal_()
+    B = tilelang.tensor((K, N_per_rank), dtype, allocator=allocator).normal_()
+    C = tilelang.tensor((M, N_per_rank), dtype, allocator=allocator)
+    ag_buffer = tilelang.tensor((M, K), dtype, allocator=allocator, return_peers=True)
+    signal_buffer = tilelang.tensor((num_ranks,),
+                                    torch.int32,
+                                    allocator=allocator,
+                                    return_peers=True)
+    signal_buffer[local_rank].fill_(0)
+    ag_buffer[local_rank][rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
A = tilelang.tensor((M_per_rank, K), dtype, allocator=allocator).normal_()
B = tilelang.tensor((K, N_per_rank), dtype, allocator=allocator).normal_()
C = tilelang.tensor((M, N_per_rank), dtype, allocator=allocator)
ag_buffer = tilelang.tensor((M, K), dtype, allocator=allocator, return_peers=True)
signal_buffer = tilelang.tensor((num_local_ranks,),
torch.int32,
allocator=allocator,
return_peers=True)
signal_buffer[rank].fill_(0)
ag_buffer[rank][rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A)
A = tilelang.tensor((M_per_rank, K), dtype, allocator=allocator).normal_()
B = tilelang.tensor((K, N_per_rank), dtype, allocator=allocator).normal_()
C = tilelang.tensor((M, N_per_rank), dtype, allocator=allocator)
ag_buffer = tilelang.tensor((M, K), dtype, allocator=allocator, return_peers=True)
signal_buffer = tilelang.tensor((num_ranks,),
torch.int32,
allocator=allocator,
return_peers=True)
signal_buffer[local_rank].fill_(0)
ag_buffer[local_rank][rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A)
🤖 Prompt for AI Agents
In examples/distributed/example_allgather_gemm_ipc.py around lines 156 to 166,
the peer signal buffer is created with length num_local_ranks and both ag_buffer
and signal_buffer are indexed by global rank; change the signal_buffer inner
length to world_size (since you index by global ranks across all peers) and
change all indexes that use rank to use local_rank instead (e.g., access
ag_buffer and signal_buffer via local_rank when writing local data/signal and
keep slices based on local offsets using local_rank * M_per_rank ..
(local_rank+1) * M_per_rank). Ensure the buffer constructions use allocator and
return_peers as before but with the corrected shape for signal_buffer and
replace global rank indexing with local_rank for local writes.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

🧹 Nitpick comments (5)
.github/workflows/ci.yml (2)

126-133: Optional: simplify discovery by using pytest selection instead of find/mapfile

You can avoid brittle shell/globbing by letting pytest handle selection (keeps reporting consistent, works across platforms).

Example (no need to change now):

  • Distributed examples:
    • TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 1 -v -r fE -k "distributed and test"
  • Non-distributed examples:
    • python -m pytest -n 4 -v -r fE -k "not distributed and test"

152-159: Optional: let pytest handle non-distributed selection

Using -k avoids maintaining separate find expressions and scales better as tests move.

Example (no immediate change required):

  • python -m pytest -n 4 -v -r fE -k "not distributed and test"
tilelang/distributed/utils.py (1)

268-282: Signal helpers: support 64-bit signals and raise precise exceptions.

  • set_signal only supports int32; many call sites use uint64 signals.
  • Use cuStreamWriteValue64 when tensor is int64/uint64.
  • Raise TypeError instead of Exception.

Apply:

-def set_signal(signal_tensor: torch.Tensor,
-               signal: int,
-               stream: Optional[torch.cuda.Stream] = None):
-    stream = stream or torch.cuda.current_stream()
-    if signal_tensor.dtype == torch.int32:
-        (err,) = cuda.cuStreamWriteValue32(
-            stream.cuda_stream,
-            signal_tensor.data_ptr(),
-            signal,
-            cuda.CUstreamWriteValue_flags.CU_STREAM_WRITE_VALUE_DEFAULT,
-        )
-        CUDA_CHECK(err)
-    else:
-        raise Exception(f"Unsupported signal dtype {signal_tensor.dtype}")
+def set_signal(
+    signal_tensor: torch.Tensor,
+    signal: int,
+    stream: Optional[torch.cuda.Stream] = None,
+):
+    stream = stream or torch.cuda.current_stream()
+    dt = signal_tensor.dtype
+    if dt == torch.int32:
+        (err,) = cuda.cuStreamWriteValue32(
+            stream.cuda_stream,
+            signal_tensor.data_ptr(),
+            signal,
+            cuda.CUstreamWriteValue_flags.CU_STREAM_WRITE_VALUE_DEFAULT,
+        )
+        CUDA_CHECK(err)
+    elif dt in (torch.int64, torch.uint64):
+        (err,) = cuda.cuStreamWriteValue64(
+            stream.cuda_stream,
+            signal_tensor.data_ptr(),
+            signal,
+            cuda.CUstreamWriteValue_flags.CU_STREAM_WRITE_VALUE_DEFAULT,
+        )
+        CUDA_CHECK(err)
+    else:
+        raise TypeError(f"Unsupported signal dtype {dt}; expected int32/int64/uint64")

Additionally, define a small custom exception if preferred by your style guide:

class SignalDTypeError(TypeError):
    pass

And replace TypeError with SignalDTypeError. As per coding guidelines

examples/distributed/example_allgather_gemm_ipc.py (2)

110-114: Remove redundant synchronization calls.

Multiple consecutive synchronization calls are redundant and impact performance.

Apply this diff to clean up synchronization:

-    torch.cuda.synchronize()
-    torch.distributed.barrier(group)
     gemm_kernel(ag_buffer[rank], B, C)
     torch.cuda.synchronize()
     torch.distributed.barrier(group)

183-183: Use exception chaining for better debugging context.

When raising exceptions after a check failure, preserve the original context.

Apply this diff:

-        raise ValueError("Test failed")
+        raise ValueError(f"Correctness check failed for rank {local_rank}")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a053313 and abfbcd4.

📒 Files selected for processing (27)
  • .github/workflows/ci.yml (1 hunks)
  • benchmark/distributed/benchmark_ag_gemm.py (1 hunks)
  • benchmark/distributed/benchmark_all_gather.py (1 hunks)
  • benchmark/distributed/benchmark_all_to_all.py (1 hunks)
  • benchmark/distributed/benchmark_gemm_rs.py (1 hunks)
  • benchmark/distributed/benchmark_reduce_scatter.py (1 hunks)
  • benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py (1 hunks)
  • benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py (1 hunks)
  • examples/distributed/example_all_to_all.py (1 hunks)
  • examples/distributed/example_allgather.py (1 hunks)
  • examples/distributed/example_allgather_gemm.py (1 hunks)
  • examples/distributed/example_allgather_gemm_ipc.py (1 hunks)
  • examples/distributed/example_cannon.py (1 hunks)
  • examples/distributed/example_gemm_rs.py (1 hunks)
  • examples/distributed/example_post_attn_all2all_transpose.py (3 hunks)
  • examples/distributed/example_pre_attn_all2all.py (3 hunks)
  • examples/distributed/example_pre_attn_all2all_transpose.py (3 hunks)
  • examples/distributed/example_simple_shift.py (1 hunks)
  • examples/distributed/example_summa.py (1 hunks)
  • examples/distributed/primitives/example_get_block.py (1 hunks)
  • examples/distributed/primitives/example_get_warp.py (1 hunks)
  • examples/distributed/primitives/example_put_block.py (1 hunks)
  • examples/distributed/primitives/example_put_warp.py (1 hunks)
  • tilelang/distributed/common/README.md (1 hunks)
  • tilelang/distributed/testing/sync/test_barrierall_sys.py (1 hunks)
  • tilelang/distributed/utils.py (3 hunks)
  • tilelang/utils/allocator.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tilelang/utils/allocator.py
🧰 Additional context used
🧬 Code graph analysis (23)
examples/distributed/primitives/example_get_warp.py (1)
tilelang/distributed/utils.py (1)
  • init_dist (34-56)
examples/distributed/example_post_attn_all2all_transpose.py (2)
tilelang/distributed/utils.py (1)
  • init_distributed (59-82)
examples/distributed/example_pre_attn_all2all.py (1)
  • verify_results (138-161)
benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py (1)
tilelang/distributed/utils.py (2)
  • init_dist (34-56)
  • perf_fn (217-238)
benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py (1)
tilelang/distributed/utils.py (2)
  • init_distributed (59-82)
  • perf_fn (217-238)
benchmark/distributed/benchmark_ag_gemm.py (1)
tilelang/distributed/utils.py (2)
  • init_distributed (59-82)
  • perf_fn (217-238)
benchmark/distributed/benchmark_all_gather.py (1)
tilelang/distributed/utils.py (2)
  • init_distributed (59-82)
  • perf_fn (217-238)
examples/distributed/example_all_to_all.py (2)
tilelang/distributed/utils.py (1)
  • init_distributed (59-82)
tilelang/profiler/__init__.py (1)
  • init_distributed (67-88)
benchmark/distributed/benchmark_gemm_rs.py (1)
tilelang/distributed/utils.py (2)
  • init_distributed (59-82)
  • perf_fn (217-238)
benchmark/distributed/benchmark_all_to_all.py (1)
tilelang/distributed/utils.py (1)
  • init_distributed (59-82)
tilelang/distributed/testing/sync/test_barrierall_sys.py (1)
tilelang/distributed/utils.py (1)
  • init_dist (34-56)
examples/distributed/example_gemm_rs.py (2)
tilelang/distributed/utils.py (2)
  • init_distributed (59-82)
  • generate_data (168-170)
tilelang/profiler/__init__.py (1)
  • init_distributed (67-88)
examples/distributed/example_cannon.py (1)
tilelang/distributed/utils.py (1)
  • init_distributed (59-82)
examples/distributed/primitives/example_put_warp.py (1)
tilelang/distributed/utils.py (1)
  • init_dist (34-56)
examples/distributed/example_allgather_gemm.py (1)
tilelang/distributed/utils.py (1)
  • init_distributed (59-82)
examples/distributed/primitives/example_get_block.py (1)
tilelang/distributed/utils.py (1)
  • init_dist (34-56)
examples/distributed/example_pre_attn_all2all_transpose.py (2)
tilelang/distributed/utils.py (1)
  • init_distributed (59-82)
examples/distributed/example_post_attn_all2all_transpose.py (1)
  • verify_results (141-164)
examples/distributed/primitives/example_put_block.py (1)
tilelang/distributed/utils.py (1)
  • init_dist (34-56)
examples/distributed/example_allgather_gemm_ipc.py (6)
tilelang/distributed/utils.py (3)
  • init_dist (34-56)
  • set_signal (268-281)
  • wait_eq (284-298)
examples/distributed/example_allgather_gemm.py (1)
  • main (15-50)
tilelang/utils/allocator.py (2)
  • get_allocator (237-249)
  • device (114-115)
tilelang/jit/__init__.py (1)
  • compile (32-81)
tilelang/jit/kernel.py (1)
  • initialize (400-409)
tilelang/utils/tensor.py (1)
  • tensor (45-58)
examples/distributed/example_summa.py (1)
tilelang/distributed/utils.py (1)
  • init_distributed (59-82)
examples/distributed/example_simple_shift.py (2)
tilelang/distributed/utils.py (1)
  • init_distributed (59-82)
tilelang/profiler/__init__.py (1)
  • init_distributed (67-88)
examples/distributed/example_pre_attn_all2all.py (2)
tilelang/distributed/utils.py (1)
  • init_distributed (59-82)
examples/distributed/example_post_attn_all2all_transpose.py (1)
  • verify_results (141-164)
benchmark/distributed/benchmark_reduce_scatter.py (1)
tilelang/distributed/utils.py (2)
  • init_distributed (59-82)
  • perf_fn (217-238)
examples/distributed/example_allgather.py (1)
tilelang/distributed/utils.py (2)
  • init_distributed (59-82)
  • perf_fn (217-238)
🪛 Ruff (0.13.1)
examples/distributed/example_allgather_gemm_ipc.py

83-83: Unpacked variable err is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


183-183: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/distributed/utils.py

281-281: Create your own exception

(TRY002)


281-281: Avoid specifying long messages outside the exception class

(TRY003)


287-287: Unused function argument: require_i64

(ARG001)


298-298: Create your own exception

(TRY002)


298-298: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: build-test-nvidia
🔇 Additional comments (32)
benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py (1)

8-8: Import path aligns with new public API surface.

tilelang.distributed now re-exports init_dist and perf_fn, so this change keeps the benchmark in sync with the refreshed namespace.

examples/distributed/example_post_attn_all2all_transpose.py (3)

6-6: Consolidated distributed imports look good.

Pulling init_distributed and dtype_map from the package root keeps these examples aligned with the new public surface.


45-45: Loop variable cleanup appreciated.

Switching to _ makes it explicit the index isn’t used and quiets static analyzers.


265-265: Direct verification call is fine.

Dropping the unused local keeps the flow lean while preserving the check.

benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py (1)

12-12: Import consolidation to the public API is fine; verify re-exports.

Ensure tilelang.distributed publicly exposes init_distributed and perf_fn without introducing a utils ↔ init import cycle.

Use the same script provided in the other comment to confirm init re-exports and no circular imports.

tilelang/distributed/utils.py (1)

301-304: cuda_stream_max_priority: LGTM.

examples/distributed/primitives/example_put_warp.py (1)

8-8: Import path update is fine; ensure init_dist is re-exported from tilelang.distributed.

Confirm tilelang.distributed exposes init_dist to avoid breaking this example.

Reuse the verification script posted earlier to inspect tilelang/distributed/init.py for init_dist re-exports and to detect cycles.

examples/distributed/example_cannon.py (1)

6-6: Public API import switch looks good.

No other changes in this file; behavior unaffected.

benchmark/distributed/benchmark_gemm_rs.py (1)

12-12: Import move to the public API namespace is OK; verify re-exports.

Ensure tilelang.distributed re-exports init_distributed, dtype_map, perf_fn to prevent runtime failures.

Use the same init.py verification script mentioned previously.

examples/distributed/example_allgather_gemm.py (1)

7-7: Import consolidation to tilelang.distributed is fine.

No behavior change; just verify init_distributed is re-exported in init.py.

Same verification script as earlier applies.

examples/distributed/primitives/example_get_warp.py (1)

8-8: Import path change LGTM; confirm re-export of init_dist.

Avoid breakage by ensuring tilelang.distributed exposes init_dist.

Same init.py verification script applies.

benchmark/distributed/benchmark_all_to_all.py (1)

4-4: Import-path update looks correct

Switching to the public tilelang.distributed re-export keeps behavior unchanged and aligns with the PR-wide consolidation. 👍

examples/distributed/primitives/example_get_block.py (1)

8-8: Import routed through public module

Using tilelang.distributed for init_dist matches the new API surface and maintains existing functionality.

examples/distributed/primitives/example_put_block.py (1)

8-8: Public import confirmed

The updated import mirrors the new re-export and leaves runtime behavior intact.

examples/distributed/example_all_to_all.py (1)

6-6: Re-export usage is consistent

The module now sources init_distributed from the consolidated public package; no other changes required.

tilelang/distributed/testing/sync/test_barrierall_sys.py (1)

9-9: Test import aligned with public API

Pointing the test to tilelang.distributed keeps it in sync with the new public surface.

benchmark/distributed/benchmark_all_gather.py (1)

7-7: Benchmark import path updated correctly

All three utilities (init_distributed, dtype_map, perf_fn) are now pulled from the public module, matching the project-wide change.

examples/distributed/example_summa.py (1)

6-6: Public re-export in use

The example now depends on the consolidated tilelang.distributed surface; no issues spotted.

examples/distributed/example_allgather.py (1)

7-7: Allgather example aligned with new API

Imports of init_distributed, dtype_map, and perf_fn now match the re-exported public interface.

benchmark/distributed/benchmark_ag_gemm.py (1)

21-21: LGTM! Import path consolidation aligns with project-wide refactor.

The import path change from tilelang.distributed.utils to tilelang.distributed correctly reflects the public API surface reorganization happening across the project.

examples/distributed/example_gemm_rs.py (1)

5-5: LGTM! Import path update matches project conventions.

The change from tilelang.distributed.utils to tilelang.distributed is consistent with the broader public API reorganization.

tilelang/distributed/common/README.md (1)

15-15: LGTM! Documentation correctly reflects API surface changes.

The README appropriately documents the new import path from tilelang.distributed instead of the previous tilelang.distributed.utils.

examples/distributed/example_allgather_gemm_ipc.py (10)

9-10: Fix cuda-python compatibility for cudart import.

Direct import from cuda breaks on cuda-python >= 12.8. Use version-gated import pattern.

Apply this diff to fix the import compatibility:

-from cuda import cudart
+import importlib.metadata
+from packaging import version
+cuda_python_version = importlib.metadata.version("cuda-python")
+if version.parse(cuda_python_version) >= version.parse("12.8.0"):
+    from cuda.bindings import runtime as cudart
+else:
+    from cuda import cudart

33-33: Fix grid dimensions to use per-rank local N dimension.

The kernel grid uses global N which causes out-of-bounds accesses. Must use local N_per_rank = N // num_rank.

Apply this diff:

-        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
+        with T.Kernel(T.ceildiv(N // num_rank, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):

48-50: Rename parameter N to K to match tensor dimensions.

The parameter N actually represents the K dimension (second dimension of A). This misnaming causes confusion.

Apply this diff:

-def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, N,
+def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, K,
                                       signal_target, rank, local_world_size, world_size,
                                       intranode_ag_stream):

Also update all usages of N within this function to K.


56-56: Use K for segment sizing and add error checking.

Segment calculation and copy size must use K dimension. Also add CUDA error checking.

Apply this diff:

-        segment = rank * M_per_rank * N
+        segment = rank * M_per_rank * K
...
-            M_per_rank * N * local_tensor.element_size(),
+            M_per_rank * K * local_tensor.element_size(),

Also add error checking after line 68:

+        from tilelang.distributed import CUDA_CHECK
+        CUDA_CHECK(err)

Also applies to: 65-65


74-74: Fix inter-node segment sizing to use K.

Inter-node segments and copy sizes should also use K dimension.

Apply this diff:

-        recv_segment = recv_rank * M_per_rank * N
+        recv_segment = recv_rank * M_per_rank * K
...
-                M_per_rank * N * local_tensor.element_size(),
+                M_per_rank * K * local_tensor.element_size(),

Also applies to: 86-86


112-112: Use local_rank to index ag_buffer.

Using global rank to index ag_buffer causes out-of-bounds access when rank >= num_local_ranks.

Apply this diff:

-    gemm_kernel(ag_buffer[rank], B, C)
+    gemm_kernel(ag_buffer[local_rank], B, C)

Also add local_rank computation at the start of the function:

+    local_rank = rank % local_world_size

135-136: Compute per-rank sizes using global world size after init_dist.

Per-rank sizes computed with num_local_ranks break multi-node setups. Move calculation after init_dist.

Apply this diff:

     K = args.K if args else 8192
-    M_per_rank = M // num_local_ranks
-    N_per_rank = N // num_local_ranks

     BLOCK_M = 128
     BLOCK_N = 128
     BLOCK_K = 64
     threads = 256

     rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
+    M_per_rank = M // num_ranks
+    N_per_rank = N // num_ranks

Also applies to: 143-143


160-165: Fix signal buffer size and use local_rank for indexing.

Signal buffer must be sized for global world_size since you index by global rank. Also use local_rank for local operations.

Apply this diff:

-    signal_buffer = tilelang.tensor((num_local_ranks,),
+    signal_buffer = tilelang.tensor((num_ranks,),
                                     torch.int32,
                                     allocator=allocator,
                                     return_peers=True)
-    signal_buffer[rank].fill_(0)
-    ag_buffer[rank][rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A)
+    signal_buffer[local_rank].fill_(0)
+    ag_buffer[local_rank][rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A)

172-173: Pass correct world_size to ag_gemm_op.

Currently passing num_local_ranks for both local and world sizes. Should pass num_ranks for world_size.

Apply this diff:

     tilelang_C = ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank,
-                            group, num_local_ranks, num_local_ranks, kernel, ag_stream)
+                            group, num_local_ranks, num_ranks, kernel, ag_stream)

83-83: Add CUDA error checking or use underscore prefix.

The error return value is unpacked but not checked, which could hide failures.

Apply this diff to add error checking:

-            (err,) = cudart.cudaMemcpyAsync(
+            (err,) = cudart.cudaMemcpyAsync(
                 dst_ptr,
                 src_ptr,
-                M_per_rank * N * local_tensor.element_size(),
+                M_per_rank * K * local_tensor.element_size(),
                 cudart.cudaMemcpyKind.cudaMemcpyDefault,
                 intranode_ag_stream.cuda_stream,
             )
+            from tilelang.distributed import CUDA_CHECK
+            CUDA_CHECK(err)

Or if error checking is intentionally omitted, use underscore prefix:

-            (err,) = cudart.cudaMemcpyAsync(
+            (_err,) = cudart.cudaMemcpyAsync(

Comment on lines +115 to +123
# find and run distributed tests with TILELANG_USE_DISTRIBUTED=1
mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true)
if [ "${#DIST_TESTS[@]}" -gt 0 ]; then
echo "Running distributed examples with TILELANG_USE_DISTRIBUTED=1:"
printf '%s\n' "${DIST_TESTS[@]}"
TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 4 "${DIST_TESTS[@]}" -v -r fE
else
echo "No distributed examples found."
fi
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Run distributed example tests serially and harden shell options

Parallelizing distributed GPU tests with pytest-xdist (-n 4) commonly oversubscribes GPUs/process groups and causes flaky hangs/OOMs. Also add strict shell flags to fail fast on CI.

Apply this diff:

-        # find and run distributed tests with TILELANG_USE_DISTRIBUTED=1
+        # find and run distributed tests with TILELANG_USE_DISTRIBUTED=1
+        set -euo pipefail
         mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true)
         if [ "${#DIST_TESTS[@]}" -gt 0 ]; then
           echo "Running distributed examples with TILELANG_USE_DISTRIBUTED=1:"
           printf '%s\n' "${DIST_TESTS[@]}"
-          TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 4 "${DIST_TESTS[@]}" -v -r fE
+          # run distributed tests serially to avoid GPU/process-group contention
+          TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 1 "${DIST_TESTS[@]}" -v -r fE
         else
           echo "No distributed examples found."
         fi
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# find and run distributed tests with TILELANG_USE_DISTRIBUTED=1
mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true)
if [ "${#DIST_TESTS[@]}" -gt 0 ]; then
echo "Running distributed examples with TILELANG_USE_DISTRIBUTED=1:"
printf '%s\n' "${DIST_TESTS[@]}"
TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 4 "${DIST_TESTS[@]}" -v -r fE
else
echo "No distributed examples found."
fi
# find and run distributed tests with TILELANG_USE_DISTRIBUTED=1
set -euo pipefail
mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true)
if [ "${#DIST_TESTS[@]}" -gt 0 ]; then
echo "Running distributed examples with TILELANG_USE_DISTRIBUTED=1:"
printf '%s\n' "${DIST_TESTS[@]}"
# run distributed tests serially to avoid GPU/process-group contention
TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 1 "${DIST_TESTS[@]}" -v -r fE
else
echo "No distributed examples found."
fi
🤖 Prompt for AI Agents
In .github/workflows/ci.yml around lines 115 to 123, the workflow currently runs
distributed GPU tests in parallel and lacks strict shell failure flags; change
it to run the distributed tests serially (remove pytest-xdist -n 4) so
TILELANG_USE_DISTRIBUTED=1 python -m pytest "${DIST_TESTS[@]}" -v -r fE is used,
and add strict shell options at the start of the shell block (e.g., set -euo
pipefail and IFS=$'\n\t') so the job fails fast on errors and handles unset
variables/pipe failures robustly.

Comment on lines +141 to +149
# run distributed tests first with env var
mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true)
if [ "${#DIST_TESTS[@]}" -gt 0 ]; then
echo "Running distributed tests with TILELANG_USE_DISTRIBUTED=1:"
printf '%s\n' "${DIST_TESTS[@]}"
TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 4 "${DIST_TESTS[@]}" -v -r fE
else
echo "No distributed tests found under testing/python."
fi
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Run distributed tests (testing/python) serially and add strict shell flags

Same flakiness risk here from -n 4 on GPU-distributed tests; also add set -euo pipefail for safety.

Apply this diff:

-        # run distributed tests first with env var
+        # run distributed tests first with env var
+        set -euo pipefail
         mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true)
         if [ "${#DIST_TESTS[@]}" -gt 0 ]; then
           echo "Running distributed tests with TILELANG_USE_DISTRIBUTED=1:"
           printf '%s\n' "${DIST_TESTS[@]}"
-          TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 4 "${DIST_TESTS[@]}" -v -r fE
+          # run distributed tests serially to avoid GPU/process-group contention
+          TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 1 "${DIST_TESTS[@]}" -v -r fE
         else
           echo "No distributed tests found under testing/python."
         fi
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# run distributed tests first with env var
mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true)
if [ "${#DIST_TESTS[@]}" -gt 0 ]; then
echo "Running distributed tests with TILELANG_USE_DISTRIBUTED=1:"
printf '%s\n' "${DIST_TESTS[@]}"
TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 4 "${DIST_TESTS[@]}" -v -r fE
else
echo "No distributed tests found under testing/python."
fi
# run distributed tests first with env var
set -euo pipefail
mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true)
if [ "${#DIST_TESTS[@]}" -gt 0 ]; then
echo "Running distributed tests with TILELANG_USE_DISTRIBUTED=1:"
printf '%s\n' "${DIST_TESTS[@]}"
# run distributed tests serially to avoid GPU/process-group contention
TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 1 "${DIST_TESTS[@]}" -v -r fE
else
echo "No distributed tests found under testing/python."
fi
🤖 Prompt for AI Agents
.github/workflows/ci.yml lines 141-149: the workflow runs GPU-distributed tests
in parallel with pytest -n 4 which causes flakiness and lacks strict shell
safety; remove the -n 4 flag so pytest runs serially for those distributed
tests, and add strict shell flags (set -euo pipefail) before this block (or at
top of the script/job) so failures, undefined variables, and pipe errors abort
the job; ensure TILELANG_USE_DISTRIBUTED=1 is preserved when invoking pytest.


# all_gather A to ag_buffer
with torch.cuda.stream(ag_stream):
cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, N, signal_target,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Pass correct K dimension to cp_engine_producer_all_gather_put.

The function call passes N but should pass K to match the corrected function signature.

Apply this diff:

-        cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, N, signal_target,
+        cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, K, signal_target,
                                           rank, local_world_size, world_size, ag_stream)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, N, signal_target,
cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, K, signal_target,
rank, local_world_size, world_size, ag_stream)
🤖 Prompt for AI Agents
In examples/distributed/example_allgather_gemm_ipc.py around line 101, the call
to cp_engine_producer_all_gather_put incorrectly passes N as the K dimension;
update the call to pass K instead of N so the arguments match the corrected
function signature (replace the N argument with K in that function call).

print(f"PE {RANK} TileLang output sample: {tilelang_output[0, 0, 0, :3]}")

verification_passed = verify_results(tilelang_output, torch_output, RANK)
verify_results(tilelang_output, torch_output, RANK)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Restore verification failure propagation

By discarding the boolean return from verify_results, any mismatch now only prints to stdout and the script exits successfully, hiding real correctness regressions. Please reintroduce the check so the example fails fast when verification fails.

-    verify_results(tilelang_output, torch_output, RANK)
+    verification_passed = verify_results(tilelang_output, torch_output, RANK)
+    if not verification_passed:
+        raise RuntimeError(f"PE {RANK}: TileLang output does not match PyTorch reference")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
verify_results(tilelang_output, torch_output, RANK)
verification_passed = verify_results(tilelang_output, torch_output, RANK)
if not verification_passed:
raise RuntimeError(f"PE {RANK}: TileLang output does not match PyTorch reference")
🤖 Prompt for AI Agents
In examples/distributed/example_pre_attn_all2all_transpose.py around line 260,
the call to verify_results(tilelang_output, torch_output, RANK) currently
ignores its boolean return value; restore failure propagation by capturing the
return and exiting non-zero on failure (e.g., if not verify_results(...) then
print/log error and call sys.exit(1) or raise an exception) so the script fails
fast when verification fails.

print(f"PE {RANK} TileLang output sample: {tilelang_output[0, 0, 0, :3]}")

verification_passed = verify_results(tilelang_output, torch_output, RANK)
verify_results(tilelang_output, torch_output, RANK)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Restore failure propagation from verify_results.

We now drop the boolean return, so the script exits successfully even when verification fails—automations won’t catch regressions. Capture the return and exit (or raise) on failure.

-    verify_results(tilelang_output, torch_output, RANK)
+    success = verify_results(tilelang_output, torch_output, RANK)
+    if not success:
+        raise SystemExit(1)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
verify_results(tilelang_output, torch_output, RANK)
success = verify_results(tilelang_output, torch_output, RANK)
if not success:
raise SystemExit(1)
🤖 Prompt for AI Agents
In examples/distributed/example_pre_attn_all2all.py around line 263, the call to
verify_results(tilelang_output, torch_output, RANK) currently ignores its
boolean result so failures are treated as success; capture the return value and
fail the script on false by either raising an exception or calling sys.exit with
a non-zero status (e.g., import sys and if not ok: sys.exit(1)) so CI and
automation detect verification regressions.

import tilelang.language as T
from tilelang.profiler import TensorSupplyType
from tilelang.distributed.utils import init_distributed
from tilelang.distributed import init_distributed
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Confirm init_distributed re-export

Switching the import to from tilelang.distributed import init_distributed assumes the package’s __init__ now re-exports that symbol; otherwise this will raise an AttributeError at runtime. Please confirm the re-export exists.


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Verify that tilelang/distributed/__init__.py exposes init_distributed.
if [[ -f tilelang/distributed/__init__.py ]]; then
  rg -n "init_distributed" tilelang/distributed/__init__.py
else
  echo "tilelang/distributed/__init__.py missing"
fi

Length of output: 40


Fix import for init_distributed
The tilelang/distributed directory has no __init__.py, so

from tilelang.distributed import init_distributed

imports the submodule, not the function. Either:

  • Add tilelang/distributed/__init__.py with
    from .init_distributed import init_distributed
  • Or update examples/distributed/example_simple_shift.py:4 to
    from tilelang.distributed.init_distributed import init_distributed
🤖 Prompt for AI Agents
In examples/distributed/example_simple_shift.py around line 4 the statement
imports the submodule instead of the function because tilelang/distributed is
not a package; change the import to reference the specific module path or add an
__init__.py that re-exports the symbol. Either update the file to import
init_distributed from tilelang.distributed.init_distributed, or add
tilelang/distributed/__init__.py that imports and exposes init_distributed so
the original import works.


import ctypes
from tilelang.distributed.common.ipc_ext import _create_ipc_handle, _sync_ipc_handles, _create_tensor
from tilelang.distributed import _create_ipc_handle, _sync_ipc_handles, _create_tensor
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Avoid circular imports: import IPC helpers from the submodule, not the package.

Importing from tilelang.distributed here can create a cycle if init.py re-exports from utils. Import the extension directly via a relative path.

Apply:

-from tilelang.distributed import _create_ipc_handle, _sync_ipc_handles, _create_tensor
+from .common.ipc_ext import _create_ipc_handle, _sync_ipc_handles, _create_tensor
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from tilelang.distributed import _create_ipc_handle, _sync_ipc_handles, _create_tensor
++ b/tilelang/distributed/utils.py
@@ -20,1 +20,1 @@
from .common.ipc_ext import _create_ipc_handle, _sync_ipc_handles, _create_tensor
🤖 Prompt for AI Agents
In tilelang/distributed/utils.py around line 20, the file currently imports IPC
helpers from the package (from tilelang.distributed import _create_ipc_handle,
_sync_ipc_handles, _create_tensor) which can produce a circular import if
__init__.py re-exports utils; change this to a direct relative import from the
extension module (e.g. from ._ipc import _create_ipc_handle, _sync_ipc_handles,
_create_tensor) so utils imports the IPC helpers from the implementation module
rather than the package entry point.

Comment on lines +284 to +299
def wait_eq(signal_tensor: torch.Tensor,
signal: int,
stream: Optional[torch.cuda.Stream] = None,
require_i64=False):
stream = stream or torch.cuda.current_stream()
if signal_tensor.dtype == torch.int32:
(err,) = cuda.cuStreamWaitValue32(
stream.cuda_stream,
signal_tensor.data_ptr(),
signal,
cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ,
)
CUDA_CHECK(err)
else:
raise Exception(f"Unsupported signal dtype {signal_tensor.dtype}")

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

wait_eq: implement 64-bit path and remove/repurpose the unused require_i64 arg.

  • The require_i64 argument is unused.
  • Support int64/uint64 via cuStreamWaitValue64; error on other dtypes.

Apply:

-def wait_eq(signal_tensor: torch.Tensor,
-            signal: int,
-            stream: Optional[torch.cuda.Stream] = None,
-            require_i64=False):
-    stream = stream or torch.cuda.current_stream()
-    if signal_tensor.dtype == torch.int32:
-        (err,) = cuda.cuStreamWaitValue32(
-            stream.cuda_stream,
-            signal_tensor.data_ptr(),
-            signal,
-            cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ,
-        )
-        CUDA_CHECK(err)
-    else:
-        raise Exception(f"Unsupported signal dtype {signal_tensor.dtype}")
+def wait_eq(
+    signal_tensor: torch.Tensor,
+    signal: int,
+    stream: Optional[torch.cuda.Stream] = None,
+    require_i64: bool = False,
+):
+    stream = stream or torch.cuda.current_stream()
+    dt = signal_tensor.dtype
+    if require_i64 and dt not in (torch.int64, torch.uint64):
+        raise TypeError(f"require_i64=True but got dtype {dt}")
+    if dt == torch.int32:
+        (err,) = cuda.cuStreamWaitValue32(
+            stream.cuda_stream,
+            signal_tensor.data_ptr(),
+            signal,
+            cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ,
+        )
+        CUDA_CHECK(err)
+    elif dt in (torch.int64, torch.uint64):
+        (err,) = cuda.cuStreamWaitValue64(
+            stream.cuda_stream,
+            signal_tensor.data_ptr(),
+            signal,
+            cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ,
+        )
+        CUDA_CHECK(err)
+    else:
+        raise TypeError(f"Unsupported signal dtype {dt}; expected int32/int64/uint64")

This also addresses the static-analysis warning about the unused parameter. As per coding guidelines

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def wait_eq(signal_tensor: torch.Tensor,
signal: int,
stream: Optional[torch.cuda.Stream] = None,
require_i64=False):
stream = stream or torch.cuda.current_stream()
if signal_tensor.dtype == torch.int32:
(err,) = cuda.cuStreamWaitValue32(
stream.cuda_stream,
signal_tensor.data_ptr(),
signal,
cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ,
)
CUDA_CHECK(err)
else:
raise Exception(f"Unsupported signal dtype {signal_tensor.dtype}")
def wait_eq(
signal_tensor: torch.Tensor,
signal: int,
stream: Optional[torch.cuda.Stream] = None,
require_i64: bool = False,
):
stream = stream or torch.cuda.current_stream()
dt = signal_tensor.dtype
if require_i64 and dt not in (torch.int64, torch.uint64):
raise TypeError(f"require_i64=True but got dtype {dt}")
if dt == torch.int32:
(err,) = cuda.cuStreamWaitValue32(
stream.cuda_stream,
signal_tensor.data_ptr(),
signal,
cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ,
)
CUDA_CHECK(err)
elif dt in (torch.int64, torch.uint64):
(err,) = cuda.cuStreamWaitValue64(
stream.cuda_stream,
signal_tensor.data_ptr(),
signal,
cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ,
)
CUDA_CHECK(err)
else:
raise TypeError(f"Unsupported signal dtype {dt}; expected int32/int64/uint64")
🧰 Tools
🪛 Ruff (0.13.1)

287-287: Unused function argument: require_i64

(ARG001)


298-298: Create your own exception

(TRY002)


298-298: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In tilelang/distributed/utils.py around lines 284 to 299, the wait_eq function
currently only handles 32-bit signals and leaves the require_i64 parameter
unused; implement support for 64-bit signals and remove or repurpose require_i64
to avoid the unused-parameter lint. Change the body so that if
signal_tensor.dtype is torch.int32 you call cuda.cuStreamWaitValue32(...), if it
is torch.int64 or torch.uint64 you call cuda.cuStreamWaitValue64(...) with
stream.cuda_stream and the tensor.data_ptr(), and invoke CUDA_CHECK on the
returned error in each case; for any other dtype raise an Exception mentioning
the unsupported dtype. Finally remove the unused require_i64 parameter from the
function signature (and update any call sites) or if you prefer repurpose it to
force 64-bit semantics, document and use it to select the 64-bit path.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
tilelang/distributed/__init__.py (2)

3-3: Drop the redundant noqa.

F401 isn’t enabled in our Ruff config, so the # noqa: F401 comment has no effect and just adds noise. Please remove it.


4-4: Re-export explicitly instead of using import *.

Pulling everything from .utils via * makes the package surface hard to reason about and can leak internal names. Please re-export the intended symbols explicitly (e.g., with a curated __all__) so downstream users know what’s public and static analysis stays reliable.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between abfbcd4 and 499a01b.

📒 Files selected for processing (1)
  • tilelang/distributed/__init__.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.13.1)
tilelang/distributed/__init__.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


4-4: from .utils import * used; unable to detect undefined names

(F403)


4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-nvidia
  • GitHub Check: bot-task

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 499a01b and b555cff.

📒 Files selected for processing (2)
  • tilelang/distributed/__init__.py (1 hunks)
  • tilelang/distributed/common/__init__.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.13.1)
tilelang/distributed/common/__init__.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/distributed/__init__.py

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


4-4: from .utils import * used; unable to detect undefined names

(F403)


4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: build-test-nvidia

Comment on lines 3 to 4
from .common import _create_ipc_handle, _sync_ipc_handles, _create_tensor # noqa: F401
from .utils import * # noqa: F401 No newline at end of file
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Clean up package exports so Ruff stops failing.

The # noqa: F401 here triggers RUF100, and the wildcard import needs an explicit F403 suppression (or an explicit __all__) to satisfy Ruff. Making the exports explicit resolves both issues.

Apply this diff:

-from .common import _create_ipc_handle, _sync_ipc_handles, _create_tensor  # noqa: F401
-from .utils import *  # noqa: F401
+from .common import _create_ipc_handle, _sync_ipc_handles, _create_tensor
+from .utils import *  # noqa: F403
+from .utils import __all__ as _utils_all
+
+__all__ = ["_create_ipc_handle", "_sync_ipc_handles", "_create_tensor", *_utils_all]
+del _utils_all

As per static analysis hints

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from .common import _create_ipc_handle, _sync_ipc_handles, _create_tensor # noqa: F401
from .utils import * # noqa: F401
from .common import _create_ipc_handle, _sync_ipc_handles, _create_tensor
from .utils import * # noqa: F403
from .utils import __all__ as _utils_all
__all__ = ["_create_ipc_handle", "_sync_ipc_handles", "_create_tensor", *_utils_all]
del _utils_all
🧰 Tools
🪛 Ruff (0.13.1)

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


4-4: from .utils import * used; unable to detect undefined names

(F403)


4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

🤖 Prompt for AI Agents
In tilelang/distributed/__init__.py around lines 3-4, the wildcard import and
blanket noqa comments trigger Ruff RUF100/F403 errors; replace the star import
with explicit symbol imports from .utils (e.g., import the specific
functions/classes used) and define a module-level __all__ list containing the
exported names (including '_create_ipc_handle', '_sync_ipc_handles',
'_create_tensor' and the explicit names imported from .utils); remove the
wildcard import and broad noqa so static analysis passes.

@@ -0,0 +1,3 @@
"""The common distributed modules"""

from .ipc_ext import _create_ipc_handle, _sync_ipc_handles, _create_tensor # noqa: F401 No newline at end of file
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix Ruff RUF100 by making the re-export explicit.

Ruff flags this # noqa: F401 as an unused directive. Declaring __all__ makes the intent clear and removes the need for the directive, keeping the file lint-clean.

Apply this diff:

-from .ipc_ext import _create_ipc_handle, _sync_ipc_handles, _create_tensor  # noqa: F401
+from .ipc_ext import _create_ipc_handle, _sync_ipc_handles, _create_tensor
+
+__all__ = ["_create_ipc_handle", "_sync_ipc_handles", "_create_tensor"]

As per static analysis hints

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from .ipc_ext import _create_ipc_handle, _sync_ipc_handles, _create_tensor # noqa: F401
from .ipc_ext import _create_ipc_handle, _sync_ipc_handles, _create_tensor
__all__ = ["_create_ipc_handle", "_sync_ipc_handles", "_create_tensor"]
🧰 Tools
🪛 Ruff (0.13.1)

3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

🤖 Prompt for AI Agents
In tilelang/distributed/common/__init__.py around line 3, the current re-export
uses a bare import with a trailing "# noqa: F401" which Ruff flags as an unused
directive; replace the noqa by making the re-export explicit: keep the from
.ipc_ext import _create_ipc_handle, _sync_ipc_handles, _create_tensor import,
remove the "# noqa: F401" comment, and add a module-level __all__ =
["_create_ipc_handle", "_sync_ipc_handles", "_create_tensor"] to explicitly
declare the exported symbols.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tilelang/distributed/utils.py (1)

268-281: Add support for 64-bit signals in set_signal.

The function currently only supports torch.int32 signals. Consider adding support for torch.int64 and torch.uint64 using cuStreamWriteValue64 for consistency with the wait_eq function.

 def set_signal(signal_tensor: torch.Tensor,
                signal: int,
                stream: Optional[torch.cuda.Stream] = None):
     stream = stream or torch.cuda.current_stream()
-    if signal_tensor.dtype == torch.int32:
+    dt = signal_tensor.dtype
+    if dt == torch.int32:
         (err,) = cuda.cuStreamWriteValue32(
             stream.cuda_stream,
             signal_tensor.data_ptr(),
             signal,
             cuda.CUstreamWriteValue_flags.CU_STREAM_WRITE_VALUE_DEFAULT,
         )
         CUDA_CHECK(err)
+    elif dt in (torch.int64, torch.uint64):
+        (err,) = cuda.cuStreamWriteValue64(
+            stream.cuda_stream,
+            signal_tensor.data_ptr(),
+            signal,
+            cuda.CUstreamWriteValue_flags.CU_STREAM_WRITE_VALUE_DEFAULT,
+        )
+        CUDA_CHECK(err)
     else:
-        raise Exception(f"Unsupported signal dtype {signal_tensor.dtype}")
+        raise TypeError(f"Unsupported signal dtype {dt}; expected int32/int64/uint64")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b555cff and 4904152.

📒 Files selected for processing (5)
  • CMakeLists.txt (0 hunks)
  • setup.py (4 hunks)
  • tilelang/distributed/__init__.py (1 hunks)
  • tilelang/distributed/utils.py (3 hunks)
  • tilelang/utils/allocator.py (4 hunks)
💤 Files with no reviewable changes (1)
  • CMakeLists.txt
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/utils/allocator.py (1)
tilelang/utils/cpp/tensor_from_ptr.cpp (2)
  • tensor_from_ptr (46-94)
  • tensor_from_ptr (46-48)
🪛 Ruff (0.13.1)
setup.py

846-846: Unused method argument: verbose

(ARG002)


856-856: Avoid specifying long messages outside the exception class

(TRY003)


859-859: Consider iterable unpacking instead of concatenation

Replace with iterable unpacking

(RUF005)


864-864: subprocess call: check for execution of untrusted input

(S603)


866-868: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/distributed/__init__.py

3-3: from .utils import * used; unable to detect undefined names

(F403)


3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/distributed/utils.py

20-20: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


281-281: Create your own exception

(TRY002)


281-281: Avoid specifying long messages outside the exception class

(TRY003)


287-287: Unused function argument: require_i64

(ARG001)


298-298: Create your own exception

(TRY002)


298-298: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-nvidia
  • GitHub Check: bot-task
🔇 Additional comments (5)
tilelang/utils/allocator.py (1)

202-210: Guard the peer‑tensor path for unsupported modes

Two problems here:

  1. If return_peers=True on a non-distributed allocator, self._group/self._buffer_ptrs are still None, so the very first call raises AttributeError.
  2. Combining return_peers=True with take_ownership=True lets the local tensor free the underlying buffer while the peer tensors still reference it, leading to use-after-free.

We need hard guards before constructing peer_ts.

-        if return_peers:
+        if return_peers:
+            if (
+                not self._is_distributed
+                or self._group is None
+                or self._buffer_ptrs is None
+            ):
+                raise RuntimeError(
+                    "return_peers=True requires a distributed allocator with initialized peer mappings."
+                )
+            if take_ownership:
+                raise ValueError(
+                    "return_peers=True cannot be combined with take_ownership=True; the local tensor would free memory still referenced by peers."
+                )
             peer_ts = []
             for i in range(self._group.size()):
                 if i == self._local_rank:
                     peer_ts.append(t)
                 else:
tilelang/distributed/__init__.py (1)

3-4: Fix circular import and address static analysis issues.

The current imports create potential circular import issues and trigger Ruff warnings. The import from ipc_ext should be from the submodule to avoid cycles, and the wildcard import needs proper handling.

This issue was already flagged in previous review comments. Please implement the suggested fix to resolve the circular import and static analysis warnings.

tilelang/distributed/utils.py (3)

20-20: Fix circular import to avoid dependency cycles.

The import should be from the submodule rather than the package to prevent circular imports when __init__.py re-exports from utils.

This issue was already identified in previous review comments. Please change to a relative import from the submodule as suggested.


284-298: Implement 64-bit support and fix unused parameter.

The function has an unused require_i64 parameter and lacks support for 64-bit signals as mentioned in previous review comments.

This issue was already flagged in previous review comments. Please implement the suggested fix to add 64-bit support and properly handle the require_i64 parameter.


301-304: LGTM! Well-implemented CUDA stream priority utility.

The function correctly retrieves the maximum CUDA stream priority using the appropriate CUDA runtime API and includes proper error checking.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
requirements-dev.txt (1)

26-27: Guard CUDA dependency for non‑CUDA hosts; confirm 12.x floor.

To avoid breaking dev setups on macOS and other non‑CUDA environments, gate cuda‑python. Also verify that CUDA 12.x is the true minimum required by the new IPC paths.

Proposed diff:

-setuptools
-cuda-python>=12.0.0
+setuptools
+cuda-python>=12.0.0; platform_system != "Darwin"

Optional:

  • Move CUDA deps to an extra (e.g., tilescale[cuda]) or a separate requirements-cuda.txt used only in distributed jobs.
  • If Windows also isn’t supported in your CI, add and platform_system == "Linux" to be stricter.

Please confirm whether features added in this PR require CUDA 12.x specifically, or if >=11.8 would suffice.

requirements-test.txt (1)

30-31: Relax Ninja pin; align CUDA guard with dev.

Pinning ninja==1.10.0 is unnecessarily strict and dated; use a lower bound compatible with CMake and your builds. Also mirror the CUDA guard from dev to prevent failures on non‑CUDA hosts.

Proposed diff:

-ninja==1.10.0
-cuda-python>=12.0.0
+ninja>=1.10
+cuda-python>=12.0.0; platform_system != "Darwin"

Follow‑ups (optional):

  • If only some tests need Ninja/CUDA, move them into conditional extras or per‑job requirements to reduce CI surface area.
  • There are duplicate entries for Cython in this file (Lines 4 and 9); consider deduping as housekeeping.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4b664ab and 8c9c0e8.

📒 Files selected for processing (2)
  • requirements-dev.txt (1 hunks)
  • requirements-test.txt (1 hunks)

@chengyupku chengyupku merged commit 2cc1a2c into main Sep 30, 2025
9 of 16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants