Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ jobs:
PIP_NO_BUILD_ISOLATION=1 pip install -r requirements-test.txt --no-user
# flash attention usually requires no isolation build
pip install flash_attn==2.5.8 --no-user --no-build-isolation
pip install . --no-user
touch "$MARKER"
fi

Expand All @@ -111,11 +110,49 @@ jobs:
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
cd examples
unset PYTHONPATH
python -m pytest -n 4 **/test*.py -v -r fE

# find and run distributed tests with TILELANG_USE_DISTRIBUTED=1
mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true)
if [ "${#DIST_TESTS[@]}" -gt 0 ]; then
echo "Running distributed examples with TILELANG_USE_DISTRIBUTED=1:"
printf '%s\n' "${DIST_TESTS[@]}"
TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 4 "${DIST_TESTS[@]}" -v -r fE
else
echo "No distributed examples found."
fi
Comment on lines +114 to +122
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Run distributed example tests serially and harden shell options

Parallelizing distributed GPU tests with pytest-xdist (-n 4) commonly oversubscribes GPUs/process groups and causes flaky hangs/OOMs. Also add strict shell flags to fail fast on CI.

Apply this diff:

-        # find and run distributed tests with TILELANG_USE_DISTRIBUTED=1
+        # find and run distributed tests with TILELANG_USE_DISTRIBUTED=1
+        set -euo pipefail
         mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true)
         if [ "${#DIST_TESTS[@]}" -gt 0 ]; then
           echo "Running distributed examples with TILELANG_USE_DISTRIBUTED=1:"
           printf '%s\n' "${DIST_TESTS[@]}"
-          TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 4 "${DIST_TESTS[@]}" -v -r fE
+          # run distributed tests serially to avoid GPU/process-group contention
+          TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 1 "${DIST_TESTS[@]}" -v -r fE
         else
           echo "No distributed examples found."
         fi
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# find and run distributed tests with TILELANG_USE_DISTRIBUTED=1
mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true)
if [ "${#DIST_TESTS[@]}" -gt 0 ]; then
echo "Running distributed examples with TILELANG_USE_DISTRIBUTED=1:"
printf '%s\n' "${DIST_TESTS[@]}"
TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 4 "${DIST_TESTS[@]}" -v -r fE
else
echo "No distributed examples found."
fi
# find and run distributed tests with TILELANG_USE_DISTRIBUTED=1
set -euo pipefail
mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true)
if [ "${#DIST_TESTS[@]}" -gt 0 ]; then
echo "Running distributed examples with TILELANG_USE_DISTRIBUTED=1:"
printf '%s\n' "${DIST_TESTS[@]}"
# run distributed tests serially to avoid GPU/process-group contention
TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 1 "${DIST_TESTS[@]}" -v -r fE
else
echo "No distributed examples found."
fi
🤖 Prompt for AI Agents
In .github/workflows/ci.yml around lines 115 to 123, the workflow currently runs
distributed GPU tests in parallel and lacks strict shell failure flags; change
it to run the distributed tests serially (remove pytest-xdist -n 4) so
TILELANG_USE_DISTRIBUTED=1 python -m pytest "${DIST_TESTS[@]}" -v -r fE is used,
and add strict shell options at the start of the shell block (e.g., set -euo
pipefail and IFS=$'\n\t') so the job fails fast on errors and handles unset
variables/pipe failures robustly.


# run remaining example tests (non-distributed)
mapfile -t OTHER_TESTS < <(find . -type f -name 'test*.py' ! -path '*/distributed/*' 2>/dev/null || true)
if [ "${#OTHER_TESTS[@]}" -gt 0 ]; then
echo "Running non-distributed examples:"
printf '%s\n' "${OTHER_TESTS[@]}"
python -m pytest -n 4 "${OTHER_TESTS[@]}" -v -r fE
else
echo "No non-distributed example tests found."
fi

- name: Run tests
run: |
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
cd testing/python
unset PYTHONPATH
python -m pytest -n 4 -v -r fE

# run distributed tests first with env var
mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true)
if [ "${#DIST_TESTS[@]}" -gt 0 ]; then
echo "Running distributed tests with TILELANG_USE_DISTRIBUTED=1:"
printf '%s\n' "${DIST_TESTS[@]}"
TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 4 "${DIST_TESTS[@]}" -v -r fE
else
echo "No distributed tests found under testing/python."
fi
Comment on lines +140 to +148
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Run distributed tests (testing/python) serially and add strict shell flags

Same flakiness risk here from -n 4 on GPU-distributed tests; also add set -euo pipefail for safety.

Apply this diff:

