- 
                Notifications
    You must be signed in to change notification settings 
- Fork 2
          [Feature][Example] Support  return_peers in _allocate_tensor and add ag_gemm_ipc example
          #25
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…g_gemm_ipc example
| Important Review skippedReview was skipped as selected files did not have any reviewable changes. 💤 Files selected but had no reviewable changes (2)
 You can disable this status message by setting the  WalkthroughAdds an all-gather GEMM IPC example, re-exports IPC helpers and distributed utilities at package root, adds signaling/stream helpers, extends allocator/tensor APIs to optionally return per-peer tensors, moves PyBind extension wiring into setup.py (removing CMake custom targets), updates many import paths and small example cleanups, and splits CI into distributed vs non-distributed test passes. Changes
 Sequence Diagram(s)sequenceDiagram
  autonumber
  participant User
  participant Script as example_allgather_gemm_ipc.py
  participant Dist as tilelang.distributed
  participant Alloc as Allocator
  participant IPC as IPC Engine (cp_engine)
  participant GEMM as gemm_kernel
  participant Torch as torch.distributed
  User->>Script: main(args)
  Script->>Dist: init_distributed(...)
  Script->>Alloc: allocate A,B,C, ag_buffer (return_peers?)
  Note over Alloc: ag_buffer may return local + per-peer tensors
  Script->>Dist: set_signal / wait_eq (barriers)
  Script->>IPC: cp_engine_producer_all_gather_put(local, ag_buffer, signal...)
  IPC-->>Script: peer signals
  Script->>GEMM: launch tiled GEMM (A_gathered, B)
  GEMM-->>Script: writes C
  par Reference
    Script->>Torch: all_gather_into_tensor(...)
    Torch-->>Script: A_ref
    Script->>Torch: matmul(A_ref, B)
  end
  Script-->>User: validate results
sequenceDiagram
  autonumber
  participant Caller
  participant TensorAPI as tilelang.utils.tensor.tensor
  participant Alloc as BaseAllocator._allocate_tensor
  Caller->>TensorAPI: tensor(shape, dtype, allocator, return_peers=True)
  TensorAPI->>Alloc: _allocate_tensor(shape, dtype, return_peers=True)
  Alloc-->>TensorAPI: list[Tensor] (local + per-peer views)
  TensorAPI-->>Caller: list[Tensor]
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
 Suggested reviewers
 Poem
 Pre-merge checks and finishing touches✅ Passed checks (3 passed)
 Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment  | 
| 👋 Hi! Thank you for contributing to the TileLang project. Please remember to run  We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️  Outside diff range comments (1)
tilelang/utils/allocator.py (1)
172-201: Prevent invalid cudaFree on suballocations whentake_ownership=TrueIf
take_ownership=Trueon a suballocation, the tensor deleter will callcudaFreeon a non-base pointer, which is undefined and can crash. Either disallowtake_ownershipfor suballocations or require the allocation to span the entire backing buffer.Apply this diff to guard against unsafe ownership:
numel = _prod_shape(shape) itemsize = _element_size_bytes(dtype) bytes_needed = numel * itemsize bytes_alloc = _align_up(bytes_needed, self._align) current_offset = int(self._ptr.value) - int(self._base_ptr.value) + # Disallow taking ownership for suballocations: cudaFree must only free the base pointer. + if take_ownership and not (current_offset == 0 and bytes_alloc == self.size): + raise NotImplementedError( + "take_ownership=True is only supported when the allocation spans the entire pre-allocated buffer." + )
🧹 Nitpick comments (5)
tilelang/utils/allocator.py (1)
166-171: Return type should reflect multi-tensor return whenreturn_peers=TrueThe method can return either a single tensor or a list; update the type hint to avoid confusion and downstream type issues.
Apply this diff:
- def _allocate_tensor(self, - shape: Tuple[int, ...], - dtype: torch.dtype, - return_peers=False, - take_ownership: bool = False) -> torch.Tensor: + def _allocate_tensor(self, + shape: Tuple[int, ...], + dtype: torch.dtype, + return_peers: bool = False, + take_ownership: bool = False) -> Union[torch.Tensor, list[torch.Tensor]]:tilelang/distributed/utils.py (2)
268-282: Support int64 signals and raise specific exceptionsBroaden dtype support and avoid generic Exception. This also addresses static analysis warnings.
Apply this diff:
-def set_signal(signal_tensor: torch.Tensor, - signal: int, - stream: Optional[torch.cuda.Stream] = None): +def set_signal(signal_tensor: torch.Tensor, + signal: int, + stream: Optional[torch.cuda.Stream] = None): stream = stream or torch.cuda.current_stream() if signal_tensor.dtype == torch.int32: (err,) = cuda.cuStreamWriteValue32( stream.cuda_stream, signal_tensor.data_ptr(), signal, cuda.CUstreamWriteValue_flags.CU_STREAM_WRITE_VALUE_DEFAULT, ) CUDA_CHECK(err) - else: - raise Exception(f"Unsupported signal dtype {signal_tensor.dtype}") + elif signal_tensor.dtype == torch.int64: + (err,) = cuda.cuStreamWriteValue64( + stream.cuda_stream, + signal_tensor.data_ptr(), + signal, + cuda.CUstreamWriteValue_flags.CU_STREAM_WRITE_VALUE_DEFAULT, + ) + CUDA_CHECK(err) + else: + raise TypeError(f"Unsupported signal dtype {signal_tensor.dtype}")
284-299: Honorrequire_i64and add int64 wait path; avoid generic ExceptionImplement the 64-bit wait when requested or when the tensor is int64. Use a specific exception type.
Apply this diff:
-def wait_eq(signal_tensor: torch.Tensor, - signal: int, - stream: Optional[torch.cuda.Stream] = None, - require_i64=False): +def wait_eq(signal_tensor: torch.Tensor, + signal: int, + stream: Optional[torch.cuda.Stream] = None, + require_i64=False): stream = stream or torch.cuda.current_stream() - if signal_tensor.dtype == torch.int32: + if signal_tensor.dtype == torch.int32 and not require_i64: (err,) = cuda.cuStreamWaitValue32( stream.cuda_stream, signal_tensor.data_ptr(), signal, cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ, ) CUDA_CHECK(err) - else: - raise Exception(f"Unsupported signal dtype {signal_tensor.dtype}") + elif signal_tensor.dtype == torch.int64 or require_i64: + (err,) = cuda.cuStreamWaitValue64( + stream.cuda_stream, + signal_tensor.data_ptr(), + signal, + cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ, + ) + CUDA_CHECK(err) + else: + raise TypeError(f"Unsupported signal dtype {signal_tensor.dtype}")examples/distributed/example_allgather_gemm_ipc.py (2)
94-96: Clarify ag op signature: renameNtoKto match semanticsThe arg forwarded to the copy path is K. Rename for correctness.
Apply this diff:
-def ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, N, signal_target, rank, group, +def ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank, group, local_world_size, world_size, gemm_kernel, ag_stream):Also update the call site below.
99-103: Forward the renamed K to the producerKeep parameter usage consistent with the rename.
Apply this diff:
- with torch.cuda.stream(ag_stream): - cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, N, signal_target, - rank, local_world_size, world_size, ag_stream) + with torch.cuda.stream(ag_stream): + cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, K, signal_target, + rank, local_world_size, world_size, ag_stream)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
- examples/distributed/example_all_to_all.py(0 hunks)
- examples/distributed/example_allgather_gemm_ipc.py(1 hunks)
- tilelang/distributed/utils.py(2 hunks)
- tilelang/utils/allocator.py(3 hunks)
- tilelang/utils/tensor.py(1 hunks)
💤 Files with no reviewable changes (1)
- examples/distributed/example_all_to_all.py
🧰 Additional context used
🧬 Code graph analysis (3)
tilelang/utils/tensor.py (2)
tilelang/utils/allocator.py (4)
BaseAllocator(71-234)
initialized(163-164)
device(114-115)
_allocate_tensor(166-218)tilelang/utils/target.py (1)
parse_device(87-100)
tilelang/utils/allocator.py (1)
tilelang/utils/cpp/tensor_from_ptr.cpp (2)
tensor_from_ptr(46-94)
tensor_from_ptr(46-48)
examples/distributed/example_allgather_gemm_ipc.py (5)
tilelang/distributed/utils.py (3)
init_dist(34-56)
set_signal(268-281)
wait_eq(284-298)tilelang/env.py (1)
disable_cache(247-248)tilelang/utils/allocator.py (2)
get_allocator(237-249)
device(114-115)tilelang/jit/kernel.py (1)
initialize(400-409)tilelang/utils/tensor.py (1)
tensor(45-58)
🪛 Ruff (0.13.1)
tilelang/distributed/utils.py
281-281: Create your own exception
(TRY002)
281-281: Avoid specifying long messages outside the exception class
(TRY003)
287-287: Unused function argument: require_i64
(ARG001)
298-298: Create your own exception
(TRY002)
298-298: Avoid specifying long messages outside the exception class
(TRY003)
examples/distributed/example_allgather_gemm_ipc.py
83-83: Unpacked variable err is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
184-184: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: bot-task
🔇 Additional comments (2)
tilelang/utils/allocator.py (1)
128-128: Correct fix: assign pointer via.valueCopying the device pointer value into
_ptris correct and fixes the previous aliasing bug.tilelang/utils/tensor.py (1)
56-57: Forwardingreturn_peers— LGTMPropagating
return_peersto the allocator is correct and aligns the public API with allocator behavior.
| from tilelang.distributed.utils import init_dist | ||
| from cuda import cudart | ||
| from tilelang.distributed.utils import set_signal, wait_eq | ||
|  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make cudart import compatible with cuda-python >= 12.8 and import CUDA_CHECK
Directly importing from cuda import cudart breaks on newer cuda-python. Use the same version-gated import as utils and import CUDA_CHECK for error handling.
Apply this diff:
-from tilelang.distributed.utils import init_dist
-from cuda import cudart
-from tilelang.distributed.utils import set_signal, wait_eq
+from tilelang.distributed.utils import init_dist
+import importlib.metadata
+from packaging import version
+cuda_python_version = importlib.metadata.version("cuda-python")
+if version.parse(cuda_python_version) >= version.parse("12.8.0"):
+    from cuda.bindings import runtime as cudart
+else:
+    from cuda import cudart
+from tilelang.distributed.utils import set_signal, wait_eq, CUDA_CHECK📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from tilelang.distributed.utils import init_dist | |
| from cuda import cudart | |
| from tilelang.distributed.utils import set_signal, wait_eq | |
| from tilelang.distributed.utils import init_dist | |
| import importlib.metadata | |
| from packaging import version | |
| cuda_python_version = importlib.metadata.version("cuda-python") | |
| if version.parse(cuda_python_version) >= version.parse("12.8.0"): | |
| from cuda.bindings import runtime as cudart | |
| else: | |
| from cuda import cudart | |
| from tilelang.distributed.utils import set_signal, wait_eq, CUDA_CHECK | 
🤖 Prompt for AI Agents
In examples/distributed/example_allgather_gemm_ipc.py around lines 8 to 11, the
direct "from cuda import cudart" import breaks on cuda-python >= 12.8; update
the imports to match the version-gated pattern used in
tilelang.distributed.utils (perform a conditional import that tries the new
module path and falls back to the old one) and also import CUDA_CHECK for error
handling, then replace usages of cudart error checks to use CUDA_CHECK where
appropriate.
| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): | ||
| A_shared = T.alloc_shared((block_M, block_K), dtype) | ||
| B_shared = T.alloc_shared((block_K, block_N), dtype) | ||
| C_local = T.alloc_fragment((block_M, block_N), accum_dtype) | ||
|  | ||
| T.clear(C_local) | ||
| for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): | ||
| T.copy(A[by * block_M, k * block_K], A_shared) | ||
| T.copy(B[k * block_K, bx * block_N], B_shared) | ||
| T.gemm(A_shared, B_shared, C_local) | ||
| T.copy(C_local, C[by * block_M, bx * block_N]) | ||
|  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix grid size: using N instead of N_per_rank causes out-of-bounds on B/C
Kernel grid must reflect local N dimension per rank. Using global N overshoots.
Apply this diff:
-        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
+        with T.Kernel(T.ceildiv(N // num_rank, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): | |
| A_shared = T.alloc_shared((block_M, block_K), dtype) | |
| B_shared = T.alloc_shared((block_K, block_N), dtype) | |
| C_local = T.alloc_fragment((block_M, block_N), accum_dtype) | |
| T.clear(C_local) | |
| for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): | |
| T.copy(A[by * block_M, k * block_K], A_shared) | |
| T.copy(B[k * block_K, bx * block_N], B_shared) | |
| T.gemm(A_shared, B_shared, C_local) | |
| T.copy(C_local, C[by * block_M, bx * block_N]) | |
| with T.Kernel(T.ceildiv(N // num_rank, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): | |
| A_shared = T.alloc_shared((block_M, block_K), dtype) | |
| B_shared = T.alloc_shared((block_K, block_N), dtype) | |
| C_local = T.alloc_fragment((block_M, block_N), accum_dtype) | |
| T.clear(C_local) | |
| for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): | |
| T.copy(A[by * block_M, k * block_K], A_shared) | |
| T.copy(B[k * block_K, bx * block_N], B_shared) | |
| T.gemm(A_shared, B_shared, C_local) | |
| T.copy(C_local, C[by * block_M, bx * block_N]) | 
🤖 Prompt for AI Agents
In examples/distributed/example_allgather_gemm_ipc.py around lines 33 to 44 the
kernel grid uses the global N which causes out-of-bounds accesses for B/C on
each rank; change the x-dimension grid calculation to use the per-rank local
width (replace T.ceildiv(N, block_N) with T.ceildiv(N_per_rank, block_N)) and
ensure any indexing/limits that assume the x-dimension (reads from B and writes
to C) are bounded by N_per_rank so the kernel only addresses the local slice.
| def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, N, | ||
| signal_target, rank, local_world_size, world_size, | ||
| intranode_ag_stream): | ||
| local_rank = rank % local_world_size | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Parameter naming bug: N here is actually K (second dim of A)
Misnaming invites misuse. Rename to K and propagate to all uses.
Apply this diff:
-def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, N,
+def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, K,
                                       signal_target, rank, local_world_size, world_size,
                                       intranode_ag_stream):Follow-up diffs below update its uses.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, N, | |
