Skip to content

Commit 8c42d09

Browse files
DerrickYLJspectrometerHBH
authored andcommitted
New op print (apache#44)
1 parent 9d0647b commit 8c42d09

File tree

7 files changed

+381
-10
lines changed

7 files changed

+381
-10
lines changed

include/tvm/tir/builtin.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -993,6 +993,12 @@ enum TVMStructFieldKind : int {
993993
kTVMFFIAnyUnionValue,
994994
kTVMValueKindBound_
995995
};
996+
997+
/*!
998+
* \brief Print the content of a buffer during runtime.
999+
*/
1000+
TVM_DLL const Op& print_buffer();
1001+
9961002
} // namespace builtin
9971003
} // namespace tir
9981004
} // namespace tvm

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2131,6 +2131,7 @@ def wrapped(*args, **kwargs):
21312131
anylist_setitem_call_cpacked = _op_wrapper(_tir_op.anylist_setitem_call_cpacked)
21322132
vscale = _op_wrapper(_tir_op.vscale)
21332133
ignore_loop_partition = _op_wrapper(_tir_op.ignore_loop_partition)
2134+
print_buffer = _op_wrapper(_tir_op.print_buffer)
21342135

21352136

21362137
def _dtype_forward(func):
@@ -2478,6 +2479,7 @@ def wrapped(*args, **kwargs):
24782479
"get_active_lane_mask",
24792480
"call_kernel",
24802481
"ignore_loop_partition",
2482+
"print_buffer",
24812483
]
24822484