-        # run distributed tests first with env var
+        # run distributed tests first with env var
+        set -euo pipefail
         mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true)
         if [ "${#DIST_TESTS[@]}" -gt 0 ]; then
           echo "Running distributed tests with TILELANG_USE_DISTRIBUTED=1:"
           printf '%s\n' "${DIST_TESTS[@]}"
-          TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 4 "${DIST_TESTS[@]}" -v -r fE
+          # run distributed tests serially to avoid GPU/process-group contention
+          TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 1 "${DIST_TESTS[@]}" -v -r fE
         else
           echo "No distributed tests found under testing/python."
         fi
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# run distributed tests first with env var
mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true)
if [ "${#DIST_TESTS[@]}" -gt 0 ]; then
echo "Running distributed tests with TILELANG_USE_DISTRIBUTED=1:"
printf '%s\n' "${DIST_TESTS[@]}"
TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 4 "${DIST_TESTS[@]}" -v -r fE
else
echo "No distributed tests found under testing/python."
fi
# run distributed tests first with env var
set -euo pipefail
mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true)
if [ "${#DIST_TESTS[@]}" -gt 0 ]; then
echo "Running distributed tests with TILELANG_USE_DISTRIBUTED=1:"
printf '%s\n' "${DIST_TESTS[@]}"
# run distributed tests serially to avoid GPU/process-group contention
TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 1 "${DIST_TESTS[@]}" -v -r fE
else
echo "No distributed tests found under testing/python."
fi
🤖 Prompt for AI Agents
.github/workflows/ci.yml lines 141-149: the workflow runs GPU-distributed tests
in parallel with pytest -n 4 which causes flakiness and lacks strict shell
safety; remove the -n 4 flag so pytest runs serially for those distributed
tests, and add strict shell flags (set -euo pipefail) before this block (or at
top of the script/job) so failures, undefined variables, and pipe errors abort
the job; ensure TILELANG_USE_DISTRIBUTED=1 is preserved when invoking pytest.


# run remaining tests
mapfile -t OTHER_TESTS < <(find . -type f -name 'test*.py' ! -path '*/distributed/*' 2>/dev/null || true)
if [ "${#OTHER_TESTS[@]}" -gt 0 ]; then
echo "Running non-distributed tests:"
printf '%s\n' "${OTHER_TESTS[@]}"
python -m pytest -n 4 "${OTHER_TESTS[@]}" -v -r fE
else
echo "No non-distributed tests found under testing/python."
fi
24 changes: 0 additions & 24 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -222,30 +222,6 @@ if(CMAKE_BUILD_TYPE STREQUAL "Debug")
target_compile_definitions(tilelang_static PRIVATE "TVM_LOG_DEBUG")
endif()

set(ALLOC_CUDA_EXT_DIR ${PROJECT_SOURCE_DIR}/tilelang/utils/cpp)
add_custom_target(
build_alloc_cuda_ext ALL
COMMAND
${Python_EXECUTABLE} setup.py build_ext --inplace
WORKING_DIRECTORY
${ALLOC_CUDA_EXT_DIR}
COMMENT
"Building alloc_cuda PyTorch extension (in-place)"
)
add_dependencies(tilelang build_alloc_cuda_ext)

set(IPC_EXT_DIR ${PROJECT_SOURCE_DIR}/tilelang/distributed/common)
add_custom_target(
build_ipc_ext ALL
COMMAND
${Python_EXECUTABLE} setup.py build_ext --inplace
WORKING_DIRECTORY
${IPC_EXT_DIR}
COMMENT
"Building ipc_ext PyTorch extension (in-place)"
)
add_dependencies(tilelang build_ipc_ext)

# Building tvm_cython modules
if(NOT DEFINED TVM_PREBUILD_PATH)
add_dependencies(tilelang tvm_cython)
Expand Down
2 changes: 1 addition & 1 deletion benchmark/distributed/benchmark_ag_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import tilelang
import tilelang.language as T
from tilelang.carver.arch import driver
from tilelang.distributed.utils import init_distributed, dtype_map, perf_fn
from tilelang.distributed import init_distributed, dtype_map, perf_fn
from triton_dist.kernels.nvidia.allgather_gemm import ag_gemm, create_ag_gemm_context
from functools import partial

Expand Down
2 changes: 1 addition & 1 deletion benchmark/distributed/benchmark_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pynvshmem
import tilelang
import tilelang.language as T
from tilelang.distributed.utils import init_distributed, dtype_map, perf_fn
from tilelang.distributed import init_distributed, dtype_map, perf_fn
from typing import List

