Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions examples/distributed/example_overlapping_allgather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import torch
import torch.distributed as dist
import pynvshmem
import tilelang
import tilelang.language as T
import os
from tilelang.distributed.utils import init_distributed
from tilelang.env import env
from packaging import version
import importlib.metadata

cuda_python_version = importlib.metadata.version("cuda-python")
if version.parse(cuda_python_version) >= version.parse("12.8.0"):
from cuda.bindings import runtime as cudart
else:
from cuda import cudart
# NODES=2 NODE_RANK=0 ARNOLD_WORKER_0_HOST=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/example_overlapping_allgather.py
# NODES=2 NODE_RANK=1 ARNOLD_WORKER_0_HOST=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/example_overlapping_allgather.py


def internode_gather(M, local_world_size, block_M, threads):

@T.prim_func
def main(
dst: T.Tensor((M), "float32"),
src: T.Tensor((M), "float32"),
):
with T.Kernel(T.ceildiv(M, block_M), threads=threads) as (bx):
rank = T.alloc_local([1], "uint64")
rank[0] = (T.get_pe() + local_world_size) % (2 * local_world_size) # 2 nodes
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Hardcoded 2-node assumption limits flexibility.

The rank computation assumes exactly 2 nodes:

rank[0] = (T.get_pe() + local_world_size) % (2 * local_world_size)

Consider either:

  1. Parameterizing the number of nodes
  2. Adding a clear assertion/comment that this example requires exactly 2 nodes
  3. Computing the number of nodes dynamically from environment variables

Example parameterization:

-def internode_gather(M, local_world_size, block_M, threads):
+def internode_gather(M, local_world_size, num_nodes, block_M, threads):
     @T.prim_func
     def main(
             dst: T.Tensor((M), "float32"),
             src: T.Tensor((M), "float32"),
     ):
         with T.Kernel(T.ceildiv(M, block_M), threads=threads) as (bx):
             rank = T.alloc_local([1], "uint64")
