Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -663,8 +663,7 @@ TVM_DLL const Op& ptx_cp_async();
* Var global_ptr,
* Expr global_offset,
* size_t bytes,
* Var barrier_ptr,
* Expr barrier_offset);
* int barrier_id);
*/
TVM_DLL const Op& ptx_cp_async_bulk();

Expand All @@ -681,43 +680,51 @@ TVM_DLL const Op& ptx_wait_group();
/*!
* \brief tvm intrinsics for ptx async copy barrier using cp.async.mbarrier.arrive
*
* ptx_cp_async_barrier(Var barrier_ptr, Expr barrier_offset)
* ptx_cp_async_barrier(int barrier_id)
*
*/
TVM_DLL const Op& ptx_cp_async_barrier();

/*!
* \brief tvm intrinsics for ptx barrier initialization of thread count using mbarrier.init
*
* ptx_init_barrier_thread_count(Var barrier_ptr, Expr barrier_offset, int thread_count)
* ptx_init_barrier_thread_count(int barrier_id, int thread_count)
*
*/
TVM_DLL const Op& ptx_init_barrier_thread_count();

/*!
* \brief tvm intrinsics for ptx barrier arrival using mbarrier.arrive
*
* ptx_arrive_barrier(Var barrier_ptr, Expr barrier_offset)
* ptx_arrive_barrier(int barrier_id)
*
*/
TVM_DLL const Op& ptx_arrive_barrier();

/*!
* \brief tvm intrinsic for ptx barrier arrival with expect tx using mbarrier.arrive.expect_tx
*
* ptx_arrive_barrier_expect_tx(Var barrier_ptr, Expr barrier_offset, int byte_count)
* ptx_arrive_barrier_expect_tx(int barrier_id, int byte_count)
*
*/
TVM_DLL const Op& ptx_arrive_barrier_expect_tx();

/*!
* \brief tvm intrinsics for ptx barrier wait using mbarrier.try_wait
*
* ptx_wait_barrier(Var barrier_ptr, Expr barrier_offset)
* ptx_wait_barrier(int barrier_id)
*
*/
TVM_DLL const Op& ptx_wait_barrier();

/*!
* \brief tvm intrinsics to create N barriers
*
* ptx_wait_barrier(int barrier_count)
*
*/
TVM_DLL const Op& create_barriers();