tilelang.disable_cache()
Expand Down
2 changes: 1 addition & 1 deletion benchmark/distributed/benchmark_all_to_all.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import tilelang
import tilelang.language as T
from tilelang.distributed.utils import init_distributed, dtype_map
from tilelang.distributed import init_distributed, dtype_map
import argparse
import random
from triton_dist.kernels.nvidia import fast_all_to_all, all_to_all_post_process
Expand Down
2 changes: 1 addition & 1 deletion benchmark/distributed/benchmark_gemm_rs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import tilelang
import tilelang.language as T
# from tilelang.carver.arch import driver
from tilelang.distributed.utils import init_distributed, dtype_map, perf_fn
from tilelang.distributed import init_distributed, dtype_map, perf_fn

tilelang.disable_cache()

Expand Down
2 changes: 1 addition & 1 deletion benchmark/distributed/benchmark_reduce_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pynvshmem
import tilelang
import tilelang.language as T
from tilelang.distributed.utils import init_distributed, dtype_map, perf_fn
from tilelang.distributed import init_distributed, dtype_map, perf_fn

tilelang.disable_cache()

Expand Down
2 changes: 1 addition & 1 deletion benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import argparse
import torch
import torch.distributed as dist
from tilelang.distributed.utils import init_distributed, perf_fn
from tilelang.distributed import init_distributed, perf_fn
import pynvshmem

os.environ['NCCL_DEBUG'] = 'WARN'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torch.distributed as dist
import torch.multiprocessing
from tilelang.distributed.utils import init_dist, perf_fn
from tilelang.distributed import init_dist, perf_fn

tilelang.disable_cache()
os.environ['NCCL_DEBUG'] = 'WARN'
Expand Down
4 changes: 1 addition & 3 deletions examples/distributed/example_all_to_all.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import torch
import pynvshmem
import os
import tilelang
import tilelang.language as T
from tilelang.profiler import TensorSupplyType
from tilelang.distributed.utils import init_distributed
from tilelang.distributed import init_distributed
import argparse
import random


tilelang.disable_cache()


Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/example_allgather.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pynvshmem
import tilelang
import tilelang.language as T
from tilelang.distributed.utils import init_distributed, dtype_map, perf_fn
from tilelang.distributed import init_distributed, dtype_map, perf_fn


def allgather(PE_num, M, N, dtype="float16", threads=128):
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/example_allgather_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tilelang
import tilelang.language as T
from tilelang.profiler import TensorSupplyType
from tilelang.distributed.utils import init_distributed
from tilelang.distributed import init_distributed


def allgather_gemm(PE_num, M, N, K, block_M, block_N, block_K, dtype="float16"):
Expand Down
198 changes: 198 additions & 0 deletions examples/distributed/example_allgather_gemm_ipc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import os
import tilelang
import tilelang.language as T
import argparse
import torch
import torch.distributed as dist
import torch.multiprocessing
from tilelang.distributed import init_dist
from cuda import cudart
from tilelang.distributed import set_signal, wait_eq

tilelang.disable_cache()
os.environ['NCCL_DEBUG'] = 'WARN' # silence NCCL log


def gemm_kernel(M,
N,
K,
num_rank,
block_M,
block_N,
block_K,
threads,
dtype="float16",
accum_dtype="float"):