| signal_target, rank, local_world_size, world_size, | |
| intranode_ag_stream): | |
| local_rank = rank % local_world_size | |
| def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, K, | |
| signal_target, rank, local_world_size, world_size, | |
| intranode_ag_stream): | |
| local_rank = rank % local_world_size | 
🤖 Prompt for AI Agents
In examples/distributed/example_allgather_gemm_ipc.py around lines 48 to 51, the
function parameter named N is incorrect — it represents K (the second dimension
of A); rename the parameter to K in the function signature and update every call
site and internal reference to use K instead of N, ensuring variable names and
any related comments/docstrings reflect the new name so behavior remains
unchanged.
| for i in range(1, local_world_size): | ||
| segment = rank * M_per_rank * N | ||
| local_dst_rank = (local_rank + local_world_size - i) % local_world_size | ||
| src_ptr = ag_buffer[local_rank].data_ptr() + segment * local_tensor.element_size() | ||
| dst_ptr = ag_buffer[local_dst_rank].data_ptr() + segment * local_tensor.element_size() | ||
| # Using copy engine to perform intranode transmission | ||
| # Sending rank-th local tensor to other ranks inside the node. | ||
| (err,) = cudart.cudaMemcpyAsync( | ||
| dst_ptr, | ||
| src_ptr, | ||
| M_per_rank * N * local_tensor.element_size(), | ||
| cudart.cudaMemcpyKind.cudaMemcpyDefault, | ||
| intranode_ag_stream.cuda_stream, | ||
| ) | ||
| # Notify the peer that the transmission is done. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Use K for segment sizing and add CUDA error checks
Copy size/offset must be based on the width of A (K). Also check cudaMemcpyAsync return codes.
Apply this diff:
-    for i in range(1, local_world_size):
-        segment = rank * M_per_rank * N
+    for i in range(1, local_world_size):
+        segment = rank * M_per_rank * K
         local_dst_rank = (local_rank + local_world_size - i) % local_world_size
-        src_ptr = ag_buffer[local_rank].data_ptr() + segment * local_tensor.element_size()
-        dst_ptr = ag_buffer[local_dst_rank].data_ptr() + segment * local_tensor.element_size()
+        src_ptr = ag_buffer[local_rank].data_ptr() + segment * local_tensor.element_size()
+        dst_ptr = ag_buffer[local_dst_rank].data_ptr() + segment * local_tensor.element_size()
         # Using copy engine to perform intranode transmission
         # Sending rank-th local tensor to other ranks inside the node.
         (err,) = cudart.cudaMemcpyAsync(
             dst_ptr,
             src_ptr,
-            M_per_rank * N * local_tensor.element_size(),
+            M_per_rank * K * local_tensor.element_size(),
             cudart.cudaMemcpyKind.cudaMemcpyDefault,
             intranode_ag_stream.cuda_stream,
         )
+        CUDA_CHECK(err)
         # Notify the peer that the transmission is done.
         set_signal(signal_buffer[local_dst_rank][rank], signal_target, intranode_ag_stream)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| for i in range(1, local_world_size): | |
| segment = rank * M_per_rank * N | |
| local_dst_rank = (local_rank + local_world_size - i) % local_world_size | |
| src_ptr = ag_buffer[local_rank].data_ptr() + segment * local_tensor.element_size() | |
| dst_ptr = ag_buffer[local_dst_rank].data_ptr() + segment * local_tensor.element_size() | |
| # Using copy engine to perform intranode transmission | |
| # Sending rank-th local tensor to other ranks inside the node. | |
| (err,) = cudart.cudaMemcpyAsync( | |
| dst_ptr, | |
| src_ptr, | |
| M_per_rank * N * local_tensor.element_size(), | |
| cudart.cudaMemcpyKind.cudaMemcpyDefault, | |
| intranode_ag_stream.cuda_stream, | |
| ) | |
| # Notify the peer that the transmission is done. | |
| for i in range(1, local_world_size): | |
| segment = rank * M_per_rank * K | |
| local_dst_rank = (local_rank + local_world_size - i) % local_world_size | |
| src_ptr = ag_buffer[local_rank].data_ptr() + segment * local_tensor.element_size() | |
| dst_ptr = ag_buffer[local_dst_rank].data_ptr() + segment * local_tensor.element_size() | |
| # Using copy engine to perform intranode transmission | |
| # Sending rank-th local tensor to other ranks inside the node. | |
| (err,) = cudart.cudaMemcpyAsync( | |
| dst_ptr, | |
| src_ptr, | |
| M_per_rank * K * local_tensor.element_size(), | |
| cudart.cudaMemcpyKind.cudaMemcpyDefault, | |
| intranode_ag_stream.cuda_stream, | |
| ) | |
| CUDA_CHECK(err) | |
| # Notify the peer that the transmission is done. | |
| set_signal(signal_buffer[local_dst_rank][rank], signal_target, intranode_ag_stream) | 
| for i in range(1, n_nodes): | ||
| recv_rank = local_rank + (node_rank + n_nodes - i) % n_nodes * local_world_size | ||
| recv_segment = recv_rank * M_per_rank * N | ||
| # Waiting for the internode data ready | ||
| wait_eq(signal_buffer[local_rank][recv_rank], signal_target, intranode_ag_stream) | ||
| src_ptr = ag_buffer[local_rank].data_ptr() + recv_segment * local_tensor.element_size() | ||
| for j in range(1, local_world_size): | ||
| local_dst_rank = (local_rank + local_world_size - j) % local_world_size | ||
| dst_ptr = ag_buffer[local_dst_rank].data_ptr( | ||
| ) + recv_segment * local_tensor.element_size() | ||
| # Sending (local_rank + j*local_world_size) % world_size -th local tensor to other ranks inside the node. | ||
| (err,) = cudart.cudaMemcpyAsync( | ||
| dst_ptr, | ||
| src_ptr, | ||
| M_per_rank * N * local_tensor.element_size(), | ||
| cudart.cudaMemcpyKind.cudaMemcpyDefault, | ||
| intranode_ag_stream.cuda_stream, | ||
| ) | ||
| # Notify the peer that the transmission is done. | ||
| set_signal(signal_buffer[local_dst_rank][recv_rank], signal_target, intranode_ag_stream) | ||
|  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Likewise for inter-node segment sizing and error checks
Mirror the K-based sizing and add error handling.
Apply this diff:
-    for i in range(1, n_nodes):
-        recv_rank = local_rank + (node_rank + n_nodes - i) % n_nodes * local_world_size
-        recv_segment = recv_rank * M_per_rank * N
+    for i in range(1, n_nodes):
+        recv_rank = local_rank + (node_rank + n_nodes - i) % n_nodes * local_world_size
+        recv_segment = recv_rank * M_per_rank * K
         # Waiting for the internode data ready
         wait_eq(signal_buffer[local_rank][recv_rank], signal_target, intranode_ag_stream)
         src_ptr = ag_buffer[local_rank].data_ptr() + recv_segment * local_tensor.element_size()
         for j in range(1, local_world_size):
             local_dst_rank = (local_rank + local_world_size - j) % local_world_size
             dst_ptr = ag_buffer[local_dst_rank].data_ptr(
             ) + recv_segment * local_tensor.element_size()
             # Sending (local_rank + j*local_world_size) % world_size -th local tensor to other ranks inside the node.
             (err,) = cudart.cudaMemcpyAsync(
                 dst_ptr,
                 src_ptr,
-                M_per_rank * N * local_tensor.element_size(),
+                M_per_rank * K * local_tensor.element_size(),
                 cudart.cudaMemcpyKind.cudaMemcpyDefault,
                 intranode_ag_stream.cuda_stream,
             )
+            CUDA_CHECK(err)
             # Notify the peer that the transmission is done.
             set_signal(signal_buffer[local_dst_rank][recv_rank], signal_target, intranode_ag_stream)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| for i in range(1, n_nodes): | |
