Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -301,5 +301,15 @@ TIR_DEFINE_TL_BUILTIN(atomicadd_elem_op)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(device_assert)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(device_assert_with_msg)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

} // namespace tl
} // namespace tvm
15 changes: 15 additions & 0 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ TVM_DLL const Op &initialize_descriptor();
*/

TVM_DLL const Op &increase_descriptor_offset();

/*!
* \brief tilelang intrinsic for element-wise atomic addition.
*
Expand All @@ -513,6 +514,20 @@ TVM_DLL const Op &increase_descriptor_offset();
*/
TVM_DLL const Op &atomicadd_elem_op();

/*!
* \brief tilelang intrinsic for assert on device.
*
* This op is used to represent an assert on device
*/
TVM_DLL const Op &device_assert();

/*!
* \brief tilelang intrinsic for assert on device with additional message.
*
* This op is used to represent an assert on device with additional message.
*/
TVM_DLL const Op &device_assert_with_msg();

} // namespace tl
} // namespace tvm

Expand Down
10 changes: 10 additions & 0 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2345,6 +2345,16 @@ void CodeGenTileLangCUDA::VisitStmt_(const EvaluateNode *op) {
stream << " " << vid_global_barrier_expect_ << " = 0;\n";
PrintIndent();
stream << "}\n";
}
if (call && (call->op.same_as(tvm::tl::device_assert()))) {
std::string cond = PrintExpr(call->args[0]);
this->PrintIndent();
stream << "device_assert(" << cond << ");\n";
} else if (call && call->op.same_as(tvm::tl::device_assert_with_msg())) {
std::string cond = PrintExpr(call->args[0]);
std::string msg_expr = PrintExpr(call->args[1]);
this->PrintIndent();
stream << "device_assert_with_msg(" << cond << ", " << msg_expr << ");\n";
} else {
CodeGenC::VisitStmt_(op);
}
Expand Down
5 changes: 3 additions & 2 deletions tilelang/language/print.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from tvm import tir
from typing import Any
import tilelang.language as T
from tilelang.language.kernel import get_thread_bindings
from tilelang.language import copy, macro, serial, alloc_shared
from tilelang.language.utils import index_to_coordinates
Expand Down Expand Up @@ -148,10 +149,10 @@ def device_assert(condition: tir.PrimExpr, msg: str = ""):
"""
if _IS_CUDA_AVAILABLE:
if msg == "":
tir.call_extern("void", "device_assert", condition)
T.call_intrin("void", tir.op.Op.get("tl.device_assert"), condition)
else:
warnings.warn("Non-empty msg may slightly slow down the kernel", stacklevel=2)
tir.call_extern("void", "device_assert_with_msg", condition, msg)
T.call_intrin("void", tir.op.Op.get("tl.device_assert_with_msg"), condition, msg)


def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> tir.PrimExpr:
Expand Down
Loading