@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N // num_rank), dtype),
C: T.Tensor((M, N // num_rank), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])

Comment on lines +33 to +44
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix grid size: using N instead of N_per_rank causes out-of-bounds on B/C

Kernel grid must reflect local N dimension per rank. Using global N overshoots.

Apply this diff:

-        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
+        with T.Kernel(T.ceildiv(N // num_rank, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
with T.Kernel(T.ceildiv(N // num_rank, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
🤖 Prompt for AI Agents
In examples/distributed/example_allgather_gemm_ipc.py around lines 33 to 44 the
kernel grid uses the global N which causes out-of-bounds accesses for B/C on
each rank; change the x-dimension grid calculation to use the per-rank local
width (replace T.ceildiv(N, block_N) with T.ceildiv(N_per_rank, block_N)) and
ensure any indexing/limits that assume the x-dimension (reads from B and writes
to C) are bounded by N_per_rank so the kernel only addresses the local slice.

return main


def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, N,
signal_target, rank, local_world_size, world_size,
intranode_ag_stream):
local_rank = rank % local_world_size
Comment on lines +48 to +51
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Parameter naming bug: N here is actually K (second dim of A)

Misnaming invites misuse. Rename to K and propagate to all uses.

Apply this diff:

-def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, N,
+def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, K,
                                       signal_target, rank, local_world_size, world_size,
                                       intranode_ag_stream):

Follow-up diffs below update its uses.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, N,
signal_target, rank, local_world_size, world_size,
intranode_ag_stream):
local_rank = rank % local_world_size
def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, K,
signal_target, rank, local_world_size, world_size,
intranode_ag_stream):
local_rank = rank % local_world_size
🤖 Prompt for AI Agents
In examples/distributed/example_allgather_gemm_ipc.py around lines 48 to 51, the
function parameter named N is incorrect — it represents K (the second dimension
of A); rename the parameter to K in the function signature and update every call
site and internal reference to use K instead of N, ensuring variable names and
any related comments/docstrings reflect the new name so behavior remains
unchanged.

n_nodes = world_size // local_world_size
node_rank = rank // local_world_size

for i in range(1, local_world_size):
segment = rank * M_per_rank * N
local_dst_rank = (local_rank + local_world_size - i) % local_world_size
src_ptr = ag_buffer[local_rank].data_ptr() + segment * local_tensor.element_size()
dst_ptr = ag_buffer[local_dst_rank].data_ptr() + segment * local_tensor.element_size()
# Using copy engine to perform intranode transmission
# Sending rank-th local tensor to other ranks inside the node.
(err,) = cudart.cudaMemcpyAsync(
dst_ptr,
src_ptr,
M_per_rank * N * local_tensor.element_size(),
cudart.cudaMemcpyKind.cudaMemcpyDefault,
intranode_ag_stream.cuda_stream,
)
# Notify the peer that the transmission is done.
Comment on lines +55 to +69
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Use K for segment sizing and add CUDA error checks

Copy size/offset must be based on the width of A (K). Also check cudaMemcpyAsync return codes.

Apply this diff:

-    for i in range(1, local_world_size):
-        segment = rank * M_per_rank * N
+    for i in range(1, local_world_size):
+        segment = rank * M_per_rank * K
         local_dst_rank = (local_rank + local_world_size - i) % local_world_size
-        src_ptr = ag_buffer[local_rank].data_ptr() + segment * local_tensor.element_size()
-        dst_ptr = ag_buffer[local_dst_rank].data_ptr() + segment * local_tensor.element_size()
+        src_ptr = ag_buffer[local_rank].data_ptr() + segment * local_tensor.element_size()
+        dst_ptr = ag_buffer[local_dst_rank].data_ptr() + segment * local_tensor.element_size()
         # Using copy engine to perform intranode transmission
         # Sending rank-th local tensor to other ranks inside the node.
         (err,) = cudart.cudaMemcpyAsync(
             dst_ptr,
             src_ptr,
-            M_per_rank * N * local_tensor.element_size(),
+            M_per_rank * K * local_tensor.element_size(),
             cudart.cudaMemcpyKind.cudaMemcpyDefault,
             intranode_ag_stream.cuda_stream,
         )
+        CUDA_CHECK(err)
         # Notify the peer that the transmission is done.
         set_signal(signal_buffer[local_dst_rank][rank], signal_target, intranode_ag_stream)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for i in range(1, local_world_size):
segment = rank * M_per_rank * N
local_dst_rank = (local_rank + local_world_size - i) % local_world_size
src_ptr = ag_buffer[local_rank].data_ptr() + segment * local_tensor.element_size()
dst_ptr = ag_buffer[local_dst_rank].data_ptr() + segment * local_tensor.element_size()
# Using copy engine to perform intranode transmission
# Sending rank-th local tensor to other ranks inside the node.
(err,) = cudart.cudaMemcpyAsync(
dst_ptr,
src_ptr,
M_per_rank * N * local_tensor.element_size(),
cudart.cudaMemcpyKind.cudaMemcpyDefault,
intranode_ag_stream.cuda_stream,
)
# Notify the peer that the transmission is done.
for i in range(1, local_world_size):
segment = rank * M_per_rank * K
local_dst_rank = (local_rank + local_world_size - i) % local_world_size
src_ptr = ag_buffer[local_rank].data_ptr() + segment * local_tensor.element_size()
dst_ptr = ag_buffer[local_dst_rank].data_ptr() + segment * local_tensor.element_size()
# Using copy engine to perform intranode transmission
# Sending rank-th local tensor to other ranks inside the node.
(err,) = cudart.cudaMemcpyAsync(
dst_ptr,
src_ptr,
M_per_rank * K * local_tensor.element_size(),
cudart.cudaMemcpyKind.cudaMemcpyDefault,
intranode_ag_stream.cuda_stream,
)
CUDA_CHECK(err)
# Notify the peer that the transmission is done.
set_signal(signal_buffer[local_dst_rank][rank], signal_target, intranode_ag_stream)

set_signal(signal_buffer[local_dst_rank][rank], signal_target, intranode_ag_stream)

for i in range(1, n_nodes):
recv_rank = local_rank + (node_rank + n_nodes - i) % n_nodes * local_world_size
recv_segment = recv_rank * M_per_rank * N
# Waiting for the internode data ready
wait_eq(signal_buffer[local_rank][recv_rank], signal_target, intranode_ag_stream)
src_ptr = ag_buffer[local_rank].data_ptr() + recv_segment * local_tensor.element_size()
for j in range(1, local_world_size):
local_dst_rank = (local_rank + local_world_size - j) % local_world_size
dst_ptr = ag_buffer[local_dst_rank].data_ptr(
) + recv_segment * local_tensor.element_size()
# Sending (local_rank + j*local_world_size) % world_size -th local tensor to other ranks inside the node.
(err,) = cudart.cudaMemcpyAsync(
dst_ptr,
src_ptr,
M_per_rank * N * local_tensor.element_size(),
cudart.cudaMemcpyKind.cudaMemcpyDefault,
intranode_ag_stream.cuda_stream,
)
# Notify the peer that the transmission is done.
set_signal(signal_buffer[local_dst_rank][recv_rank], signal_target, intranode_ag_stream)

Comment on lines +72 to +92
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Likewise for inter-node segment sizing and error checks

Mirror the K-based sizing and add error handling.

Apply this diff:

-    for i in range(1, n_nodes):
-        recv_rank = local_rank + (node_rank + n_nodes - i) % n_nodes * local_world_size
-        recv_segment = recv_rank * M_per_rank * N
+    for i in range(1, n_nodes):
+        recv_rank = local_rank + (node_rank + n_nodes - i) % n_nodes * local_world_size
+        recv_segment = recv_rank * M_per_rank * K
         # Waiting for the internode data ready
         wait_eq(signal_buffer[local_rank][recv_rank], signal_target, intranode_ag_stream)
         src_ptr = ag_buffer[local_rank].data_ptr() + recv_segment * local_tensor.element_size()
         for j in range(1, local_world_size):
             local_dst_rank = (local_rank + local_world_size - j) % local_world_size
             dst_ptr = ag_buffer[local_dst_rank].data_ptr(
             ) + recv_segment * local_tensor.element_size()
             # Sending (local_rank + j*local_world_size) % world_size -th local tensor to other ranks inside the node.
             (err,) = cudart.cudaMemcpyAsync(
                 dst_ptr,
                 src_ptr,
-                M_per_rank * N * local_tensor.element_size(),
+                M_per_rank * K * local_tensor.element_size(),
                 cudart.cudaMemcpyKind.cudaMemcpyDefault,
                 intranode_ag_stream.cuda_stream,
             )
+            CUDA_CHECK(err)
             # Notify the peer that the transmission is done.
             set_signal(signal_buffer[local_dst_rank][recv_rank], signal_target, intranode_ag_stream)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for i in range(1, n_nodes):
recv_rank = local_rank + (node_rank + n_nodes - i) % n_nodes * local_world_size
recv_segment = recv_rank * M_per_rank * N
# Waiting for the internode data ready
wait_eq(signal_buffer[local_rank][recv_rank], signal_target, intranode_ag_stream)
src_ptr = ag_buffer[local_rank].data_ptr() + recv_segment * local_tensor.element_size()
for j in range(1, local_world_size):
local_dst_rank = (local_rank + local_world_size - j) % local_world_size
dst_ptr = ag_buffer[local_dst_rank].data_ptr(
) + recv_segment * local_tensor.element_size()
# Sending (local_rank + j*local_world_size) % world_size -th local tensor to other ranks inside the node.
(err,) = cudart.cudaMemcpyAsync(
dst_ptr,
src_ptr,
M_per_rank * N * local_tensor.element_size(),
cudart.cudaMemcpyKind.cudaMemcpyDefault,
intranode_ag_stream.cuda_stream,
)
# Notify the peer that the transmission is done.
set_signal(signal_buffer[local_dst_rank][recv_rank], signal_target, intranode_ag_stream)
for i in range(1, n_nodes):
recv_rank = local_rank + (node_rank + n_nodes - i) % n_nodes * local_world_size
recv_segment = recv_rank * M_per_rank * K
# Waiting for the internode data ready
wait_eq(signal_buffer[local_rank][recv_rank], signal_target, intranode_ag_stream)
src_ptr = ag_buffer[local_rank].data_ptr() + recv_segment * local_tensor.element_size()
for j in range(1, local_world_size):
local_dst_rank = (local_rank + local_world_size - j) % local_world_size
dst_ptr = ag_buffer[local_dst_rank].data_ptr(
) + recv_segment * local_tensor.element_size()
# Sending (local_rank + j*local_world_size) % world_size -th local tensor to other ranks inside the node.
(err,) = cudart.cudaMemcpyAsync(
dst_ptr,
src_ptr,
M_per_rank * K * local_tensor.element_size(),
cudart.cudaMemcpyKind.cudaMemcpyDefault,
intranode_ag_stream.cuda_stream,
)
CUDA_CHECK(err)
# Notify the peer that the transmission is done.
set_signal(signal_buffer[local_dst_rank][recv_rank], signal_target, intranode_ag_stream)
🧰 Tools
🪛 Ruff (0.13.1)

83-83: Unpacked variable err is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
In examples/distributed/example_allgather_gemm_ipc.py around lines 72 to 92, the
inter-node segment sizing currently uses M_per_rank but should mirror the
K-based sizing (use K_per_rank) and it lacks error checks on cudaMemcpyAsync;
change recv_segment and the byte count passed to cudaMemcpyAsync to use
K_per_rank * N (times element_size), check the returned err from
cudart.cudaMemcpyAsync and raise or log a RuntimeError if err is non-zero, and
only call set_signal after the copy completed successfully (or after handling
the error).


def ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, N, signal_target, rank, group,
local_world_size, world_size, gemm_kernel, ag_stream):

dist.barrier(group)

# all_gather A to ag_buffer
with torch.cuda.stream(ag_stream):
cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, N, signal_target,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Pass correct K dimension to cp_engine_producer_all_gather_put.

The function call passes N but should pass K to match the corrected function signature.

Apply this diff:

-        cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, N, signal_target,
+        cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, K, signal_target,
                                           rank, local_world_size, world_size, ag_stream)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, N, signal_target,
cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, K, signal_target,
rank, local_world_size, world_size, ag_stream)
🤖 Prompt for AI Agents
In examples/distributed/example_allgather_gemm_ipc.py around line 101, the call
to cp_engine_producer_all_gather_put incorrectly passes N as the K dimension;
update the call to pass K instead of N so the arguments match the corrected
function signature (replace the N argument with K in that function call).

rank, local_world_size, world_size, ag_stream)

current_stream = torch.cuda.current_stream()
current_stream.wait_stream(ag_stream)

dist.barrier(group)
torch.cuda.synchronize()

torch.cuda.synchronize()
torch.distributed.barrier(group)
gemm_kernel(ag_buffer[rank], B, C)
torch.cuda.synchronize()
torch.distributed.barrier(group)
Comment on lines +110 to +114
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Use local peer for kernel input; drop redundant sync

Fix out-of-range access when rank >= local_world_size. Also remove duplicate synchronize.

Apply this diff:

-    torch.cuda.synchronize()
-    torch.distributed.barrier(group)
-    gemm_kernel(ag_buffer[rank], B, C)
+    gemm_kernel(ag_buffer[local_rank], B, C)
     torch.cuda.synchronize()
-    torch.distributed.barrier(group)
+    torch.distributed.barrier(group)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
torch.cuda.synchronize()
torch.distributed.barrier(group)
gemm_kernel(ag_buffer[rank], B, C)
torch.cuda.synchronize()
torch.distributed.barrier(group)
gemm_kernel(ag_buffer[local_rank], B, C)
torch.cuda.synchronize()
torch.distributed.barrier(group)
🤖 Prompt for AI Agents
In examples/distributed/example_allgather_gemm_ipc.py around lines 110 to 114,
the code uses the global rank as the index into ag_buffer which can be
out-of-range when rank >= local_world_size; change the kernel input index to a
local peer index (e.g., local_peer = rank % local_world_size or however
local_world_size is computed in this module) and call
gemm_kernel(ag_buffer[local_peer], B, C) instead of ag_buffer[rank]; also remove
the redundant torch.cuda.synchronize() (keep a single synchronize after the
kernel) so you only call torch.distributed.barrier(group), gemm_kernel(... with
local_peer ...), torch.cuda.synchronize(), and torch.distributed.barrier(group).