-            rank[0] = (T.get_pe() + local_world_size) % (2 * local_world_size)  # 2 nodes
+            rank[0] = (T.get_pe() + local_world_size) % (num_nodes * local_world_size)
🤖 Prompt for AI Agents
In examples/distributed/example_overlapping_allgather.py around line 33, the
rank calculation hardcodes a 2-node assumption using (2 * local_world_size),
which limits flexibility; replace this with a configurable or dynamically
computed node count (e.g., derive num_nodes from environment or from
total_world_size // local_world_size) or add an explicit assertion/comment
stating the example requires exactly 2 nodes; update rank computation to use
num_nodes instead of 2 and validate inputs (raise/assert on mismatch) so the
example works for arbitrary node counts or clearly documents its 2-node
requirement.

T.putmem_nbi_block(
T.address_of(dst[bx * block_M]), T.address_of(src[bx * block_M]), block_M * 4,
rank[0])

return main


def intranode_gather(M, world_size, block_M, threads):

@T.prim_func
def main(
dst: T.Tensor((M * world_size), "float32"),
src: T.Tensor((M * 2), "float32"),
):
with T.Kernel(T.ceildiv(M, block_M), threads=threads) as (bx):
rank = T.alloc_local([1], "uint64")
num_rank = T.alloc_local([1], "uint64")
rank[0] = T.get_rank()
num_rank[0] = T.get_num_ranks()
tid = T.get_thread_binding()
if tid == 0:
T.print(T.cast(rank[0], "int32"), msg="signal")
T.print(T.cast(num_rank[0], "int32"), msg="signal")
for k in T.serial(world_size // 2): # 2 node
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Hardcoded 2-node assumption in loop.

Similar to the internode kernel, this loop assumes exactly 2 nodes:

for k in T.serial(world_size // 2):  # 2 node

Consider parameterizing or documenting this constraint clearly. If this example is specifically designed for 2 nodes, add an early assertion:

if __name__ == '__main__':
    num_nodes = int(os.environ.get('NODES', 2))
    assert num_nodes == 2, "This example requires exactly 2 nodes"
🤖 Prompt for AI Agents
In examples/distributed/example_overlapping_allgather.py around line 57, the
loop uses a hardcoded 2-node assumption ("for k in T.serial(world_size // 2):  #
2 node"), which should be made explicit or guarded; either parameterize the
behavior to derive the loop bound from a configurable num_nodes/worker_count
variable or add an early runtime assertion that the example requires exactly 2
nodes (readable from an env var or CLI) and fail fast with a clear message;
update top-of-script argument/env parsing or add the assert in the main entry so
the loop remains correct for the intended number of nodes.

T.put_block(
src=T.address_of(src[bx * block_M]),
dst=T.address_of(dst[bx * block_M + rank[0] * M]),
size=block_M,
dst_pe=k,
)
T.put_block(
src=T.address_of(src[bx * block_M + M]),
dst=T.address_of(dst[bx * block_M + M * num_rank[0] + rank[0] * M]),
size=block_M,
dst_pe=k,
)

return main


if __name__ == '__main__':
tilelang.disable_cache()

M = 2
K = 12288
#for 2 node(16 GPUs), world_size=16,rank is 0-15,local rank is 0-7
WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP, LC_GROUP = init_distributed(
return_tp_group=True, return_lc_group=True)
local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0))

allocator = tilelang.get_allocator(
size=2**25,
device="cuda",
is_distributed=True,
local_rank=LOCAL_RANK,
num_local_ranks=local_world_size,
group=LC_GROUP)
print(local_world_size, LOCAL_RANK)

# Each rank sends the local_tensor to ranks of other nodes with the same local_rank
# Assuming there are 2 nodes, each with 4 workers
# 0-th local tensor ([0] -> [4]), 4-th local tensor ([4] -> [0])
# 1-th local tensor ([1] -> [5]), 5-th local tensor ([5] -> [1])
# 2-th local tensor ([2] -> [6]), 6-th local tensor ([6] -> [2])
# 3-th local tensor ([3] -> [7]), 7-th local tensor ([7] -> [3])
interkernel = tilelang.compile(internode_gather(M, local_world_size, M, 128))
if LOCAL_RANK == 0:
print(interkernel.get_kernel_source())
src = pynvshmem.nvshmem_create_tensor([M], torch.float32)
dst = pynvshmem.nvshmem_create_tensor([M], torch.float32)
input_data = torch.ones([M], dtype=torch.float32, device='cuda') * RANK
src.copy_(input_data)

pynvshmem.nvshmem_barrier_all()
dist.barrier(TP_GROUP)
interkernel(dst, src)
pynvshmem.nvshmem_barrier_all()

# Each rank sends the local_tensor and the received internode tensors to intranode ranks.
# 0-th and 4-th local tensors ([0]->[1,2,3])
# 1-th and 5-th local tensors ([1]->[0,2,3])
# 2-th and 6-th local tensors ([2]->[0,1,3])
# 3-th and 7-th local tensors ([3]->[0,1,2])
# 0-th and 4-th local tensors ([4]->[5,6,7])
# 1-th and 5-th local tensors ([5]->[4,6,7])
# 2-th and 6-th local tensors ([6]->[4,5,7])
# 3-th and 7-th local tensors ([7]->[4,5,6])
src_intra = tilelang.tensor((M * 2), torch.float32, allocator=allocator).normal_()
dst_intra = tilelang.tensor((M * WORLD_SIZE), torch.float32, allocator=allocator)
if RANK < WORLD_SIZE / 2:
cudart.cudaMemcpy(src_intra.data_ptr(), src.data_ptr(), M * 4,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, dst.data_ptr(), M * 4,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
else:
cudart.cudaMemcpy(src_intra.data_ptr(), dst.data_ptr(), M * 4,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, src.data_ptr(), M * 4,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
Comment on lines +122 to +130
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add error checking for CUDA memory operations.

The cudaMemcpy calls don't check return codes. CUDA operations can fail silently, leading to incorrect results.

The codebase uses CUDA_CHECK for error handling (see tilelang/distributed/utils.py lines 249-257). Apply error checking:

+    from tilelang.distributed.utils import CUDA_CHECK
+
     if RANK < WORLD_SIZE / 2:
-        cudart.cudaMemcpy(src_intra.data_ptr(), src.data_ptr(), M * 4,
-                          cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
-        cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, dst.data_ptr(), M * 4,
-                          cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
+        CUDA_CHECK(cudart.cudaMemcpy(src_intra.data_ptr(), src.data_ptr(), M * 4,
+                                      cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice))
+        CUDA_CHECK(cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, dst.data_ptr(), M * 4,
+                                      cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice))
     else:
-        cudart.cudaMemcpy(src_intra.data_ptr(), dst.data_ptr(), M * 4,
-                          cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
-        cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, src.data_ptr(), M * 4,
-                          cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
+        CUDA_CHECK(cudart.cudaMemcpy(src_intra.data_ptr(), dst.data_ptr(), M * 4,
+                                      cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice))
+        CUDA_CHECK(cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, src.data_ptr(), M * 4,
+                                      cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
cudart.cudaMemcpy(src_intra.data_ptr(), src.data_ptr(), M * 4,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, dst.data_ptr(), M * 4,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
else:
cudart.cudaMemcpy(src_intra.data_ptr(), dst.data_ptr(), M * 4,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, src.data_ptr(), M * 4,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
from tilelang.distributed.utils import CUDA_CHECK
if RANK < WORLD_SIZE / 2:
CUDA_CHECK(cudart.cudaMemcpy(src_intra.data_ptr(), src.data_ptr(), M * 4,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice))
CUDA_CHECK(cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, dst.data_ptr(), M * 4,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice))
else:
CUDA_CHECK(cudart.cudaMemcpy(src_intra.data_ptr(), dst.data_ptr(), M * 4,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice))
CUDA_CHECK(cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, src.data_ptr(), M * 4,
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice))
🤖 Prompt for AI Agents
In examples/distributed/example_overlapping_allgather.py around lines 125 to
133, the cudart.cudaMemcpy calls are missing error checking; wrap each
cudaMemcpy invocation with the project's CUDA_CHECK helper (the same pattern
used in tilelang/distributed/utils.py lines ~249-257) so that the return code is
validated and any CUDA errors are surfaced. Replace each raw
cudart.cudaMemcpy(...) call with a CUDA_CHECK invocation that passes the
cudaMemcpy call and preserves the same arguments and intent.


env.USE_NVSHMEM = False
intrakernel = tilelang.compile(
intranode_gather(M, WORLD_SIZE, M, 128),
pass_configs={tilelang.PassConfigKey.TL_DISABLE_RDC: True})
intrakernel.initialize(allocator=allocator)
if LOCAL_RANK == 0:
print(intrakernel.get_kernel_source())
torch.cuda.synchronize()
torch.distributed.barrier(LC_GROUP)
intrakernel(dst_intra, src_intra)
torch.cuda.synchronize()
torch.distributed.barrier(LC_GROUP)

print(dst_intra)
4 changes: 2 additions & 2 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ std::string CodeGenTileLangCUDA::Finish() {
}

if (use_nvshmem_) {
decl_stream << "#include <nvshmem.h>>\n";
decl_stream << "#include <nvshmemx.h>>\n";
decl_stream << "#include <nvshmem.h>\n";
decl_stream << "#include <nvshmemx.h>\n";
}

if (need_cooperative_groups_) {
Expand Down
8 changes: 6 additions & 2 deletions tilelang/distributed/launch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@ nproc_per_node=${GPUS:=$(nvidia-smi --list-gpus | wc -l)} # set env var. `GPUS`
nnodes=${NODES:=1} # set env var. `NODES` to # of nodes
node_rank=${NODE_RANK:=0} # set env var. `NODE_RANK` to the rank of current node

master_addr="127.0.0.1"
master_port="$(( RANDOM % 100 + 23400 ))" # randomly choose a port between 23400 and 23499
master_addr=${ARNOLD_WORKER_0_HOST:="127.0.0.1"}
if [ -z ${ARNOLD_WORKER_0_PORT} ]; then
master_port="8361"
else
master_port=$(echo "$ARNOLD_WORKER_0_PORT" | cut -d "," -f 1)
fi
additional_args="--rdzv_endpoint=${master_addr}:${master_port}"
IB_HCA=mlx5

Expand Down
12 changes: 10 additions & 2 deletions tilelang/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,14 @@ def init_dist(local_rank: int, num_local_ranks: int):
list(range(num_local_ranks * num_nodes)))


def init_distributed(return_tp_group=False, init_nvshmem=True):
def init_distributed(return_tp_group=False, init_nvshmem=True, return_lc_group=False):
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
RANK = int(os.environ.get("RANK", 0))
LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0))

torch.distributed.init_process_group(
backend="nccl",
device_id=torch.device(f'cuda:{LOCAL_RANK}'),
world_size=WORLD_SIZE,
rank=RANK,
timeout=datetime.timedelta(seconds=1800),
Expand All @@ -76,7 +77,14 @@ def init_distributed(return_tp_group=False, init_nvshmem=True):
import pynvshmem
pynvshmem.init_nvshmem_by_uniqueid(TP_GROUP)

if return_tp_group:
if return_lc_group:
local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
base = (RANK // local_world_size) * local_world_size
LC_GROUP = torch.distributed.new_group(
list(range(base, base + local_world_size)), backend="nccl")

return WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP, LC_GROUP
elif return_tp_group:
return WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP
else:
return WORLD_SIZE, RANK, LOCAL_RANK
Expand Down
2 changes: 1 addition & 1 deletion tilelang/utils/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _init_table(self):
] * self._group.size()
local_ipc_handle = _create_ipc_handle(self._base_ptr.value)
dist.all_gather_object(ipc_handles, local_ipc_handle, self._group)
buffer_ptrs = torch.empty(self._group.size(), dtype=torch.uint64)
buffer_ptrs = torch.empty(self._group.size(), dtype=torch.uint64, device='cuda')
_sync_ipc_handles(self._local_rank, device_ids,
ctypes.c_void_p(buffer_ptrs.data_ptr()).value, ipc_handles, None)
buffer_ptrs[self._local_rank] = self._base_ptr.value
Expand Down
Loading