| recv_rank = local_rank + (node_rank + n_nodes - i) % n_nodes * local_world_size | |
| recv_segment = recv_rank * M_per_rank * N | |
| # Waiting for the internode data ready | |
| wait_eq(signal_buffer[local_rank][recv_rank], signal_target, intranode_ag_stream) | |
| src_ptr = ag_buffer[local_rank].data_ptr() + recv_segment * local_tensor.element_size() | |
| for j in range(1, local_world_size): | |
| local_dst_rank = (local_rank + local_world_size - j) % local_world_size | |
| dst_ptr = ag_buffer[local_dst_rank].data_ptr( | |
| ) + recv_segment * local_tensor.element_size() | |
| # Sending (local_rank + j*local_world_size) % world_size -th local tensor to other ranks inside the node. | |
| (err,) = cudart.cudaMemcpyAsync( | |
| dst_ptr, | |
| src_ptr, | |
| M_per_rank * N * local_tensor.element_size(), | |
| cudart.cudaMemcpyKind.cudaMemcpyDefault, | |
| intranode_ag_stream.cuda_stream, | |
| ) | |
| # Notify the peer that the transmission is done. | |
| set_signal(signal_buffer[local_dst_rank][recv_rank], signal_target, intranode_ag_stream) | |
| for i in range(1, n_nodes): | |
| recv_rank = local_rank + (node_rank + n_nodes - i) % n_nodes * local_world_size | |
| recv_segment = recv_rank * M_per_rank * K | |
| # Waiting for the internode data ready | |
| wait_eq(signal_buffer[local_rank][recv_rank], signal_target, intranode_ag_stream) | |
| src_ptr = ag_buffer[local_rank].data_ptr() + recv_segment * local_tensor.element_size() | |
| for j in range(1, local_world_size): | |
| local_dst_rank = (local_rank + local_world_size - j) % local_world_size | |
| dst_ptr = ag_buffer[local_dst_rank].data_ptr( | |
| ) + recv_segment * local_tensor.element_size() | |
| # Sending (local_rank + j*local_world_size) % world_size -th local tensor to other ranks inside the node. | |
| (err,) = cudart.cudaMemcpyAsync( | |
| dst_ptr, | |
| src_ptr, | |
| M_per_rank * K * local_tensor.element_size(), | |
| cudart.cudaMemcpyKind.cudaMemcpyDefault, | |
| intranode_ag_stream.cuda_stream, | |
| ) | |
| CUDA_CHECK(err) | |
| # Notify the peer that the transmission is done. | |
| set_signal(signal_buffer[local_dst_rank][recv_rank], signal_target, intranode_ag_stream) | 
🧰 Tools
🪛 Ruff (0.13.1)
83-83: Unpacked variable err is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
In examples/distributed/example_allgather_gemm_ipc.py around lines 72 to 92, the
inter-node segment sizing currently uses M_per_rank but should mirror the
K-based sizing (use K_per_rank) and it lacks error checks on cudaMemcpyAsync;
change recv_segment and the byte count passed to cudaMemcpyAsync to use
K_per_rank * N (times element_size), check the returned err from
cudart.cudaMemcpyAsync and raise or log a RuntimeError if err is non-zero, and
only call set_signal after the copy completed successfully (or after handling
the error).
| tilelang_C = ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank, | ||
| group, num_local_ranks, num_local_ranks, kernel, ag_stream) | ||
|  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pass the correct world size (global) to ag_gemm_op
You currently pass num_local_ranks for both local and global sizes. Use the true num_ranks for world_size.
Apply this diff:
-    tilelang_C = ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank,
-                            group, num_local_ranks, num_local_ranks, kernel, ag_stream)
+    tilelang_C = ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank,
+                            group, num_local_ranks, num_ranks, kernel, ag_stream)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| tilelang_C = ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank, | |
| group, num_local_ranks, num_local_ranks, kernel, ag_stream) | |
| tilelang_C = ag_gemm_op( | |
| A, B, C, | |
| ag_buffer, signal_buffer, | |
| M_per_rank, K, | |
| signal_target, rank, | |
| group, | |
| num_local_ranks, num_ranks, | |
| kernel, ag_stream | |
| ) | 
🤖 Prompt for AI Agents
In examples/distributed/example_allgather_gemm_ipc.py around lines 173 to 175,
the call to ag_gemm_op passes num_local_ranks twice (for both local and
global/world sizes); change the second num_local_ranks argument to the true
global world size variable num_ranks so the function receives local_size,
world_size correctly—update the ag_gemm_op call to pass num_ranks as the
world_size parameter.
| if return_peers: | ||
| peer_ts = [] | ||
| for i in range(self._group.size()): | ||
| if i == self._local_rank: | ||
| peer_ts.append(t) | ||
| else: | ||
| peer_ptr_val = int(self._buffer_ptrs[i]) + current_offset | ||
| peer_t = tensor_from_ptr(peer_ptr_val, shape, dtype_str, self._device, False) | ||
| peer_ts.append(peer_t) | ||
|  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard return_peers usage for non-distributed allocators
When return_peers=True, code assumes _group and _buffer_ptrs are initialized. Without distribution, this raises at runtime.
Apply this diff:
-        if return_peers:
+        if return_peers:
+            if not self._is_distributed or self._group is None or self._buffer_ptrs is None:
+                raise RuntimeError(
+                    "return_peers=True requires a distributed allocator with initialized peer mappings."
+                )
             peer_ts = []
             for i in range(self._group.size()):
                 if i == self._local_rank:
                     peer_ts.append(t)
                 else:
                     peer_ptr_val = int(self._buffer_ptrs[i]) + current_offset
                     peer_t = tensor_from_ptr(peer_ptr_val, shape, dtype_str, self._device, False)
                     peer_ts.append(peer_t)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if return_peers: | |
| peer_ts = [] | |
| for i in range(self._group.size()): | |
| if i == self._local_rank: | |
| peer_ts.append(t) | |
| else: | |
| peer_ptr_val = int(self._buffer_ptrs[i]) + current_offset | |
| peer_t = tensor_from_ptr(peer_ptr_val, shape, dtype_str, self._device, False) | |
| peer_ts.append(peer_t) | |
| if return_peers: | |
| if not self._is_distributed or self._group is None or self._buffer_ptrs is None: | |
| raise RuntimeError( | |
| "return_peers=True requires a distributed allocator with initialized peer mappings." | |
| ) | |
| peer_ts = [] | |
| for i in range(self._group.size()): | |
| if i == self._local_rank: | |
| peer_ts.append(t) | |
| else: | |
| peer_ptr_val = int(self._buffer_ptrs[i]) + current_offset | |
| peer_t = tensor_from_ptr(peer_ptr_val, shape, dtype_str, self._device, False) | |
| peer_ts.append(peer_t) | 
🤖 Prompt for AI Agents
In tilelang/utils/allocator.py around lines 202 to 211, the block that builds
peer_ts when return_peers=True assumes self._group and self._buffer_ptrs exist
and is distributed; guard this by checking whether the allocator is running in
distributed mode (e.g., if self._group is None or not getattr(self,
"_buffer_ptrs", None):) before using them, and if the check fails raise a clear
ValueError or RuntimeError stating that return_peers requires a distributed
allocator (or instruct the caller to set return_peers=False), otherwise proceed
to build peer_ts as before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (1)
examples/distributed/example_allgather_gemm_ipc.py (1)
94-106: Align ag_gemm_op signature with K and compute local_rankPrevents confusion and enables correct local peer indexing later.
Apply this diff:
-def ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, N, signal_target, rank, group, +def ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank, group, local_world_size, world_size, gemm_kernel, ag_stream): dist.barrier(group) # all_gather A to ag_buffer with torch.cuda.stream(ag_stream): - cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, N, signal_target, + cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank, local_world_size, world_size, ag_stream) current_stream = torch.cuda.current_stream() current_stream.wait_stream(ag_stream) + local_rank = rank % local_world_size
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
- examples/distributed/example_allgather_gemm_ipc.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/distributed/example_allgather_gemm_ipc.py (6)
tilelang/distributed/utils.py (3)
init_dist(34-56)
set_signal(268-281)
wait_eq(284-298)tilelang/env.py (1)
disable_cache(247-248)tilelang/utils/allocator.py (2)
get_allocator(237-249)
device(114-115)tilelang/jit/__init__.py (1)
compile(32-81)tilelang/jit/kernel.py (1)
initialize(400-409)tilelang/utils/tensor.py (1)
tensor(45-58)
🪛 Ruff (0.13.1)
examples/distributed/example_allgather_gemm_ipc.py
83-83: Unpacked variable err is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
183-183: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (8)
examples/distributed/example_allgather_gemm_ipc.py (8)
48-51: Rename parameter: N here is K (width of A)Avoid confusion and downstream misuse by renaming.
Apply this diff:
-def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, N, +def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank, local_world_size, world_size, intranode_ag_stream):Follow-up: update all internal references and call sites to use K (see separate comments).
119-127: Torch reference path looks goodUsing all_gather_into_tensor + matmul is correct for the reference result.
8-11: Gate cudart import for cuda-python >= 12.8 and import CUDA_CHECKDirect
from cuda import cudartbreaks on newer cuda-python. Mirror the version-gated import pattern and bring in CUDA_CHECK for error handling.Apply this diff:
-from tilelang.distributed.utils import init_dist -from cuda import cudart -from tilelang.distributed.utils import set_signal, wait_eq +from tilelang.distributed.utils import init_dist +import importlib.metadata +from packaging import version +cuda_python_version = importlib.metadata.version("cuda-python") +if version.parse(cuda_python_version) >= version.parse("12.8.0"): + from cuda.bindings import runtime as cudart +else: + from cuda import cudart +from tilelang.distributed.utils import set_signal, wait_eq, CUDA_CHECK
33-44: Fix grid x-dimension to use local N per rankKernel grid uses global N, causing out-of-bounds for B/C which are sized to N // num_rank.
Apply this diff:
- with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(N // num_rank, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
55-71: Use K for segment sizing and check cudaMemcpyAsync errorsByte counts/offsets must use K. Also, check return codes.
Apply this diff:
- for i in range(1, local_world_size): - segment = rank * M_per_rank * N + for i in range(1, local_world_size): + segment = rank * M_per_rank * K local_dst_rank = (local_rank + local_world_size - i) % local_world_size src_ptr = ag_buffer[local_rank].data_ptr() + segment * local_tensor.element_size() dst_ptr = ag_buffer[local_dst_rank].data_ptr() + segment * local_tensor.element_size() # Using copy engine to perform intranode transmission # Sending rank-th local tensor to other ranks inside the node. - (err,) = cudart.cudaMemcpyAsync( + (err,) = cudart.cudaMemcpyAsync( dst_ptr, src_ptr, - M_per_rank * N * local_tensor.element_size(), + M_per_rank * K * local_tensor.element_size(), cudart.cudaMemcpyKind.cudaMemcpyDefault, intranode_ag_stream.cuda_stream, ) + CUDA_CHECK(err) # Notify the peer that the transmission is done. set_signal(signal_buffer[local_dst_rank][rank], signal_target, intranode_ag_stream)
72-92: Mirror K-based sizing for inter-node path and add error checksAlso resolves “err is never used” warning.
Apply this diff:
- for i in range(1, n_nodes): - recv_rank = local_rank + (node_rank + n_nodes - i) % n_nodes * local_world_size - recv_segment = recv_rank * M_per_rank * N + for i in range(1, n_nodes): + recv_rank = local_rank + (node_rank + n_nodes - i) % n_nodes * local_world_size + recv_segment = recv_rank * M_per_rank * K # Waiting for the internode data ready wait_eq(signal_buffer[local_rank][recv_rank], signal_target, intranode_ag_stream) src_ptr = ag_buffer[local_rank].data_ptr() + recv_segment * local_tensor.element_size() for j in range(1, local_world_size): local_dst_rank = (local_rank + local_world_size - j) % local_world_size dst_ptr = ag_buffer[local_dst_rank].data_ptr( ) + recv_segment * local_tensor.element_size() # Sending (local_rank + j*local_world_size) % world_size -th local tensor to other ranks inside the node. - (err,) = cudart.cudaMemcpyAsync( + (err,) = cudart.cudaMemcpyAsync( dst_ptr, src_ptr, - M_per_rank * N * local_tensor.element_size(), + M_per_rank * K * local_tensor.element_size(), cudart.cudaMemcpyKind.cudaMemcpyDefault, intranode_ag_stream.cuda_stream, ) + CUDA_CHECK(err) # Notify the peer that the transmission is done. set_signal(signal_buffer[local_dst_rank][recv_rank], signal_target, intranode_ag_stream)
172-174: Pass global world_size to ag_gemm_opYou currently pass num_local_ranks for both local and global sizes.
Apply this diff:
- tilelang_C = ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank, - group, num_local_ranks, num_local_ranks, kernel, ag_stream) + tilelang_C = ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank, + group, num_local_ranks, num_ranks, kernel, ag_stream)
12-12: Module‐level disable_cache alias is defined
disable_cache = CacheState.disableis declared intilelang/env.pyand re-exported intilelang/__init__.py, sotilelang.disable_cache()is valid. No changes needed.Likely an incorrect or invalid review comment.
| torch.cuda.synchronize() | ||
| torch.distributed.barrier(group) | ||
| gemm_kernel(ag_buffer[rank], B, C) | ||
| torch.cuda.synchronize() | ||
| torch.distributed.barrier(group) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use local peer for kernel input; drop redundant sync
Fix out-of-range access when rank >= local_world_size. Also remove duplicate synchronize.
Apply this diff:
-    torch.cuda.synchronize()
-    torch.distributed.barrier(group)
-    gemm_kernel(ag_buffer[rank], B, C)
+    gemm_kernel(ag_buffer[local_rank], B, C)
     torch.cuda.synchronize()