/*!
* \brief tvm intrinsic for storing the result of PTX MMA into a destination pointer.
* For example, if each thread in a warp of size 32 has 4 elements from the result of
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1849,6 +1849,7 @@ def wrapped(*args, **kwargs):
ptx_arrive_barrier = _op_wrapper(_tir_op.ptx_arrive_barrier)
ptx_arrive_barrier_expect_tx = _op_wrapper(_tir_op.ptx_arrive_barrier_expect_tx)
ptx_wait_barrier = _op_wrapper(_tir_op.ptx_wait_barrier)
create_barriers = _op_wrapper(_tir_op.create_barriers)
assume = _op_wrapper(_tir_op.assume)
undef = _op_wrapper(_tir_op.undef)
TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace)
Expand Down Expand Up @@ -2125,6 +2126,7 @@ def wrapped(*args, **kwargs):
"ptx_arrive_barrier",
"ptx_arrive_barrier_expect_tx",
"ptx_wait_barrier",
"create_barriers",
"mma_store",
"mma_fill",
"vectorlow",
Expand Down
1 change: 1 addition & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
ptx_arrive_barrier,
ptx_arrive_barrier_expect_tx,
ptx_wait_barrier,
create_barriers,
)
from .op import vectorlow, vectorhigh, vectorcombine
from .op import infinity, reinterpret
Expand Down
75 changes: 34 additions & 41 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,7 +1369,7 @@ def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, by


def ptx_cp_async_bulk(
dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_ptr, barrier_offset
dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_id
):
"""TVM intrinsic for ptx async copy from global to shared memory using cp.async.bulk
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk
Expand All @@ -1394,11 +1394,8 @@ def ptx_cp_async_bulk(
bytes : int
The data size to copy.

barrier_ptr : Var
The barrier shared memory pointer variable.

barrier_id : int
The offset of the barrier shared memory pointer.
The ID of the barrier shared memory pointer.

Returns
-------
Expand All @@ -1413,8 +1410,7 @@ def ptx_cp_async_bulk(
global_ptr,
global_offset,
bytes,
barrier_ptr,
barrier_offset,
barrier_id,
)


Expand Down Expand Up @@ -1447,37 +1443,31 @@ def ptx_wait_group(num):
return call_intrin("", "tir.ptx_wait_group", num)


def ptx_cp_async_barrier(barrier_ptr, barrier_offset):
def ptx_cp_async_barrier(barrier_id):
"""TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive

Parameters
----------
barrier_ptr : Var
The barrier shared memory pointer variable.

barrier_id : int
The offset of the barrier shared memory pointer.
The ID of the barrier shared memory pointer.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_cp_async_barrier", barrier_ptr, barrier_offset)
return call_intrin("", "tir.ptx_cp_async_barrier", barrier_id)


def ptx_init_barrier_thread_count(barrier_ptr, barrier_offset, thread_count):
def ptx_init_barrier_thread_count(barrier_id, thread_count):
"""TVM intrinsic for ptx barrier initialization of thread count using mbarrier.init
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init

Parameters
----------
barrier_ptr : Var
The barrier shared memory pointer variable.

barrier_id : int
The offset of the barrier shared memory pointer.
The ID of the barrier shared memory pointer.

thread_count : int
Number of threads expected to arrive at the barrier.
Expand All @@ -1487,43 +1477,35 @@ def ptx_init_barrier_thread_count(barrier_ptr, barrier_offset, thread_count):
call : PrimExpr
The call expression.
"""
return call_intrin(
"", "tir.ptx_init_barrier_thread_count", barrier_ptr, barrier_offset, thread_count
)
return call_intrin("", "tir.ptx_init_barrier_thread_count", barrier_id, thread_count)


def ptx_arrive_barrier(barrier_ptr, barrier_offset):
def ptx_arrive_barrier(barrier_id):
"""TVM intrinsic for ptx barrier arrival using mbarrier.arrive
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive

Parameters
----------
barrier_ptr : Var
The barrier shared memory pointer variable.

barrier_id : int
The offset of the barrier shared memory pointer.
The ID of the barrier shared memory pointer.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_arrive_barrier", barrier_ptr, barrier_offset)
return call_intrin("", "tir.ptx_arrive_barrier", barrier_id)


def ptx_arrive_barrier_expect_tx(barrier_ptr, barrier_offset, byte_count):
def ptx_arrive_barrier_expect_tx(barrier_id, byte_count):
"""TVM intrinsic for ptx barrier arrival with expect tx using mbarrier.arrive.expect_tx
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-expect-tx-operation

Parameters
----------
barrier_ptr : Var
The barrier shared memory pointer variable.

barrier_id : int
The offset of the barrier shared memory pointer.
The ID of the barrier shared memory pointer.

byte_count : int
Increases the tx count of the mbarrier object to track completion of
Expand All @@ -1534,29 +1516,40 @@ def ptx_arrive_barrier_expect_tx(barrier_ptr, barrier_offset, byte_count):
call : PrimExpr
The call expression.
"""
return call_intrin(
"", "tir.ptx_arrive_barrier_expect_tx", barrier_ptr, barrier_offset, byte_count
)
return call_intrin("", "tir.ptx_arrive_barrier_expect_tx", barrier_id, byte_count)


def ptx_wait_barrier(barrier_ptr, barrier_offset):
def ptx_wait_barrier(barrier_id):
"""TVM intrinsic for ptx barrier wait using mbarrier.try_wait
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait

Parameters
----------
barrier_ptr : Var
The barrier shared memory pointer variable.

barrier_id : int
The offset of the barrier shared memory pointer.
The ID of the barrier shared memory pointer.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_wait_barrier", barrier_id)


def create_barriers(barrier_count):
"""TVM intrinsic to create N barriers

Parameters
----------
barrier_count : int
The number of barriers to create.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_wait_barrier", barrier_ptr, barrier_offset)
return call_intrin("", "tir.create_barriers", barrier_count)


def vectorlow(dtype, vec):
Expand Down
61 changes: 40 additions & 21 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -968,42 +968,61 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string src = this->PrintExpr(op->args[2]);
std::string src_offset = this->PrintExpr(op->args[3]);
std::string size = this->PrintExpr(op->args[4]);
std::string barrier_ptr = this->PrintExpr(op->args[5]);
std::string barrier_offset = this->PrintExpr(op->args[6]);
this->stream << PrintCpAsyncBulkAsm(dst, dst_offset, src, src_offset, size, barrier_ptr,
barrier_offset);
int barrier_id = Downcast<IntImm>(op->args[5])->value;
CHECK(barrier_id < barrier_count_);
std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]";
this->stream << PrintCpAsyncBulkAsm(dst, dst_offset, src, src_offset, size, barrier);
} else if (op->op.same_as(builtin::ptx_commit_group())) {
this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n";
} else if (op->op.same_as(builtin::ptx_wait_group())) {
int n = Downcast<IntImm>(op->args[0])->value;
this->stream << "__asm__ __volatile__(\"cp.async.wait_group " << n << ";\");\n\n";
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
need_cast_smem_ptr_to_int_ = true;
std::string barrier_ptr = this->PrintExpr(op->args[0]);
std::string barrier_offset = this->PrintExpr(op->args[1]);
this->stream << PrintCpAsyncBarrierAsm(barrier_ptr, barrier_offset);
int barrier_id = Downcast<IntImm>(op->args[0])->value;
CHECK(barrier_id < barrier_count_);
std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]";
this->stream << PrintCpAsyncBarrierAsm(barrier);
} else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
need_cast_smem_ptr_to_int_ = true;
std::string barrier_ptr = this->PrintExpr(op->args[0]);
std::string barrier_offset = this->PrintExpr(op->args[1]);
std::string thread_count = this->PrintExpr(op->args[2]);
this->stream << PrintInitBarrierThreadCountAsm(barrier_ptr, barrier_offset, thread_count);
int barrier_id = Downcast<IntImm>(op->args[0])->value;
CHECK(barrier_id < barrier_count_);
std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]";
std::string thread_count = this->PrintExpr(op->args[1]);
this->stream << PrintInitBarrierThreadCountAsm(barrier, thread_count);
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
need_cast_smem_ptr_to_int_ = true;
std::string barrier_ptr = this->PrintExpr(op->args[0]);
std::string barrier_offset = this->PrintExpr(op->args[1]);
this->stream << PrintArriveBarrierAsm(barrier_ptr, barrier_offset);
int barrier_id = Downcast<IntImm>(op->args[0])->value;
CHECK(barrier_id < barrier_count_);
std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]";
this->stream << PrintArriveBarrierAsm(barrier);
} else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
need_cast_smem_ptr_to_int_ = true;
std::string barrier_ptr = this->PrintExpr(op->args[0]);
std::string barrier_offset = this->PrintExpr(op->args[1]);
std::string byte_count = this->PrintExpr(op->args[2]);
this->stream << PrintArriveBarrierExpectTxAsm(barrier_ptr, barrier_offset, byte_count);
int barrier_id = Downcast<IntImm>(op->args[0])->value;
CHECK(barrier_id < barrier_count_);
std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]";
std::string byte_count = this->PrintExpr(op->args[1]);
this->stream << PrintArriveBarrierExpectTxAsm(barrier, byte_count);
} else if (op->op.same_as(builtin::ptx_wait_barrier())) {
need_cast_smem_ptr_to_int_ = true;
std::string barrier_ptr = this->PrintExpr(op->args[0]);
std::string barrier_offset = this->PrintExpr(op->args[1]);
this->stream << PrintWaitBarrierAsm(barrier_ptr, barrier_offset);
int barrier_id = Downcast<IntImm>(op->args[0])->value;
CHECK(barrier_id < barrier_count_);
std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]";
this->stream << PrintWaitBarrierAsm(barrier);
} else if (op->op.same_as(builtin::create_barriers())) {
CHECK_EQ(barrier_count_, -1);
int barrier_count = Downcast<IntImm>(op->args[0])->value;
// pad barrier alignment to avoid runtime alignment errors
CHECK_EQ(barrier_alignment_bytes_ % sizeof(uint64_t), 0);
int barrier_alignment_count = barrier_alignment_bytes_ / sizeof(uint64_t);
if (barrier_count % barrier_alignment_count != 0) {
barrier_count = ((barrier_count / barrier_alignment_count) + 1) * barrier_alignment_count;
}
barrier_count_ = barrier_count;
this->stream << "__shared__ __align__(" << barrier_alignment_bytes_ << ") uint64_t "
<< barrier_name_ << "[" << barrier_count << "];\n";
this->stream << "for (int i = 0; i < " << barrier_count << "; ++i) { " << barrier_name_
<< "[i] = 0; }\n";
} else if (op->op.same_as(builtin::ptx_ldg32())) {
/*
asm volatile (
Expand Down
8 changes: 8 additions & 0 deletions src/target/source/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,14 @@ class CodeGenCUDA final : public CodeGenC {
// Op attribute map
OpAttrMap<bool> op_need_warp_shuffle_ = Op::GetAttrMap<bool>("cuda.need_warp_shuffle");

// The name of the barrier array in shared memory
const std::string barrier_name_ = "barrier";
// The size of the barrier array in shared memory
int barrier_count_ = -1;
// The alignment of the barrier array in shared memory
// Set to 16 to maintain minimum alignment requirements for async bulk copy
const int barrier_alignment_bytes_ = 16;

std::unordered_map<const VarNode*, std::string> fragment_shapes;
std::unordered_map<const VarNode*, std::string> fragment_layouts;
friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p);
Expand Down
Loading