return C


def torch_ag_gemm(
pg: torch.distributed.ProcessGroup,
local_input: torch.Tensor,
local_weight: torch.Tensor,
ag_out: torch.Tensor,
):
torch.distributed.all_gather_into_tensor(ag_out, local_input, pg)
ag_gemm_output = torch.matmul(ag_out, local_weight)
return ag_gemm_output


def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
dtype = torch.float16
M = args.M if args else 8192
N = args.N if args else 8192
K = args.K if args else 8192
M_per_rank = M // num_local_ranks
N_per_rank = N // num_local_ranks

BLOCK_M = 128
BLOCK_N = 128
BLOCK_K = 64
threads = 256

rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
allocator = tilelang.get_allocator(
size=2**30,
device="cuda",
is_distributed=True,
local_rank=local_rank,
num_local_ranks=num_local_ranks,
group=group)
kernel = tilelang.compile(gemm_kernel(M, N, K, num_ranks, BLOCK_M, BLOCK_N, BLOCK_K, threads))
Comment on lines +130 to +151
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Compute per-rank sizes using the global world size and after init_dist

Current M_per_rank/N_per_rank use num_local_ranks, which breaks multi-node (OOB segment indexing).

Apply this diff:

-    M = args.M if args else 8192
-    N = args.N if args else 8192
-    K = args.K if args else 8192
-    M_per_rank = M // num_local_ranks
-    N_per_rank = N // num_local_ranks
+    M = args.M if args else 8192
+    N = args.N if args else 8192
+    K = args.K if args else 8192
@@
-    rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
+    rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
+    M_per_rank = M // num_ranks
+    N_per_rank = N // num_ranks
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
dtype = torch.float16
M = args.M if args else 8192
N = args.N if args else 8192
K = args.K if args else 8192
M_per_rank = M // num_local_ranks
N_per_rank = N // num_local_ranks
BLOCK_M = 128
BLOCK_N = 128
BLOCK_K = 64
threads = 256
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
allocator = tilelang.get_allocator(
size=2**30,
device="cuda",
is_distributed=True,
local_rank=local_rank,
num_local_ranks=num_local_ranks,
group=group)
kernel = tilelang.compile(gemm_kernel(M, N, K, num_ranks, BLOCK_M, BLOCK_N, BLOCK_K, threads))
def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
dtype = torch.float16
M = args.M if args else 8192
N = args.N if args else 8192
K = args.K if args else 8192
BLOCK_M = 128
BLOCK_N = 128
BLOCK_K = 64
threads = 256
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
M_per_rank = M // num_ranks
N_per_rank = N // num_ranks
allocator = tilelang.get_allocator(
size=2**30,
device="cuda",
is_distributed=True,
local_rank=local_rank,
num_local_ranks=num_local_ranks,
group=group)
kernel = tilelang.compile(
gemm_kernel(M, N, K, num_ranks, BLOCK_M, BLOCK_N, BLOCK_K, threads)
)
🤖 Prompt for AI Agents
In examples/distributed/example_allgather_gemm_ipc.py around lines 130-151,
M_per_rank and N_per_rank are computed too early using num_local_ranks which
breaks multi-node indexing; move the M_per_rank/N_per_rank calculations to after
the init_dist(...) call and compute them with the global world size (use
num_ranks returned by init_dist), e.g. M_per_rank = M // num_ranks and
N_per_rank = N // num_ranks, ensuring the per-rank sizes are based on the global
number of ranks and evaluated after init_dist.