-    torch.distributed.barrier(group)
+    torch.distributed.barrier(group)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| torch.cuda.synchronize() | |
| torch.distributed.barrier(group) | |
| gemm_kernel(ag_buffer[rank], B, C) | |
| torch.cuda.synchronize() | |
| torch.distributed.barrier(group) | |
| gemm_kernel(ag_buffer[local_rank], B, C) | |
| torch.cuda.synchronize() | |
| torch.distributed.barrier(group) | 
🤖 Prompt for AI Agents
In examples/distributed/example_allgather_gemm_ipc.py around lines 110 to 114,
the code uses the global rank as the index into ag_buffer which can be
out-of-range when rank >= local_world_size; change the kernel input index to a
local peer index (e.g., local_peer = rank % local_world_size or however
local_world_size is computed in this module) and call
gemm_kernel(ag_buffer[local_peer], B, C) instead of ag_buffer[rank]; also remove
the redundant torch.cuda.synchronize() (keep a single synchronize after the
kernel) so you only call torch.distributed.barrier(group), gemm_kernel(... with
local_peer ...), torch.cuda.synchronize(), and torch.distributed.barrier(group).
| def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): | ||
| dtype = torch.float16 | ||
| M = args.M if args else 8192 | ||
| N = args.N if args else 8192 | ||
| K = args.K if args else 8192 | ||
| M_per_rank = M // num_local_ranks | ||
| N_per_rank = N // num_local_ranks | ||
|  | ||
| BLOCK_M = 128 | ||
| BLOCK_N = 128 | ||
| BLOCK_K = 64 | ||
| threads = 256 | ||
|  | ||
| rank, num_ranks, group = init_dist(local_rank, num_local_ranks) | ||
| allocator = tilelang.get_allocator( | ||
| size=2**30, | ||
| device="cuda", | ||
| is_distributed=True, | ||
| local_rank=local_rank, | ||
| num_local_ranks=num_local_ranks, | ||
| group=group) | ||
| kernel = tilelang.compile(gemm_kernel(M, N, K, num_ranks, BLOCK_M, BLOCK_N, BLOCK_K, threads)) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Compute per-rank sizes using the global world size and after init_dist
Current M_per_rank/N_per_rank use num_local_ranks, which breaks multi-node (OOB segment indexing).
Apply this diff:
-    M = args.M if args else 8192
-    N = args.N if args else 8192
-    K = args.K if args else 8192
-    M_per_rank = M // num_local_ranks
-    N_per_rank = N // num_local_ranks
+    M = args.M if args else 8192
+    N = args.N if args else 8192
+    K = args.K if args else 8192
@@
-    rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
+    rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
+    M_per_rank = M // num_ranks
+    N_per_rank = N // num_ranks📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): | |
| dtype = torch.float16 | |
| M = args.M if args else 8192 | |
| N = args.N if args else 8192 | |
| K = args.K if args else 8192 | |
| M_per_rank = M // num_local_ranks | |
| N_per_rank = N // num_local_ranks | |
| BLOCK_M = 128 | |
| BLOCK_N = 128 | |
| BLOCK_K = 64 | |
| threads = 256 | |
| rank, num_ranks, group = init_dist(local_rank, num_local_ranks) | |
| allocator = tilelang.get_allocator( | |
| size=2**30, | |
| device="cuda", | |
| is_distributed=True, | |
| local_rank=local_rank, | |
| num_local_ranks=num_local_ranks, | |
| group=group) | |
| kernel = tilelang.compile(gemm_kernel(M, N, K, num_ranks, BLOCK_M, BLOCK_N, BLOCK_K, threads)) | |
| def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): | |
| dtype = torch.float16 | |
| M = args.M if args else 8192 | |
| N = args.N if args else 8192 | |
| K = args.K if args else 8192 | |
| BLOCK_M = 128 | |
| BLOCK_N = 128 | |
| BLOCK_K = 64 | |
| threads = 256 | |
| rank, num_ranks, group = init_dist(local_rank, num_local_ranks) | |
| M_per_rank = M // num_ranks | |
| N_per_rank = N // num_ranks | |
| allocator = tilelang.get_allocator( | |
| size=2**30, | |
| device="cuda", | |
| is_distributed=True, | |
| local_rank=local_rank, | |
| num_local_ranks=num_local_ranks, | |
| group=group) | |
| kernel = tilelang.compile( | |
| gemm_kernel(M, N, K, num_ranks, BLOCK_M, BLOCK_N, BLOCK_K, threads) | |
| ) | 
🤖 Prompt for AI Agents
In examples/distributed/example_allgather_gemm_ipc.py around lines 130-151,
M_per_rank and N_per_rank are computed too early using num_local_ranks which
breaks multi-node indexing; move the M_per_rank/N_per_rank calculations to after
the init_dist(...) call and compute them with the global world size (use
num_ranks returned by init_dist), e.g. M_per_rank = M // num_ranks and
N_per_rank = N // num_ranks, ensuring the per-rank sizes are based on the global
number of ranks and evaluated after init_dist.
| A = tilelang.tensor((M_per_rank, K), dtype, allocator=allocator).normal_() | ||
| B = tilelang.tensor((K, N_per_rank), dtype, allocator=allocator).normal_() | ||
| C = tilelang.tensor((M, N_per_rank), dtype, allocator=allocator) | ||
| ag_buffer = tilelang.tensor((M, K), dtype, allocator=allocator, return_peers=True) | ||
| signal_buffer = tilelang.tensor((num_local_ranks,), | ||
| torch.int32, | ||
| allocator=allocator, | ||
| return_peers=True) | ||
| signal_buffer[rank].fill_(0) | ||
| ag_buffer[rank][rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A) | ||
|  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix peer buffer/signal shapes and local indexing
- signal_buffer inner length must be world_size (you index by global rank).
- Index ag_buffer/signal_buffer by local_rank, not global rank.
Apply this diff:
-    A = tilelang.tensor((M_per_rank, K), dtype, allocator=allocator).normal_()
-    B = tilelang.tensor((K, N_per_rank), dtype, allocator=allocator).normal_()
-    C = tilelang.tensor((M, N_per_rank), dtype, allocator=allocator)
-    ag_buffer = tilelang.tensor((M, K), dtype, allocator=allocator, return_peers=True)
-    signal_buffer = tilelang.tensor((num_local_ranks,),
-                                    torch.int32,
-                                    allocator=allocator,
-                                    return_peers=True)
-    signal_buffer[rank].fill_(0)
-    ag_buffer[rank][rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A)
+    A = tilelang.tensor((M_per_rank, K), dtype, allocator=allocator).normal_()
+    B = tilelang.tensor((K, N_per_rank), dtype, allocator=allocator).normal_()
+    C = tilelang.tensor((M, N_per_rank), dtype, allocator=allocator)
+    ag_buffer = tilelang.tensor((M, K), dtype, allocator=allocator, return_peers=True)
+    signal_buffer = tilelang.tensor((num_ranks,),
+                                    torch.int32,
+                                    allocator=allocator,
+                                    return_peers=True)
+    signal_buffer[local_rank].fill_(0)
+    ag_buffer[local_rank][rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| A = tilelang.tensor((M_per_rank, K), dtype, allocator=allocator).normal_() | |
| B = tilelang.tensor((K, N_per_rank), dtype, allocator=allocator).normal_() | |
| C = tilelang.tensor((M, N_per_rank), dtype, allocator=allocator) | |
| ag_buffer = tilelang.tensor((M, K), dtype, allocator=allocator, return_peers=True) | |
| signal_buffer = tilelang.tensor((num_local_ranks,), | |
| torch.int32, | |
| allocator=allocator, | |
| return_peers=True) | |
| signal_buffer[rank].fill_(0) | |
| ag_buffer[rank][rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A) | |
| A = tilelang.tensor((M_per_rank, K), dtype, allocator=allocator).normal_() | |
| B = tilelang.tensor((K, N_per_rank), dtype, allocator=allocator).normal_() | |
| C = tilelang.tensor((M, N_per_rank), dtype, allocator=allocator) | |
| ag_buffer = tilelang.tensor((M, K), dtype, allocator=allocator, return_peers=True) | |
| signal_buffer = tilelang.tensor((num_ranks,), | |
| torch.int32, | |
| allocator=allocator, | |
| return_peers=True) | |
| signal_buffer[local_rank].fill_(0) | |
| ag_buffer[local_rank][rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A) | 
🤖 Prompt for AI Agents
In examples/distributed/example_allgather_gemm_ipc.py around lines 156 to 166,
the peer signal buffer is created with length num_local_ranks and both ag_buffer
and signal_buffer are indexed by global rank; change the signal_buffer inner
length to world_size (since you index by global ranks across all peers) and
change all indexes that use rank to use local_rank instead (e.g., access
ag_buffer and signal_buffer via local_rank when writing local data/signal and
keep slices based on local offsets using local_rank * M_per_rank ..
(local_rank+1) * M_per_rank). Ensure the buffer constructions use allocator and
return_peers as before but with the corrected shape for signal_buffer and
replace global rank indexing with local_rank for local writes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 9
🧹 Nitpick comments (5)
.github/workflows/ci.yml (2)
126-133: Optional: simplify discovery by using pytest selection instead of find/mapfileYou can avoid brittle shell/globbing by letting pytest handle selection (keeps reporting consistent, works across platforms).
Example (no need to change now):
- Distributed examples:
- TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 1 -v -r fE -k "distributed and test"
- Non-distributed examples:
- python -m pytest -n 4 -v -r fE -k "not distributed and test"
152-159: Optional: let pytest handle non-distributed selectionUsing -k avoids maintaining separate find expressions and scales better as tests move.
Example (no immediate change required):
- python -m pytest -n 4 -v -r fE -k "not distributed and test"
tilelang/distributed/utils.py (1)
268-282: Signal helpers: support 64-bit signals and raise precise exceptions.
- set_signal only supports int32; many call sites use uint64 signals.
- Use cuStreamWriteValue64 when tensor is int64/uint64.
- Raise TypeError instead of Exception.
Apply:
-def set_signal(signal_tensor: torch.Tensor, - signal: int, - stream: Optional[torch.cuda.Stream] = None): - stream = stream or torch.cuda.current_stream() - if signal_tensor.dtype == torch.int32: - (err,) = cuda.cuStreamWriteValue32( - stream.cuda_stream, - signal_tensor.data_ptr(), - signal, - cuda.CUstreamWriteValue_flags.CU_STREAM_WRITE_VALUE_DEFAULT, - ) - CUDA_CHECK(err) - else: - raise Exception(f"Unsupported signal dtype {signal_tensor.dtype}") +def set_signal( + signal_tensor: torch.Tensor, + signal: int, + stream: Optional[torch.cuda.Stream] = None, +): + stream = stream or torch.cuda.current_stream() + dt = signal_tensor.dtype + if dt == torch.int32: + (err,) = cuda.cuStreamWriteValue32( + stream.cuda_stream, + signal_tensor.data_ptr(), + signal, + cuda.CUstreamWriteValue_flags.CU_STREAM_WRITE_VALUE_DEFAULT, + ) + CUDA_CHECK(err) + elif dt in (torch.int64, torch.uint64): + (err,) = cuda.cuStreamWriteValue64( + stream.cuda_stream, + signal_tensor.data_ptr(), + signal, + cuda.CUstreamWriteValue_flags.CU_STREAM_WRITE_VALUE_DEFAULT, + ) + CUDA_CHECK(err) + else: + raise TypeError(f"Unsupported signal dtype {dt}; expected int32/int64/uint64")Additionally, define a small custom exception if preferred by your style guide:
class SignalDTypeError(TypeError): passAnd replace TypeError with SignalDTypeError. As per coding guidelines
examples/distributed/example_allgather_gemm_ipc.py (2)
110-114: Remove redundant synchronization calls.Multiple consecutive synchronization calls are redundant and impact performance.
Apply this diff to clean up synchronization:
- torch.cuda.synchronize() - torch.distributed.barrier(group) gemm_kernel(ag_buffer[rank], B, C) torch.cuda.synchronize() torch.distributed.barrier(group)
183-183: Use exception chaining for better debugging context.When raising exceptions after a check failure, preserve the original context.
Apply this diff:
- raise ValueError("Test failed") + raise ValueError(f"Correctness check failed for rank {local_rank}")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (27)
- .github/workflows/ci.yml(1 hunks)
- benchmark/distributed/benchmark_ag_gemm.py(1 hunks)
- benchmark/distributed/benchmark_all_gather.py(1 hunks)
- benchmark/distributed/benchmark_all_to_all.py(1 hunks)
- benchmark/distributed/benchmark_gemm_rs.py(1 hunks)
- benchmark/distributed/benchmark_reduce_scatter.py(1 hunks)
- benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py(1 hunks)
- benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py(1 hunks)
- examples/distributed/example_all_to_all.py(1 hunks)
- examples/distributed/example_allgather.py(1 hunks)
- examples/distributed/example_allgather_gemm.py(1 hunks)
- examples/distributed/example_allgather_gemm_ipc.py(1 hunks)
- examples/distributed/example_cannon.py(1 hunks)
- examples/distributed/example_gemm_rs.py(1 hunks)
- examples/distributed/example_post_attn_all2all_transpose.py(3 hunks)
- examples/distributed/example_pre_attn_all2all.py(3 hunks)
- examples/distributed/example_pre_attn_all2all_transpose.py(3 hunks)
- examples/distributed/example_simple_shift.py(1 hunks)
- examples/distributed/example_summa.py(1 hunks)
- examples/distributed/primitives/example_get_block.py(1 hunks)
- examples/distributed/primitives/example_get_warp.py(1 hunks)
- examples/distributed/primitives/example_put_block.py(1 hunks)
- examples/distributed/primitives/example_put_warp.py(1 hunks)
- tilelang/distributed/common/README.md(1 hunks)
- tilelang/distributed/testing/sync/test_barrierall_sys.py(1 hunks)
- tilelang/distributed/utils.py(3 hunks)
- tilelang/utils/allocator.py(4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tilelang/utils/allocator.py
🧰 Additional context used
🧬 Code graph analysis (23)
examples/distributed/primitives/example_get_warp.py (1)
tilelang/distributed/utils.py (1)
init_dist(34-56)
examples/distributed/example_post_attn_all2all_transpose.py (2)
tilelang/distributed/utils.py (1)
init_distributed(59-82)examples/distributed/example_pre_attn_all2all.py (1)
verify_results(138-161)
benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py (1)
tilelang/distributed/utils.py (2)
init_dist(34-56)
perf_fn(217-238)
benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py (1)
tilelang/distributed/utils.py (2)
init_distributed(59-82)
perf_fn(217-238)
benchmark/distributed/benchmark_ag_gemm.py (1)
tilelang/distributed/utils.py (2)
init_distributed(59-82)
perf_fn(217-238)
benchmark/distributed/benchmark_all_gather.py (1)
tilelang/distributed/utils.py (2)
init_distributed(59-82)
perf_fn(217-238)
examples/distributed/example_all_to_all.py (2)
tilelang/distributed/utils.py (1)
init_distributed(59-82)tilelang/profiler/__init__.py (1)
init_distributed(67-88)
benchmark/distributed/benchmark_gemm_rs.py (1)
tilelang/distributed/utils.py (2)
init_distributed(59-82)
perf_fn(217-238)
benchmark/distributed/benchmark_all_to_all.py (1)
tilelang/distributed/utils.py (1)
init_distributed(59-82)
tilelang/distributed/testing/sync/test_barrierall_sys.py (1)
tilelang/distributed/utils.py (1)
init_dist(34-56)
examples/distributed/example_gemm_rs.py (2)
tilelang/distributed/utils.py (2)
init_distributed(59-82)
generate_data(168-170)tilelang/profiler/__init__.py (1)
init_distributed(67-88)
examples/distributed/example_cannon.py (1)
tilelang/distributed/utils.py (1)
init_distributed(59-82)
examples/distributed/primitives/example_put_warp.py (1)
tilelang/distributed/utils.py (1)
init_dist(34-56)
examples/distributed/example_allgather_gemm.py (1)
tilelang/distributed/utils.py (1)
init_distributed(59-82)
examples/distributed/primitives/example_get_block.py (1)
tilelang/distributed/utils.py (1)
init_dist(34-56)
examples/distributed/example_pre_attn_all2all_transpose.py (2)
tilelang/distributed/utils.py (1)
init_distributed(59-82)examples/distributed/example_post_attn_all2all_transpose.py (1)
verify_results(141-164)
examples/distributed/primitives/example_put_block.py (1)
tilelang/distributed/utils.py (1)
init_dist(34-56)
examples/distributed/example_allgather_gemm_ipc.py (6)
tilelang/distributed/utils.py (3)
init_dist(34-56)
set_signal(268-281)
wait_eq(284-298)examples/distributed/example_allgather_gemm.py (1)
main(15-50)tilelang/utils/allocator.py (2)
get_allocator(237-249)
device(114-115)tilelang/jit/__init__.py (1)
compile(32-81)tilelang/jit/kernel.py (1)
initialize(400-409)tilelang/utils/tensor.py (1)
tensor(45-58)
examples/distributed/example_summa.py (1)
tilelang/distributed/utils.py (1)
init_distributed(59-82)
examples/distributed/example_simple_shift.py (2)
tilelang/distributed/utils.py (1)
init_distributed(59-82)tilelang/profiler/__init__.py (1)
init_distributed(67-88)
examples/distributed/example_pre_attn_all2all.py (2)
tilelang/distributed/utils.py (1)
init_distributed(59-82)examples/distributed/example_post_attn_all2all_transpose.py (1)
verify_results(141-164)
benchmark/distributed/benchmark_reduce_scatter.py (1)
tilelang/distributed/utils.py (2)
init_distributed(59-82)
perf_fn(217-238)
examples/distributed/example_allgather.py (1)
tilelang/distributed/utils.py (2)
init_distributed(59-82)
perf_fn(217-238)
🪛 Ruff (0.13.1)
examples/distributed/example_allgather_gemm_ipc.py
83-83: Unpacked variable err is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
183-183: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/distributed/utils.py
281-281: Create your own exception
(TRY002)
281-281: Avoid specifying long messages outside the exception class
(TRY003)
287-287: Unused function argument: require_i64
(ARG001)
298-298: Create your own exception
(TRY002)
298-298: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: build-test-nvidia
🔇 Additional comments (32)
benchmark/distributed/ipc_impls/benchmark_unrolledcp_p2p.py (1)
8-8: Import path aligns with new public API surface.
tilelang.distributednow re-exportsinit_distandperf_fn, so this change keeps the benchmark in sync with the refreshed namespace.examples/distributed/example_post_attn_all2all_transpose.py (3)
6-6: Consolidated distributed imports look good.Pulling
init_distributedanddtype_mapfrom the package root keeps these examples aligned with the new public surface.
45-45: Loop variable cleanup appreciated.Switching to
_makes it explicit the index isn’t used and quiets static analyzers.
265-265: Direct verification call is fine.Dropping the unused local keeps the flow lean while preserving the check.
benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py (1)
12-12: Import consolidation to the public API is fine; verify re-exports.Ensure tilelang.distributed publicly exposes init_distributed and perf_fn without introducing a utils ↔ init import cycle.
Use the same script provided in the other comment to confirm init re-exports and no circular imports.
tilelang/distributed/utils.py (1)
301-304: cuda_stream_max_priority: LGTM.examples/distributed/primitives/example_put_warp.py (1)
8-8: Import path update is fine; ensure init_dist is re-exported from tilelang.distributed.Confirm tilelang.distributed exposes init_dist to avoid breaking this example.
Reuse the verification script posted earlier to inspect tilelang/distributed/init.py for init_dist re-exports and to detect cycles.
examples/distributed/example_cannon.py (1)
6-6: Public API import switch looks good.No other changes in this file; behavior unaffected.
benchmark/distributed/benchmark_gemm_rs.py (1)
12-12: Import move to the public API namespace is OK; verify re-exports.Ensure tilelang.distributed re-exports init_distributed, dtype_map, perf_fn to prevent runtime failures.
Use the same init.py verification script mentioned previously.
examples/distributed/example_allgather_gemm.py (1)
7-7: Import consolidation to tilelang.distributed is fine.No behavior change; just verify init_distributed is re-exported in init.py.
Same verification script as earlier applies.
examples/distributed/primitives/example_get_warp.py (1)
8-8: Import path change LGTM; confirm re-export of init_dist.Avoid breakage by ensuring tilelang.distributed exposes init_dist.
Same init.py verification script applies.
benchmark/distributed/benchmark_all_to_all.py (1)
4-4: Import-path update looks correctSwitching to the public
tilelang.distributedre-export keeps behavior unchanged and aligns with the PR-wide consolidation. 👍examples/distributed/primitives/example_get_block.py (1)
8-8: Import routed through public moduleUsing
tilelang.distributedforinit_distmatches the new API surface and maintains existing functionality.examples/distributed/primitives/example_put_block.py (1)
8-8: Public import confirmedThe updated import mirrors the new re-export and leaves runtime behavior intact.
examples/distributed/example_all_to_all.py (1)
6-6: Re-export usage is consistentThe module now sources
init_distributedfrom the consolidated public package; no other changes required.tilelang/distributed/testing/sync/test_barrierall_sys.py (1)
9-9: Test import aligned with public APIPointing the test to
tilelang.distributedkeeps it in sync with the new public surface.benchmark/distributed/benchmark_all_gather.py (1)
7-7: Benchmark import path updated correctlyAll three utilities (
init_distributed,dtype_map,perf_fn) are now pulled from the public module, matching the project-wide change.examples/distributed/example_summa.py (1)
6-6: Public re-export in useThe example now depends on the consolidated
tilelang.distributedsurface; no issues spotted.examples/distributed/example_allgather.py (1)
7-7: Allgather example aligned with new APIImports of
init_distributed,dtype_map, andperf_fnnow match the re-exported public interface.benchmark/distributed/benchmark_ag_gemm.py (1)
21-21: LGTM! Import path consolidation aligns with project-wide refactor.The import path change from
tilelang.distributed.utilstotilelang.distributedcorrectly reflects the public API surface reorganization happening across the project.examples/distributed/example_gemm_rs.py (1)
5-5: LGTM! Import path update matches project conventions.The change from
tilelang.distributed.utilstotilelang.distributedis consistent with the broader public API reorganization.tilelang/distributed/common/README.md (1)
15-15: LGTM! Documentation correctly reflects API surface changes.The README appropriately documents the new import path from
tilelang.distributedinstead of the previoustilelang.distributed.utils.examples/distributed/example_allgather_gemm_ipc.py (10)
9-10: Fix cuda-python compatibility for cudart import.Direct import from
cudabreaks on cuda-python >= 12.8. Use version-gated import pattern.Apply this diff to fix the import compatibility:
-from cuda import cudart +import importlib.metadata +from packaging import version +cuda_python_version = importlib.metadata.version("cuda-python") +if version.parse(cuda_python_version) >= version.parse("12.8.0"): + from cuda.bindings import runtime as cudart +else: + from cuda import cudart
33-33: Fix grid dimensions to use per-rank local N dimension.The kernel grid uses global N which causes out-of-bounds accesses. Must use local N_per_rank = N // num_rank.
Apply this diff:
- with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + with T.Kernel(T.ceildiv(N // num_rank, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
48-50: Rename parameter N to K to match tensor dimensions.The parameter
Nactually represents the K dimension (second dimension of A). This misnaming causes confusion.Apply this diff:
-def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, N, +def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank, local_world_size, world_size, intranode_ag_stream):Also update all usages of
Nwithin this function toK.
56-56: Use K for segment sizing and add error checking.Segment calculation and copy size must use K dimension. Also add CUDA error checking.
Apply this diff:
- segment = rank * M_per_rank * N + segment = rank * M_per_rank * K ... - M_per_rank * N * local_tensor.element_size(), + M_per_rank * K * local_tensor.element_size(),Also add error checking after line 68:
+ from tilelang.distributed import CUDA_CHECK + CUDA_CHECK(err)Also applies to: 65-65
74-74: Fix inter-node segment sizing to use K.Inter-node segments and copy sizes should also use K dimension.
Apply this diff:
- recv_segment = recv_rank * M_per_rank * N + recv_segment = recv_rank * M_per_rank * K ... - M_per_rank * N * local_tensor.element_size(), + M_per_rank * K * local_tensor.element_size(),Also applies to: 86-86
112-112: Use local_rank to index ag_buffer.Using global
rankto indexag_buffercauses out-of-bounds access when rank >= num_local_ranks.Apply this diff:
- gemm_kernel(ag_buffer[rank], B, C) + gemm_kernel(ag_buffer[local_rank], B, C)Also add local_rank computation at the start of the function:
+ local_rank = rank % local_world_size
135-136: Compute per-rank sizes using global world size after init_dist.Per-rank sizes computed with
num_local_ranksbreak multi-node setups. Move calculation afterinit_dist.Apply this diff:
K = args.K if args else 8192 - M_per_rank = M // num_local_ranks - N_per_rank = N // num_local_ranks BLOCK_M = 128 BLOCK_N = 128 BLOCK_K = 64 threads = 256 rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + M_per_rank = M // num_ranks + N_per_rank = N // num_ranksAlso applies to: 143-143
160-165: Fix signal buffer size and use local_rank for indexing.Signal buffer must be sized for global world_size since you index by global rank. Also use local_rank for local operations.
Apply this diff:
- signal_buffer = tilelang.tensor((num_local_ranks,), + signal_buffer = tilelang.tensor((num_ranks,), torch.int32, allocator=allocator, return_peers=True) - signal_buffer[rank].fill_(0) - ag_buffer[rank][rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A) + signal_buffer[local_rank].fill_(0) + ag_buffer[local_rank][rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A)
172-173: Pass correct world_size to ag_gemm_op.Currently passing
num_local_ranksfor both local and world sizes. Should passnum_ranksfor world_size.Apply this diff:
tilelang_C = ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank, - group, num_local_ranks, num_local_ranks, kernel, ag_stream) + group, num_local_ranks, num_ranks, kernel, ag_stream)
83-83: Add CUDA error checking or use underscore prefix.The error return value is unpacked but not checked, which could hide failures.
Apply this diff to add error checking:
- (err,) = cudart.cudaMemcpyAsync( + (err,) = cudart.cudaMemcpyAsync( dst_ptr, src_ptr, - M_per_rank * N * local_tensor.element_size(), + M_per_rank * K * local_tensor.element_size(), cudart.cudaMemcpyKind.cudaMemcpyDefault, intranode_ag_stream.cuda_stream, ) + from tilelang.distributed import CUDA_CHECK + CUDA_CHECK(err)Or if error checking is intentionally omitted, use underscore prefix:
- (err,) = cudart.cudaMemcpyAsync( + (_err,) = cudart.cudaMemcpyAsync(
| # find and run distributed tests with TILELANG_USE_DISTRIBUTED=1 | ||
| mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true) | ||
| if [ "${#DIST_TESTS[@]}" -gt 0 ]; then | ||
| echo "Running distributed examples with TILELANG_USE_DISTRIBUTED=1:" | ||
| printf '%s\n' "${DIST_TESTS[@]}" | ||
| TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 4 "${DIST_TESTS[@]}" -v -r fE | ||
| else | ||
| echo "No distributed examples found." | ||
| fi | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Run distributed example tests serially and harden shell options
Parallelizing distributed GPU tests with pytest-xdist (-n 4) commonly oversubscribes GPUs/process groups and causes flaky hangs/OOMs. Also add strict shell flags to fail fast on CI.
Apply this diff:
-        # find and run distributed tests with TILELANG_USE_DISTRIBUTED=1
+        # find and run distributed tests with TILELANG_USE_DISTRIBUTED=1
+        set -euo pipefail
         mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true)
         if [ "${#DIST_TESTS[@]}" -gt 0 ]; then
           echo "Running distributed examples with TILELANG_USE_DISTRIBUTED=1:"
           printf '%s\n' "${DIST_TESTS[@]}"
-          TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 4 "${DIST_TESTS[@]}" -v -r fE
+          # run distributed tests serially to avoid GPU/process-group contention
+          TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 1 "${DIST_TESTS[@]}" -v -r fE
         else
           echo "No distributed examples found."
         fi📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # find and run distributed tests with TILELANG_USE_DISTRIBUTED=1 | |
| mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true) | |
| if [ "${#DIST_TESTS[@]}" -gt 0 ]; then | |
| echo "Running distributed examples with TILELANG_USE_DISTRIBUTED=1:" | |
| printf '%s\n' "${DIST_TESTS[@]}" | |
| TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 4 "${DIST_TESTS[@]}" -v -r fE | |
| else | |
| echo "No distributed examples found." | |
| fi | |
| # find and run distributed tests with TILELANG_USE_DISTRIBUTED=1 | |
| set -euo pipefail | |
| mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true) | |
| if [ "${#DIST_TESTS[@]}" -gt 0 ]; then | |
| echo "Running distributed examples with TILELANG_USE_DISTRIBUTED=1:" | |
| printf '%s\n' "${DIST_TESTS[@]}" | |
| # run distributed tests serially to avoid GPU/process-group contention | |
| TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 1 "${DIST_TESTS[@]}" -v -r fE | |
| else | |
| echo "No distributed examples found." | |
| fi | 
🤖 Prompt for AI Agents
In .github/workflows/ci.yml around lines 115 to 123, the workflow currently runs
distributed GPU tests in parallel and lacks strict shell failure flags; change
it to run the distributed tests serially (remove pytest-xdist -n 4) so
TILELANG_USE_DISTRIBUTED=1 python -m pytest "${DIST_TESTS[@]}" -v -r fE is used,
and add strict shell options at the start of the shell block (e.g., set -euo
pipefail and IFS=$'\n\t') so the job fails fast on errors and handles unset
variables/pipe failures robustly.
| # run distributed tests first with env var | ||
| mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true) | ||
| if [ "${#DIST_TESTS[@]}" -gt 0 ]; then | ||
| echo "Running distributed tests with TILELANG_USE_DISTRIBUTED=1:" | ||
| printf '%s\n' "${DIST_TESTS[@]}" | ||
| TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 4 "${DIST_TESTS[@]}" -v -r fE | ||
| else | ||
| echo "No distributed tests found under testing/python." | ||
| fi | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Run distributed tests (testing/python) serially and add strict shell flags
Same flakiness risk here from -n 4 on GPU-distributed tests; also add set -euo pipefail for safety.
Apply this diff:
-        # run distributed tests first with env var
+        # run distributed tests first with env var
+        set -euo pipefail
         mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true)
         if [ "${#DIST_TESTS[@]}" -gt 0 ]; then
           echo "Running distributed tests with TILELANG_USE_DISTRIBUTED=1:"
           printf '%s\n' "${DIST_TESTS[@]}"