24832485
__all__ += [

python/tvm/tir/op.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3662,6 +3662,34 @@ def ignore_loop_partition(predicate) -> PrimExpr:
36623662
return call_intrin("bool", "tir.ignore_loop_partition", predicate)
36633663

36643664

3665+
def print_buffer(buffer_var, dtype, shape_size, *dims):
3666+
"""Print out buffer memory during runtime on cuda.
3667+
3668+
This print function allows printing out buffer in tvm during runtime without
3669+
dumping all the cuda code.
3670+
3671+
Parameters
3672+
----------
3673+
buffer_var : Var
3674+
The data pointer of the buffer that needs to be printed out.
3675+
3676+
dtype : DataType
3677+
The data type of the buffer.
3678+
3679+
shape_size : Int
3680+
The number of dimensions of the buffer
3681+
3682+
*dims : Array
3683+
The dimensions of the buffer in order.
3684+
3685+
Returns
3686+
-------
3687+
call : PrimExpr
3688+
The call expression.
3689+
"""
3690+
return _ffi_api.print_buffer(buffer_var, dtype, shape_size, *dims)
3691+
3692+
36653693
# pylint: disable=unnecessary-lambda
36663694
sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum")
36673695
min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore

src/target/source/codegen_cuda.cc

Lines changed: 112 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,7 @@ std::string CodeGenCUDA::Finish() {
199199
if (enable_fp16_) {
200200
decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)\n";
201201
decl_stream << "#include <cuda_fp16.h>\n";
202-
decl_stream << "__device__ half max"
203-
<< "(half a, half b)\n"
202+
decl_stream << "__device__ half max" << "(half a, half b)\n"
204203
<< "{\n return __hgt(__half(a), __half(b)) ? a : b;\n}\n";
205204
decl_stream << "__device__ half min(half a, half b)\n"
206205
<< "{\n return __hlt(__half(a), __half(b)) ? a : b;\n}\n";
@@ -214,8 +213,7 @@ std::string CodeGenCUDA::Finish() {
214213
if (enable_bf16_) {
215214
decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)\n";
216215
decl_stream << "#include <cuda_bf16.h>\n";
217-
decl_stream << "__device__ nv_bfloat16 max"
218-
<< "(nv_bfloat16 a, nv_bfloat16 b)\n"
216+
decl_stream << "__device__ nv_bfloat16 max" << "(nv_bfloat16 a, nv_bfloat16 b)\n"
219217
<< "{\n return __hgt(a, b) ? a : b;\n}\n";
220218
decl_stream << "__device__ nv_bfloat16 min(nv_bfloat16 a, nv_bfloat16 b)\n"
221219
<< "{\n return __hlt(a, b) ? a : b;\n}\n";
@@ -693,8 +691,7 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i,
693691
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4));
694692
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
695693
if (t.lanes() == 2 || t.lanes() == 3) {
696-
stream << vec << '.' << access[i % t.lanes()] << "="
697-
<< "(" << value << ");\n";
694+
stream << vec << '.' << access[i % t.lanes()] << "=" << "(" << value << ");\n";
698695
} else {
699696
std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
700697
stream << ac << "=";
@@ -1253,8 +1250,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
12531250
this->stream << "\" @!p mov.b32 %0, 0;\\n\"\n";
12541251
this->stream << "\" @p ld.global.nc.f32 %0, [%1];}\\n\"\n";
12551252
// stream << "\" @p ld.global.nc.L2::128B.f32 %0, [%1];}\\n\"\n" ;
1256-
stream << ": \"=f\"(" << reg << "[" << local_addr << "]"
1257-
<< ")\n";
1253+
stream << ": \"=f\"(" << reg << "[" << local_addr << "]" << ")\n";
12581254
stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr << ")), \"r\"((int)"
12591255
<< guard << ")\n";
12601256
stream << ");\n";
@@ -1334,6 +1330,113 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
13341330
LOG(FATAL) << "Invalid number of lanes for float4_e2m1fn reinterpret: " << lanes;
13351331
}
13361332
EndScope(ssa_scope);
1333+
} else if (op->op.same_as(builtin::print_buffer())) {
1334+
ICHECK_GE(op->args.size(), 3U) << "Print operation expects at least three arguments";
1335+
const PrimExpr& arg = op->args[0];
1336+
const auto* var_node = arg.as<VarNode>();
1337+
DataType dtype = op->dtype;
1338+
int num_dims = op->args[2].as<IntImmNode>()->value;
1339+
Array<PrimExpr> shape;
1340+
for (size_t i = 3; i < op->args.size(); ++i) {
1341+
shape.push_back(op->args[i]);
1342+
}
1343+
std::string format_specifier;
1344+
bool is_float16 = false;
1345+
if (dtype.is_float()) {
1346+
if (dtype.bits() == 16) {
1347+
format_specifier = "%f";
1348+
is_float16 = true;
1349+
} else {
1350+
format_specifier = "%f";
1351+
}
1352+
} else if (dtype.is_int()) {
1353+
format_specifier = "%d";
1354+
} else if (dtype.is_uint()) {
1355+
format_specifier = "%u";
1356+
} else {
1357+
LOG(FATAL) << "Unsupported data type for print: " << dtype;
1358+
}
1359+
if (var_node) {
1360+
std::string buffer_name = GetVarID(var_node);
1361+
std::vector<std::string> indices;
1362+
for (int i = 0; i < num_dims; ++i) {
1363+
indices.push_back("i" + std::to_string(i));
1364+
}
1365+
auto nested_loops = [&](auto&& self, int dim) -> std::string {
1366+
if (dim >= num_dims) return "";
1367+
std::string loop_var = "i" + std::to_string(dim);
1368+
std::string body = self(self, dim + 1);
1369+
std::ostringstream oss;
1370+
oss << "for (int " << loop_var << " = 0; " << loop_var << " < "
1371+
<< shape[dim].as<IntImmNode>()->value << "; ++" << loop_var << ") {\n";
1372+
1373+
if (dim == num_dims - 1) {
1374+
std::string index_calculation = indices[0];
1375+
for (size_t i = 1; i < indices.size(); ++i) {
1376+
index_calculation = indices[i] + " + " +
1377+
std::to_string(shape[i - 1].as<IntImmNode>()->value) + " * (" +
1378+
index_calculation + ")";
1379+
}
1380+
oss << " int idx = " << index_calculation << ";\n";
1381+
if (is_float16) {
1382+
oss << " if (" << loop_var << " == " << shape[dim].as<IntImmNode>()->value
1383+
<< " - 1) {\n"
1384+
<< " printf(\"" << format_specifier << "\", static_cast<float>(" << buffer_name
1385+
<< "[idx]));\n"
1386+
<< "} else {\n"
1387+
<< " printf(\"" << format_specifier << " \", static_cast<float>(" << buffer_name
1388+
<< "[idx]));\n"
1389+
<< " }\n";
1390+
} else {
1391+
oss << " if (" << loop_var << " == " << shape[dim].as<IntImmNode>()->value
1392+
<< " - 1) {\n"
1393+
<< " printf(\"" << format_specifier << "\", " << buffer_name << "[idx]);\n"
1394+
<< "} else {\n"
1395+
<< " printf(\"" << format_specifier << " \", " << buffer_name << "[idx]);\n"
1396+
<< " }\n";
1397+
}
1398+
} else {
1399+
oss << " printf(\"[\");\n"
1400+
<< body << " if (" << loop_var << " == " << shape[dim].as<IntImmNode>()->value
1401+
<< " - 1) {\n"
1402+
<< " printf(\"]\");\n"
1403+
<< "} else {\n"
1404+
<< " printf(\"]\\n\");\n"
1405+
<< " }\n";
1406+
}
1407+
oss << "}\n";
1408+
return oss.str();
1409+
};
1410+
os << "// print_buffer starts\n"
1411+
<< "if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) {\n"
1412+
<< " printf(\"\\nBuffer " << buffer_name << "\\nDatatype: " << dtype << "\\n"
1413+
<< "Shape: [";
1414+
for (int i = 0; i < num_dims; ++i) {
1415+
os << "%d" << (i < num_dims - 1 ? ", " : "");
1416+
}
1417+
os << "]\\n\", ";
1418+
for (int i = 0; i < num_dims; ++i) {
1419+
os << shape[i].as<IntImmNode>()->value << (i < num_dims - 1 ? ", " : "");
1420+
}
1421+
os << ");\n"
1422+
<< " printf(\"Buffer " << buffer_name << " Contents:\\n[\");\n";
1423+
std::string loops = nested_loops(nested_loops, 0);
1424+
os << loops << " printf(\"]\\n\");\n"
1425+
<< "}\n"
1426+
<< "// print_buffer ends\n";
1427+
} else {
1428+
std::string print_arg = PrintExpr(arg);
1429+
if (is_float16) {
1430+
os << "if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) {\n"
1431+
<< " printf(\"" << format_specifier << "\\n\", static_cast<float>(" << print_arg
1432+
<< "));\n"
1433+
<< "}\n";
1434+
} else {
1435+
os << "if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) {\n"
1436+
<< " printf(\"" << format_specifier << "\\n\", " << print_arg << ");\n"
1437+
<< "}\n";
1438+
}
1439+
}
13371440
} else if (op->op.same_as(builtin::thread_return())) {
13381441
os << "return";
13391442
} else {
@@ -1442,8 +1545,7 @@ void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
14421545
PrintVecConstructor(op->dtype, os);
14431546
os << "(";
14441547
for (int i = 0; i < lanes; i++) {
1445-
os << "(" << PrintExpr(op->base) << ")"
1446-
<< "+(" << PrintExpr(op->stride) << "*" << i << ")";
1548+
os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i << ")";
14471549
if (i != lanes - 1) os << ", ";
14481550
}
14491551
os << ")";

src/tir/op/builtin.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,9 @@ TIR_DEFINE_BUILTIN_FUNC(buffer_offset)
435435
.set_num_inputs(2)
436436
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
437437

438+
TIR_DEFINE_BUILTIN_FUNC(print_buffer)
439+
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
440+
438441
} // namespace builtin
439442
} // namespace tir
440443
} // namespace tvm

src/tir/op/op.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,5 +1257,23 @@ PrimExpr fast_erf_float_expr(PrimExpr arg, int bits) {
12571257

12581258
return p / q;
12591259
}
1260+
PrimExpr PrintOpPacked(Var data, DataType dtype, Array<PrimExpr> shape) {
1261+
Array<PrimExpr> args;
1262+
args.push_back(data);
1263+
args.push_back(tir::StringImm(runtime::DLDataType2String(dtype)));
1264+
args.push_back(make_const(DataType::UInt(32), shape.size()));
1265+
for (const auto& dim : shape) {
1266+
args.push_back(dim);
1267+
}
1268+
return tir::Call(dtype, tir::builtin::print_buffer(), args);
1269+
}
1270+
1271+
TVM_REGISTER_GLOBAL("tir.print_buffer").set_body([](TVMArgs args, TVMRetValue* ret) {
1272+
Array<PrimExpr> shape;
1273+
for (int i = 3; i < args.size(); ++i) {
1274+
shape.push_back(args[i]);
1275+
}
1276+
*ret = PrintOpPacked(args[0], args[1].operator DataType(), shape);
1277+
});
12601278

12611279
} // namespace tvm

0 commit comments

Comments
 (0)