Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 192 additions & 0 deletions examples/distributed/primitives/example_tilescale_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import os
import tilelang
import tilelang.language as T
import argparse
import torch
import torch.distributed as dist
import torch.multiprocessing
from tilelang.distributed import init_dist

tilelang.disable_cache()
os.environ['NCCL_DEBUG'] = 'WARN' # silence NCCL log


@tilelang.jit
def get_kernel(M, N, block_M, block_N, threads, kernel='simt_push_tile'):

@T.prim_func
def simt_push_buffer(
dst: T.Tensor((M, N), "float32"),
src: T.Tensor((M, N), "float32"),
):
with T.Kernel((1), threads=threads):
rank = T.alloc_local([1], "uint64")
rank[0] = T.get_rank()

T.copy(
src,
dst,
dst_pe=1 - rank[0],
disable_tma=True # Ensure testing SIMT remote copy
)

@T.prim_func
def simt_push_tile(
dst: T.Tensor((M, N), "float32"),
src: T.Tensor((M, N), "float32"),
):
with T.Kernel(M // block_M, N // block_N, threads=threads) as (bx, by):
rank = T.alloc_local([1], "uint64")
rank[0] = T.get_rank()

smem = T.alloc_shared((block_M, block_N), "float32")
T.annotate_layout({smem: tilelang.layout.make_swizzled_layout(smem)})

T.copy(
src[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N],
smem,
disable_tma=True # Ensure testing SIMT remote copy
)

T.copy(
smem,
dst[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N],
dst_pe=1 - rank[0],
disable_tma=True # Ensure testing SIMT remote copy
)

@T.prim_func
def simt_pull_tile(
dst: T.Tensor((M, N), "float32"),
src: T.Tensor((M, N), "float32"),
):
with T.Kernel(M // block_M, N // block_N, threads=threads) as (bx, by):
rank = T.alloc_local([1], "uint64")
rank[0] = T.get_rank()

smem = T.alloc_shared((block_M, block_N), "float32")
T.annotate_layout({smem: tilelang.layout.make_swizzled_layout(smem)})

T.copy(
src[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N],
smem,
src_pe=1 - rank[0],
disable_tma=True # Ensure testing SIMT remote copy
)

T.copy(
smem,
dst[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N],
disable_tma=True # Ensure testing SIMT remote copy
)

# TMA kernel requires run-time aware peer rank
@T.prim_func
def tma_load_tile(
dst: T.Tensor((M, N), "float32"),
src: T.Tensor((M, N), "float32"),
):
with T.Kernel(M // block_M, N // block_N, threads=threads) as (bx, by):

smem = T.alloc_shared((block_M, block_N), "float32")
T.annotate_layout({smem: tilelang.layout.make_swizzled_layout(smem)})

# TMA load
T.copy(
src[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N],
smem,
src_pe=1 - T.get_rank(),
# NOTE(wt): We cannot use rank[0] as above for TMA remote copy currently.
)

T.copy(
smem,
dst[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N],
disable_tma=True # Ensure testing SIMT remote copy
)

@T.prim_func
def tma_store_tile(
dst: T.Tensor((M, N), "float32"),
src: T.Tensor((M, N), "float32"),
):
with T.Kernel(M // block_M, N // block_N, threads=threads) as (bx, by):

smem = T.alloc_shared((block_M, block_N), "float32")
T.annotate_layout({smem: tilelang.layout.make_swizzled_layout(smem)})

T.copy(
src[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N],
smem,
disable_tma=True # Ensure testing SIMT remote copy
)

# TMA store
T.copy(
smem,
dst[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N],
dst_pe=1 - T.get_rank())

return {
'simt_push_buffer': simt_push_buffer,
'simt_push_tile': simt_push_tile,
'simt_pull_tile': simt_pull_tile,
'tma_load_tile': tma_load_tile,
'tma_store_tile': tma_store_tile
}[kernel]


def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
M = args.M
N = args.N
BLOCK_M = 64
BLOCK_N = 128
threads = 128
assert num_local_ranks == 2, "this example only supports 2 ranks copying to each other"

_, _, group = init_dist(local_rank, num_local_ranks)
allocator = tilelang.get_allocator(
size=2**25,
device="cuda",
is_distributed=True,
local_rank=local_rank,
num_local_ranks=num_local_ranks,
group=group)

kernel = get_kernel(M, N, BLOCK_M, BLOCK_N, threads, kernel=args.kernel)
kernel.initialize(allocator=allocator)
if local_rank == 0:
print(kernel.get_kernel_source())

src = tilelang.tensor((M, N), torch.float32, allocator=allocator).normal_()
dst = tilelang.tensor((M, N), torch.float32, allocator=allocator)

torch.cuda.synchronize()
torch.distributed.barrier(group)
kernel(dst, src)
torch.cuda.synchronize()
torch.distributed.barrier(group)

dst_torchs = [torch.empty_like(src) for _ in range(num_local_ranks)]
dist.all_gather(dst_torchs, src, group)
dst_torch = dst_torchs[local_rank ^ 1]

if torch.allclose(dst_torch, dst, atol=1e-6, rtol=1e-6):
print(f"rank {local_rank} check passed.✅")
else:
print(f"rank {local_rank} check failed.❌")
print(f"dst_torch: {dst_torch}, dst: {dst}")
raise ValueError("Test failed")

dist.destroy_process_group()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--M', type=int, default=1024, help='M dimension')
parser.add_argument('--N', type=int, default=1024, help='N dimension')
parser.add_argument('--kernel', type=str, default='simt_push_tile', help='kernel to use')
args = parser.parse_args()
num_processes = 2

torch.multiprocessing.spawn(main, args=(num_processes, args), nprocs=num_processes)
42 changes: 39 additions & 3 deletions src/op/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,45 @@ Copy::Copy(Array<PrimExpr> args, BufferMap vmap) {
if (args.size() >= 5) {
node->eviction_policy = args[4].as<IntImmNode>()->value;
}

// Parse remote copy params
if (args.size() >= 6) {
node->src_pe = args[5];
}
if (args.size() >= 7) {
node->dst_pe = args[6];
}

ICHECK(!(node->is_remote_push() && node->is_remote_pull()))
<< "At least one of src_pe or dst_pe must be local rank";

if (node->is_remote_push()) {
ICHECK(node->dst.scope() == "global")
<< "Can only copy to peer's global memory, but got "
<< node->dst.scope();
} else if (node->is_remote_pull()) {
ICHECK(node->src.scope() == "global")
<< "Can only pull from peer's global memory, but got "
<< node->src.scope();
}

data_ = std::move(node);
}

bool CopyNode::is_remote_push() const {
return !(dst_pe->IsInstance<IntImmNode>() &&
dst_pe.as<IntImmNode>()->value == -1);
}

bool CopyNode::is_remote_pull() const {
return !(src_pe->IsInstance<IntImmNode>() &&
src_pe.as<IntImmNode>()->value == -1);
}

bool CopyNode::is_remote_copy() const {
return is_remote_push() || is_remote_pull();
}

/**
* @brief Create a shallow clone of this CopyNode as a TileOperator.
*
Expand Down Expand Up @@ -1940,11 +1976,11 @@ Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const {

// Register the Copy operation with TVM's TIR system
// This makes the copy operation available for use in TVM programs
// - Takes 5 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma,
// eviction_policy
// - Takes 8 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma,
// eviction_policy, src_pe, dst_pe
// - Marked as opaque since it has side effects (memory writes)
TIR_REGISTER_TL_OP(Copy, copy)
.set_num_inputs(5)
.set_num_inputs(7)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

Expand Down
9 changes: 9 additions & 0 deletions src/op/copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ class CopyNode : public TileOperatorNode {
IntImm coalesced_width; // Width (in elements) for coalesced memory access
Bool disable_tma = Bool(false); // Whether to disable TMA acceleration

// Params for remote copy
PrimExpr src_pe; // Source PE for remote copy
PrimExpr dst_pe; // Destination PE for remote copy
Buffer symm_buffer; // Symmetric buffer for remote copy

bool is_remote_copy() const;
bool is_remote_push() const;
bool is_remote_pull() const;

mutable ParallelOp par_op_; // Optional associated parallelization operator

enum class EvictionPolicy : uint8_t {
Expand Down
5 changes: 5 additions & 0 deletions src/op/distributed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,11 @@ TIR_DEFINE_TL_BUILTIN(get_remote_base_ptr)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(get_local_base)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(get_uintptr_t)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Expand Down
5 changes: 5 additions & 0 deletions src/op/distributed.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,11 @@ const Op &get_num_ranks();
*/
const Op &get_remote_base_ptr();

/*!
* \brief tvm intrinsics for getting the local base pointer
*/
const Op &get_local_base();

/*!
* \brief tvm intrinsics for getting the uintptr_t of a pointer
*/
Expand Down
5 changes: 5 additions & 0 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ std::string CodeGenTileLangCUDA::Finish() {

if (use_distributed_) {
decl_stream << "uint64_t __constant__ meta_data[1024];\n";
decl_stream
<< "uint64_t* host_meta_data = nullptr;\n"; // An alias of host_table
}
decl_stream << "#ifdef ENABLE_BF16\n";
decl_stream << "#include <tl_templates/cuda/cuda_bf16_fallbacks.cuh>\n";
Expand Down Expand Up @@ -1543,6 +1545,9 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
this->use_distributed_ = true;
std::string pe_str = this->PrintExpr(op->args[0]);
os << "tl::get_remote_base_ptr(" << pe_str << ")";
} else if (op->op.same_as(tl::get_local_base())) {
this->use_distributed_ = true;
os << "tl::get_local_base()";
} else if (op->op.same_as(tl::get_uintptr_t())) {
os << "tl::get_uintptr_t(" << this->PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(builtin::tvm_fill_fragment())) {
Expand Down
4 changes: 4 additions & 0 deletions src/tl_templates/cuda/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ using int4_t = int4;

#define TL_DEVICE __forceinline__ __device__
#define TL_DEVICE_NOINLINE __noinline__ __device__
#define TL_HOST __forceinline__ __host__
#define TL_HOST_NOINLINE __noinline__ __host__
#define TL_HOST_DEVICE __forceinline__ __host__ __device__
#define TL_HOST_DEVICE_NOINLINE __noinline__ __host__ __device__
#define TL_PATCH

#define TILELANG_CHECK(stmt) \
Expand Down
42 changes: 36 additions & 6 deletions src/tl_templates/cuda/distributed.h
Original file line number Diff line number Diff line change
@@ -1,20 +1,50 @@
#pragma once

#include "common.h"
#include <cstdint>

namespace tl {

extern "C" extern __device__ uint64_t meta_data[1024];
extern "C" __device__ uint64_t meta_data[1024];
extern "C" uint64_t *host_meta_data;

TL_DEVICE uint64_t get_rank() { return meta_data[0]; }
TL_HOST_DEVICE uint64_t get_rank() {
#ifdef __CUDA_ARCH__
return meta_data[0];
#else
return host_meta_data[0];
#endif
}

TL_HOST_DEVICE uint64_t get_num_ranks() {
#ifdef __CUDA_ARCH__
return meta_data[1];
#else
return host_meta_data[1];
#endif
}

TL_DEVICE uint64_t get_num_ranks() { return meta_data[1]; }
TL_HOST_DEVICE void *get_remote_base_ptr(uint64_t rank) {
#ifdef __CUDA_ARCH__
return (void *)meta_data[2 + rank];
#else
return (void *)host_meta_data[2 + rank];
#endif
}

TL_DEVICE uint64_t get_remote_base_ptr(uint64_t rank) {
return meta_data[2 + rank];
// NOTE(wt): Be careful about the return types here!
// get_local_base() returns u64 since I could not find a way cast u64 to ptr in
// tir
TL_HOST_DEVICE uint64_t get_local_base() {
#ifdef __CUDA_ARCH__
return meta_data[2 + get_rank()];
#else
return host_meta_data[2 + get_rank()];
#endif
}

template <typename dtype_t> TL_DEVICE uint64_t get_uintptr_t(dtype_t *ptr) {
template <typename dtype_t>
TL_HOST_DEVICE uint64_t get_uintptr_t(dtype_t *ptr) {
return reinterpret_cast<uint64_t>(ptr);
}

Expand Down
Loading
Loading