-          TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 4 "${DIST_TESTS[@]}" -v -r fE
+          # run distributed tests serially to avoid GPU/process-group contention
+          TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 1 "${DIST_TESTS[@]}" -v -r fE
         else
           echo "No distributed tests found under testing/python."
         fi📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # run distributed tests first with env var | |
| mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true) | |
| if [ "${#DIST_TESTS[@]}" -gt 0 ]; then | |
| echo "Running distributed tests with TILELANG_USE_DISTRIBUTED=1:" | |
| printf '%s\n' "${DIST_TESTS[@]}" | |
| TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 4 "${DIST_TESTS[@]}" -v -r fE | |
| else | |
| echo "No distributed tests found under testing/python." | |
| fi | |
| # run distributed tests first with env var | |
| set -euo pipefail | |
| mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true) | |
| if [ "${#DIST_TESTS[@]}" -gt 0 ]; then | |
| echo "Running distributed tests with TILELANG_USE_DISTRIBUTED=1:" | |
| printf '%s\n' "${DIST_TESTS[@]}" | |
| # run distributed tests serially to avoid GPU/process-group contention | |
| TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 1 "${DIST_TESTS[@]}" -v -r fE | |
| else | |
| echo "No distributed tests found under testing/python." | |
| fi | 
🤖 Prompt for AI Agents
.github/workflows/ci.yml lines 141-149: the workflow runs GPU-distributed tests
in parallel with pytest -n 4 which causes flakiness and lacks strict shell
safety; remove the -n 4 flag so pytest runs serially for those distributed
tests, and add strict shell flags (set -euo pipefail) before this block (or at
top of the script/job) so failures, undefined variables, and pipe errors abort
the job; ensure TILELANG_USE_DISTRIBUTED=1 is preserved when invoking pytest.
|  | ||
| # all_gather A to ag_buffer | ||
| with torch.cuda.stream(ag_stream): | ||
| cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, N, signal_target, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pass correct K dimension to cp_engine_producer_all_gather_put.
The function call passes N but should pass K to match the corrected function signature.
Apply this diff:
-        cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, N, signal_target,
+        cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, K, signal_target,
                                           rank, local_world_size, world_size, ag_stream)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, N, signal_target, | |
| cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, K, signal_target, | |
| rank, local_world_size, world_size, ag_stream) | 
🤖 Prompt for AI Agents
In examples/distributed/example_allgather_gemm_ipc.py around line 101, the call
to cp_engine_producer_all_gather_put incorrectly passes N as the K dimension;
update the call to pass K instead of N so the arguments match the corrected
function signature (replace the N argument with K in that function call).
| print(f"PE {RANK} TileLang output sample: {tilelang_output[0, 0, 0, :3]}") | ||
|  | ||
| verification_passed = verify_results(tilelang_output, torch_output, RANK) | ||
| verify_results(tilelang_output, torch_output, RANK) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Restore verification failure propagation
By discarding the boolean return from verify_results, any mismatch now only prints to stdout and the script exits successfully, hiding real correctness regressions. Please reintroduce the check so the example fails fast when verification fails.
-    verify_results(tilelang_output, torch_output, RANK)
+    verification_passed = verify_results(tilelang_output, torch_output, RANK)
+    if not verification_passed:
+        raise RuntimeError(f"PE {RANK}: TileLang output does not match PyTorch reference")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| verify_results(tilelang_output, torch_output, RANK) | |
| verification_passed = verify_results(tilelang_output, torch_output, RANK) | |
| if not verification_passed: | |
| raise RuntimeError(f"PE {RANK}: TileLang output does not match PyTorch reference") | 
🤖 Prompt for AI Agents
In examples/distributed/example_pre_attn_all2all_transpose.py around line 260,
the call to verify_results(tilelang_output, torch_output, RANK) currently
ignores its boolean return value; restore failure propagation by capturing the
return and exiting non-zero on failure (e.g., if not verify_results(...) then
print/log error and call sys.exit(1) or raise an exception) so the script fails
fast when verification fails.
| print(f"PE {RANK} TileLang output sample: {tilelang_output[0, 0, 0, :3]}") | ||
|  | ||
| verification_passed = verify_results(tilelang_output, torch_output, RANK) | ||
| verify_results(tilelang_output, torch_output, RANK) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Restore failure propagation from verify_results.
We now drop the boolean return, so the script exits successfully even when verification fails—automations won’t catch regressions. Capture the return and exit (or raise) on failure.
-    verify_results(tilelang_output, torch_output, RANK)
+    success = verify_results(tilelang_output, torch_output, RANK)
+    if not success:
+        raise SystemExit(1)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| verify_results(tilelang_output, torch_output, RANK) | |
| success = verify_results(tilelang_output, torch_output, RANK) | |
| if not success: | |
| raise SystemExit(1) | 
🤖 Prompt for AI Agents
In examples/distributed/example_pre_attn_all2all.py around line 263, the call to
verify_results(tilelang_output, torch_output, RANK) currently ignores its
boolean result so failures are treated as success; capture the return value and
fail the script on false by either raising an exception or calling sys.exit with
a non-zero status (e.g., import sys and if not ok: sys.exit(1)) so CI and
automation detect verification regressions.
| import tilelang.language as T | ||
| from tilelang.profiler import TensorSupplyType | ||
| from tilelang.distributed.utils import init_distributed | ||
| from tilelang.distributed import init_distributed | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Confirm init_distributed re-export
Switching the import to from tilelang.distributed import init_distributed assumes the package’s __init__ now re-exports that symbol; otherwise this will raise an AttributeError at runtime. Please confirm the re-export exists.
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Verify that tilelang/distributed/__init__.py exposes init_distributed.
if [[ -f tilelang/distributed/__init__.py ]]; then
  rg -n "init_distributed" tilelang/distributed/__init__.py