kernel.initialize(allocator=allocator)
if local_rank == 0:
print(kernel.get_kernel_source())

A = tilelang.tensor((M_per_rank, K), dtype, allocator=allocator).normal_()
B = tilelang.tensor((K, N_per_rank), dtype, allocator=allocator).normal_()
C = tilelang.tensor((M, N_per_rank), dtype, allocator=allocator)
ag_buffer = tilelang.tensor((M, K), dtype, allocator=allocator, return_peers=True)
signal_buffer = tilelang.tensor((num_local_ranks,),
torch.int32,
allocator=allocator,
return_peers=True)
signal_buffer[rank].fill_(0)
ag_buffer[rank][rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A)

Comment on lines +156 to +166
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix peer buffer/signal shapes and local indexing

  • signal_buffer inner length must be world_size (you index by global rank).
  • Index ag_buffer/signal_buffer by local_rank, not global rank.

Apply this diff:

-    A = tilelang.tensor((M_per_rank, K), dtype, allocator=allocator).normal_()
-    B = tilelang.tensor((K, N_per_rank), dtype, allocator=allocator).normal_()
-    C = tilelang.tensor((M, N_per_rank), dtype, allocator=allocator)
-    ag_buffer = tilelang.tensor((M, K), dtype, allocator=allocator, return_peers=True)
-    signal_buffer = tilelang.tensor((num_local_ranks,),
-                                    torch.int32,
-                                    allocator=allocator,
-                                    return_peers=True)
-    signal_buffer[rank].fill_(0)
-    ag_buffer[rank][rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A)
+    A = tilelang.tensor((M_per_rank, K), dtype, allocator=allocator).normal_()
+    B = tilelang.tensor((K, N_per_rank), dtype, allocator=allocator).normal_()
+    C = tilelang.tensor((M, N_per_rank), dtype, allocator=allocator)
+    ag_buffer = tilelang.tensor((M, K), dtype, allocator=allocator, return_peers=True)
+    signal_buffer = tilelang.tensor((num_ranks,),
+                                    torch.int32,
+                                    allocator=allocator,
+                                    return_peers=True)
+    signal_buffer[local_rank].fill_(0)
+    ag_buffer[local_rank][rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
A = tilelang.tensor((M_per_rank, K), dtype, allocator=allocator).normal_()
B = tilelang.tensor((K, N_per_rank), dtype, allocator=allocator).normal_()
C = tilelang.tensor((M, N_per_rank), dtype, allocator=allocator)
ag_buffer = tilelang.tensor((M, K), dtype, allocator=allocator, return_peers=True)
signal_buffer = tilelang.tensor((num_local_ranks,),
torch.int32,
allocator=allocator,
return_peers=True)
signal_buffer[rank].fill_(0)
ag_buffer[rank][rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A)
A = tilelang.tensor((M_per_rank, K), dtype, allocator=allocator).normal_()
B = tilelang.tensor((K, N_per_rank), dtype, allocator=allocator).normal_()
C = tilelang.tensor((M, N_per_rank), dtype, allocator=allocator)
ag_buffer = tilelang.tensor((M, K), dtype, allocator=allocator, return_peers=True)
signal_buffer = tilelang.tensor((num_ranks,),
torch.int32,
allocator=allocator,
return_peers=True)
signal_buffer[local_rank].fill_(0)
ag_buffer[local_rank][rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A)
🤖 Prompt for AI Agents
In examples/distributed/example_allgather_gemm_ipc.py around lines 156 to 166,
the peer signal buffer is created with length num_local_ranks and both ag_buffer
and signal_buffer are indexed by global rank; change the signal_buffer inner
length to world_size (since you index by global ranks across all peers) and
change all indexes that use rank to use local_rank instead (e.g., access
ag_buffer and signal_buffer via local_rank when writing local data/signal and
keep slices based on local offsets using local_rank * M_per_rank ..
(local_rank+1) * M_per_rank). Ensure the buffer constructions use allocator and
return_peers as before but with the corrected shape for signal_buffer and
replace global rank indexing with local_rank for local writes.

dist.barrier(group)

ag_stream = torch.cuda.Stream()
signal_target = 1

tilelang_C = ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank,
group, num_local_ranks, num_local_ranks, kernel, ag_stream)