else
  echo "tilelang/distributed/__init__.py missing"
fiLength of output: 40
Fix import for init_distributed
The tilelang/distributed directory has no __init__.py, so
from tilelang.distributed import init_distributedimports the submodule, not the function. Either:
- Add tilelang/distributed/__init__.pywithfrom .init_distributed import init_distributed 
- Or update examples/distributed/example_simple_shift.py:4tofrom tilelang.distributed.init_distributed import init_distributed 
🤖 Prompt for AI Agents
In examples/distributed/example_simple_shift.py around line 4 the statement
imports the submodule instead of the function because tilelang/distributed is
not a package; change the import to reference the specific module path or add an
__init__.py that re-exports the symbol. Either update the file to import
init_distributed from tilelang.distributed.init_distributed, or add
tilelang/distributed/__init__.py that imports and exposes init_distributed so
the original import works.
        
          
                tilelang/distributed/utils.py
              
                Outdated
          
        
      |  | ||
| import ctypes | ||
| from tilelang.distributed.common.ipc_ext import _create_ipc_handle, _sync_ipc_handles, _create_tensor | ||
| from tilelang.distributed import _create_ipc_handle, _sync_ipc_handles, _create_tensor | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid circular imports: import IPC helpers from the submodule, not the package.
Importing from tilelang.distributed here can create a cycle if init.py re-exports from utils. Import the extension directly via a relative path.
Apply:
-from tilelang.distributed import _create_ipc_handle, _sync_ipc_handles, _create_tensor
+from .common.ipc_ext import _create_ipc_handle, _sync_ipc_handles, _create_tensor📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from tilelang.distributed import _create_ipc_handle, _sync_ipc_handles, _create_tensor | |
| ++ b/tilelang/distributed/utils.py | |
| @@ -20,1 +20,1 @@ | |
| from .common.ipc_ext import _create_ipc_handle, _sync_ipc_handles, _create_tensor | 
🤖 Prompt for AI Agents
In tilelang/distributed/utils.py around line 20, the file currently imports IPC
helpers from the package (from tilelang.distributed import _create_ipc_handle,
_sync_ipc_handles, _create_tensor) which can produce a circular import if
__init__.py re-exports utils; change this to a direct relative import from the
extension module (e.g. from ._ipc import _create_ipc_handle, _sync_ipc_handles,
_create_tensor) so utils imports the IPC helpers from the implementation module
rather than the package entry point.
| def wait_eq(signal_tensor: torch.Tensor, | ||
| signal: int, | ||
| stream: Optional[torch.cuda.Stream] = None, | ||
| require_i64=False): | ||
| stream = stream or torch.cuda.current_stream() | ||
| if signal_tensor.dtype == torch.int32: | ||
| (err,) = cuda.cuStreamWaitValue32( | ||
| stream.cuda_stream, | ||
| signal_tensor.data_ptr(), | ||
| signal, | ||
| cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ, | ||
| ) | ||
| CUDA_CHECK(err) | ||
| else: | ||
| raise Exception(f"Unsupported signal dtype {signal_tensor.dtype}") | ||
|  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait_eq: implement 64-bit path and remove/repurpose the unused require_i64 arg.
- The require_i64 argument is unused.
- Support int64/uint64 via cuStreamWaitValue64; error on other dtypes.
Apply:
-def wait_eq(signal_tensor: torch.Tensor,
-            signal: int,
-            stream: Optional[torch.cuda.Stream] = None,
-            require_i64=False):
-    stream = stream or torch.cuda.current_stream()
-    if signal_tensor.dtype == torch.int32:
-        (err,) = cuda.cuStreamWaitValue32(
-            stream.cuda_stream,
-            signal_tensor.data_ptr(),
-            signal,
-            cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ,
-        )
-        CUDA_CHECK(err)
-    else:
-        raise Exception(f"Unsupported signal dtype {signal_tensor.dtype}")
+def wait_eq(
+    signal_tensor: torch.Tensor,
+    signal: int,
+    stream: Optional[torch.cuda.Stream] = None,
+    require_i64: bool = False,
+):
+    stream = stream or torch.cuda.current_stream()
+    dt = signal_tensor.dtype
+    if require_i64 and dt not in (torch.int64, torch.uint64):
+        raise TypeError(f"require_i64=True but got dtype {dt}")
+    if dt == torch.int32:
+        (err,) = cuda.cuStreamWaitValue32(
+            stream.cuda_stream,
+            signal_tensor.data_ptr(),
+            signal,
+            cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ,
+        )
+        CUDA_CHECK(err)
+    elif dt in (torch.int64, torch.uint64):
+        (err,) = cuda.cuStreamWaitValue64(
+            stream.cuda_stream,
+            signal_tensor.data_ptr(),
+            signal,
+            cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ,
+        )
+        CUDA_CHECK(err)
+    else:
+        raise TypeError(f"Unsupported signal dtype {dt}; expected int32/int64/uint64")This also addresses the static-analysis warning about the unused parameter. As per coding guidelines
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def wait_eq(signal_tensor: torch.Tensor, | |
| signal: int, | |
| stream: Optional[torch.cuda.Stream] = None, | |
| require_i64=False): | |
| stream = stream or torch.cuda.current_stream() | |
| if signal_tensor.dtype == torch.int32: | |
| (err,) = cuda.cuStreamWaitValue32( | |
| stream.cuda_stream, | |
| signal_tensor.data_ptr(), | |
| signal, | |
| cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ, | |
| ) | |
| CUDA_CHECK(err) | |
| else: | |
| raise Exception(f"Unsupported signal dtype {signal_tensor.dtype}") | |
| def wait_eq( | |
| signal_tensor: torch.Tensor, | |
| signal: int, | |
| stream: Optional[torch.cuda.Stream] = None, | |
| require_i64: bool = False, | |
| ): | |
| stream = stream or torch.cuda.current_stream() | |
| dt = signal_tensor.dtype | |
| if require_i64 and dt not in (torch.int64, torch.uint64): | |
| raise TypeError(f"require_i64=True but got dtype {dt}") | |
| if dt == torch.int32: | |
| (err,) = cuda.cuStreamWaitValue32( | |
| stream.cuda_stream, | |
| signal_tensor.data_ptr(), | |
| signal, | |
| cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ, | |
| ) | |
| CUDA_CHECK(err) | |
| elif dt in (torch.int64, torch.uint64): | |
| (err,) = cuda.cuStreamWaitValue64( | |
| stream.cuda_stream, | |
| signal_tensor.data_ptr(), | |
| signal, | |
| cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ, | |
| ) | |
| CUDA_CHECK(err) | |
| else: | |
| raise TypeError(f"Unsupported signal dtype {dt}; expected int32/int64/uint64") | 
🧰 Tools
🪛 Ruff (0.13.1)
287-287: Unused function argument: require_i64
(ARG001)
298-298: Create your own exception
(TRY002)
298-298: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In tilelang/distributed/utils.py around lines 284 to 299, the wait_eq function
currently only handles 32-bit signals and leaves the require_i64 parameter
unused; implement support for 64-bit signals and remove or repurpose require_i64
to avoid the unused-parameter lint. Change the body so that if
signal_tensor.dtype is torch.int32 you call cuda.cuStreamWaitValue32(...), if it
is torch.int64 or torch.uint64 you call cuda.cuStreamWaitValue64(...) with
stream.cuda_stream and the tensor.data_ptr(), and invoke CUDA_CHECK on the
returned error in each case; for any other dtype raise an Exception mentioning
the unsupported dtype. Finally remove the unused require_i64 parameter from the
function signature (and update any call sites) or if you prefer repurpose it to
force 64-bit semantics, document and use it to select the 64-bit path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tilelang/distributed/__init__.py (2)
3-3: Drop the redundantnoqa.
F401isn’t enabled in our Ruff config, so the# noqa: F401comment has no effect and just adds noise. Please remove it.
4-4: Re-export explicitly instead of usingimport *.Pulling everything from
.utilsvia*makes the package surface hard to reason about and can leak internal names. Please re-export the intended symbols explicitly (e.g., with a curated__all__) so downstream users know what’s public and static analysis stays reliable.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
- tilelang/distributed/__init__.py(1 hunks)
🧰 Additional context used
🪛 Ruff (0.13.1)
tilelang/distributed/__init__.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
4-4: from .utils import * used; unable to detect undefined names
(F403)
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-nvidia
- GitHub Check: bot-task
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
- tilelang/distributed/__init__.py(1 hunks)
- tilelang/distributed/common/__init__.py(1 hunks)
🧰 Additional context used
🪛 Ruff (0.13.1)
tilelang/distributed/common/__init__.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/distributed/__init__.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
4-4: from .utils import * used; unable to detect undefined names
(F403)
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: build-test-nvidia
        
          
                tilelang/distributed/__init__.py
              
                Outdated
          
        
      | from .common import _create_ipc_handle, _sync_ipc_handles, _create_tensor # noqa: F401 | ||
| from .utils import * # noqa: F401 No newline at end of file | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clean up package exports so Ruff stops failing.
The # noqa: F401 here triggers RUF100, and the wildcard import needs an explicit F403 suppression (or an explicit __all__) to satisfy Ruff. Making the exports explicit resolves both issues.
Apply this diff:
-from .common import _create_ipc_handle, _sync_ipc_handles, _create_tensor  # noqa: F401
-from .utils import *  # noqa: F401
+from .common import _create_ipc_handle, _sync_ipc_handles, _create_tensor
+from .utils import *  # noqa: F403
+from .utils import __all__ as _utils_all
+
+__all__ = ["_create_ipc_handle", "_sync_ipc_handles", "_create_tensor", *_utils_all]
+del _utils_allAs per static analysis hints
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from .common import _create_ipc_handle, _sync_ipc_handles, _create_tensor # noqa: F401 | |
| from .utils import * # noqa: F401 | |
| from .common import _create_ipc_handle, _sync_ipc_handles, _create_tensor | |
| from .utils import * # noqa: F403 | |
| from .utils import __all__ as _utils_all | |
| __all__ = ["_create_ipc_handle", "_sync_ipc_handles", "_create_tensor", *_utils_all] | |
| del _utils_all | 
🧰 Tools
🪛 Ruff (0.13.1)
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
4-4: from .utils import * used; unable to detect undefined names
(F403)
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
🤖 Prompt for AI Agents
In tilelang/distributed/__init__.py around lines 3-4, the wildcard import and
blanket noqa comments trigger Ruff RUF100/F403 errors; replace the star import
with explicit symbol imports from .utils (e.g., import the specific
functions/classes used) and define a module-level __all__ list containing the
exported names (including '_create_ipc_handle', '_sync_ipc_handles',
'_create_tensor' and the explicit names imported from .utils); remove the
wildcard import and broad noqa so static analysis passes.
| @@ -0,0 +1,3 @@ | |||
| """The common distributed modules""" | |||
|  | |||
| from .ipc_ext import _create_ipc_handle, _sync_ipc_handles, _create_tensor # noqa: F401 No newline at end of file | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix Ruff RUF100 by making the re-export explicit.
Ruff flags this # noqa: F401 as an unused directive. Declaring __all__ makes the intent clear and removes the need for the directive, keeping the file lint-clean.
Apply this diff:
-from .ipc_ext import _create_ipc_handle, _sync_ipc_handles, _create_tensor  # noqa: F401
+from .ipc_ext import _create_ipc_handle, _sync_ipc_handles, _create_tensor
+
+__all__ = ["_create_ipc_handle", "_sync_ipc_handles", "_create_tensor"]As per static analysis hints
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from .ipc_ext import _create_ipc_handle, _sync_ipc_handles, _create_tensor # noqa: F401 | |
| from .ipc_ext import _create_ipc_handle, _sync_ipc_handles, _create_tensor | |
| __all__ = ["_create_ipc_handle", "_sync_ipc_handles", "_create_tensor"] | 
🧰 Tools
🪛 Ruff (0.13.1)
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
🤖 Prompt for AI Agents
In tilelang/distributed/common/__init__.py around line 3, the current re-export
uses a bare import with a trailing "# noqa: F401" which Ruff flags as an unused
directive; replace the noqa by making the re-export explicit: keep the from
.ipc_ext import _create_ipc_handle, _sync_ipc_handles, _create_tensor import,
remove the "# noqa: F401" comment, and add a module-level __all__ =
["_create_ipc_handle", "_sync_ipc_handles", "_create_tensor"] to explicitly
declare the exported symbols.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tilelang/distributed/utils.py (1)
268-281: Add support for 64-bit signals in set_signal.The function currently only supports
torch.int32signals. Consider adding support fortorch.int64andtorch.uint64usingcuStreamWriteValue64for consistency with thewait_eqfunction.def set_signal(signal_tensor: torch.Tensor, signal: int, stream: Optional[torch.cuda.Stream] = None): stream = stream or torch.cuda.current_stream() - if signal_tensor.dtype == torch.int32: + dt = signal_tensor.dtype + if dt == torch.int32: (err,) = cuda.cuStreamWriteValue32( stream.cuda_stream, signal_tensor.data_ptr(), signal, cuda.CUstreamWriteValue_flags.CU_STREAM_WRITE_VALUE_DEFAULT, ) CUDA_CHECK(err) + elif dt in (torch.int64, torch.uint64): + (err,) = cuda.cuStreamWriteValue64( + stream.cuda_stream, + signal_tensor.data_ptr(), + signal, + cuda.CUstreamWriteValue_flags.CU_STREAM_WRITE_VALUE_DEFAULT, + ) + CUDA_CHECK(err) else: - raise Exception(f"Unsupported signal dtype {signal_tensor.dtype}") + raise TypeError(f"Unsupported signal dtype {dt}; expected int32/int64/uint64")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
- CMakeLists.txt(0 hunks)
- setup.py(4 hunks)
- tilelang/distributed/__init__.py(1 hunks)
- tilelang/distributed/utils.py(3 hunks)
- tilelang/utils/allocator.py(4 hunks)
💤 Files with no reviewable changes (1)
- CMakeLists.txt
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/utils/allocator.py (1)
tilelang/utils/cpp/tensor_from_ptr.cpp (2)
tensor_from_ptr(46-94)
tensor_from_ptr(46-48)
🪛 Ruff (0.13.1)
setup.py
846-846: Unused method argument: verbose
(ARG002)
856-856: Avoid specifying long messages outside the exception class
(TRY003)
859-859: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
864-864: subprocess call: check for execution of untrusted input
(S603)
866-868: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/distributed/__init__.py
3-3: from .utils import * used; unable to detect undefined names
(F403)
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
4-4: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/distributed/utils.py
20-20: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
281-281: Create your own exception
(TRY002)
281-281: Avoid specifying long messages outside the exception class
(TRY003)
287-287: Unused function argument: require_i64
(ARG001)
298-298: Create your own exception
(TRY002)
298-298: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-nvidia
- GitHub Check: bot-task
🔇 Additional comments (5)
tilelang/utils/allocator.py (1)
202-210: Guard the peer‑tensor path for unsupported modesTwo problems here:
- If
return_peers=Trueon a non-distributed allocator,self._group/self._buffer_ptrsare stillNone, so the very first call raisesAttributeError.- Combining
return_peers=Truewithtake_ownership=Truelets the local tensor free the underlying buffer while the peer tensors still reference it, leading to use-after-free.We need hard guards before constructing
peer_ts.- if return_peers: + if return_peers: + if ( + not self._is_distributed + or self._group is None + or self._buffer_ptrs is None + ): + raise RuntimeError( + "return_peers=True requires a distributed allocator with initialized peer mappings." + ) + if take_ownership: + raise ValueError( + "return_peers=True cannot be combined with take_ownership=True; the local tensor would free memory still referenced by peers." + ) peer_ts = [] for i in range(self._group.size()): if i == self._local_rank: peer_ts.append(t) else:tilelang/distributed/__init__.py (1)
3-4: Fix circular import and address static analysis issues.The current imports create potential circular import issues and trigger Ruff warnings. The import from
ipc_extshould be from the submodule to avoid cycles, and the wildcard import needs proper handling.This issue was already flagged in previous review comments. Please implement the suggested fix to resolve the circular import and static analysis warnings.
tilelang/distributed/utils.py (3)
20-20: Fix circular import to avoid dependency cycles.The import should be from the submodule rather than the package to prevent circular imports when
__init__.pyre-exports from utils.This issue was already identified in previous review comments. Please change to a relative import from the submodule as suggested.
284-298: Implement 64-bit support and fix unused parameter.The function has an unused
require_i64parameter and lacks support for 64-bit signals as mentioned in previous review comments.This issue was already flagged in previous review comments. Please implement the suggested fix to add 64-bit support and properly handle the
require_i64parameter.
301-304: LGTM! Well-implemented CUDA stream priority utility.The function correctly retrieves the maximum CUDA stream priority using the appropriate CUDA runtime API and includes proper error checking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
requirements-dev.txt (1)
26-27: Guard CUDA dependency for non‑CUDA hosts; confirm 12.x floor.To avoid breaking dev setups on macOS and other non‑CUDA environments, gate cuda‑python. Also verify that CUDA 12.x is the true minimum required by the new IPC paths.
Proposed diff:
-setuptools -cuda-python>=12.0.0 +setuptools +cuda-python>=12.0.0; platform_system != "Darwin"Optional:
- Move CUDA deps to an extra (e.g.,
tilescale[cuda]) or a separaterequirements-cuda.txtused only in distributed jobs.- If Windows also isn’t supported in your CI, add
and platform_system == "Linux"to be stricter.Please confirm whether features added in this PR require CUDA 12.x specifically, or if >=11.8 would suffice.
requirements-test.txt (1)
30-31: Relax Ninja pin; align CUDA guard with dev.Pinning
ninja==1.10.0is unnecessarily strict and dated; use a lower bound compatible with CMake and your builds. Also mirror the CUDA guard from dev to prevent failures on non‑CUDA hosts.Proposed diff:
-ninja==1.10.0 -cuda-python>=12.0.0 +ninja>=1.10 +cuda-python>=12.0.0; platform_system != "Darwin"Follow‑ups (optional):
- If only some tests need Ninja/CUDA, move them into conditional extras or per‑job requirements to reduce CI surface area.
- There are duplicate entries for
Cythonin this file (Lines 4 and 9); consider deduping as housekeeping.
This pull request introduces a new distributed all-gather GEMM example and adds key utilities for peer-to-peer signaling and tensor allocation in distributed settings. The changes improve support for distributed GPU computations, particularly for multi-process setups, and enhance the flexibility of tensor allocation and communication primitives.
New distributed example and utilities
example_allgather_gemm_ipc.pythat demonstrates distributed all-gather followed by GEMM computation using custom signaling and buffer sharing, including a correctness check against PyTorch's implementation.set_signalandwait_eqfunctions intilelang/distributed/utils.pyfor efficient peer-to-peer signaling between CUDA streams using atomic operations.Tensor allocation and API improvements
tilelang/utils/allocator.pyand thetensorAPI intilelang/utils/tensor.pyto support allocation of peer tensors across distributed ranks via a newreturn_peersargument. This enables returning a list of tensors, one per peer, for distributed memory access. [1] [2] [3].valuefor correct memory handling.Miscellaneous
example_all_to_all.pyby removing an unused import.Summary by CodeRabbit
New Features
Documentation
Tests
Chores