Comment on lines +172 to +174
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Pass the correct world size (global) to ag_gemm_op

You currently pass num_local_ranks for both local and global sizes. Use the true num_ranks for world_size.

Apply this diff:

-    tilelang_C = ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank,
-                            group, num_local_ranks, num_local_ranks, kernel, ag_stream)
+    tilelang_C = ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank,
+                            group, num_local_ranks, num_ranks, kernel, ag_stream)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
tilelang_C = ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank,
group, num_local_ranks, num_local_ranks, kernel, ag_stream)
tilelang_C = ag_gemm_op(
A, B, C,
ag_buffer, signal_buffer,
M_per_rank, K,
signal_target, rank,
group,
num_local_ranks, num_ranks,
kernel, ag_stream
)
🤖 Prompt for AI Agents
In examples/distributed/example_allgather_gemm_ipc.py around lines 173 to 175,
the call to ag_gemm_op passes num_local_ranks twice (for both local and
global/world sizes); change the second num_local_ranks argument to the true
global world size variable num_ranks so the function receives local_size,
world_size correctly—update the ag_gemm_op call to pass num_ranks as the
world_size parameter.

torch_ag_buffer = torch.empty([M, K], dtype=dtype, device="cuda")
torch_C = torch_ag_gemm(group, A, B, torch_ag_buffer)

if torch.allclose(torch_C, tilelang_C, atol=1e-6, rtol=1e-6):
print(f"rank {local_rank} check passed.✅")
else:
print(f"rank {local_rank} check failed.❌")
print(f"torch_C: {torch_C}, tilelang_C: {tilelang_C}")
raise ValueError("Test failed")

dist.destroy_process_group()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--num-processes', type=int, default=2, help='Number of processes to spawn (default: 2)')
parser.add_argument('--M', type=int, default=8192, help='M dimension')
parser.add_argument('--N', type=int, default=8192, help='N dimension')
parser.add_argument('--K', type=int, default=8192, help='K dimension')
args = parser.parse_args()
num_processes = args.num_processes

torch.multiprocessing.spawn(main, args=(num_processes, args), nprocs=num_processes)
2 changes: 1 addition & 1 deletion examples/distributed/example_cannon.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pynvshmem
import tilelang
import tilelang.language as T
from tilelang.distributed.utils import init_distributed, dtype_map
from tilelang.distributed import init_distributed, dtype_map
import math
import argparse

Expand Down
Loading
Loading