@@ -199,8 +199,7 @@ std::string CodeGenCUDA::Finish() {
199199 if (enable_fp16_) {
200200 decl_stream << " #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)\n " ;
201201 decl_stream << " #include <cuda_fp16.h>\n " ;
202- decl_stream << " __device__ half max"
203- << " (half a, half b)\n "
202+ decl_stream << " __device__ half max" << " (half a, half b)\n "
204203 << " {\n return __hgt(__half(a), __half(b)) ? a : b;\n }\n " ;
205204 decl_stream << " __device__ half min(half a, half b)\n "
206205 << " {\n return __hlt(__half(a), __half(b)) ? a : b;\n }\n " ;
@@ -214,8 +213,7 @@ std::string CodeGenCUDA::Finish() {
214213 if (enable_bf16_) {
215214 decl_stream << " #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)\n " ;
216215 decl_stream << " #include <cuda_bf16.h>\n " ;
217- decl_stream << " __device__ nv_bfloat16 max"
218- << " (nv_bfloat16 a, nv_bfloat16 b)\n "
216+ decl_stream << " __device__ nv_bfloat16 max" << " (nv_bfloat16 a, nv_bfloat16 b)\n "
219217 << " {\n return __hgt(a, b) ? a : b;\n }\n " ;
220218 decl_stream << " __device__ nv_bfloat16 min(nv_bfloat16 a, nv_bfloat16 b)\n "
221219 << " {\n return __hlt(a, b) ? a : b;\n }\n " ;
@@ -693,8 +691,7 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i,
693691 ICHECK (i >= 0 && i < (t.bits () == 8 ? 16 : (t.bits () == 16 || t.bits () == 32 ) ? 8 : 4 ));
694692 if (t.bits () == 8 && (t.is_int () || t.is_uint ())) {
695693 if (t.lanes () == 2 || t.lanes () == 3 ) {
696- stream << vec << ' .' << access[i % t.lanes ()] << " ="
697- << " (" << value << " );\n " ;
694+ stream << vec << ' .' << access[i % t.lanes ()] << " =" << " (" << value << " );\n " ;
698695 } else {
699696 std::string ac = t.lanes () == 4 ? vec : (vec + " ." + access[i / 4 ]);
700697 stream << ac << " =" ;
@@ -1253,8 +1250,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
12531250 this ->stream << " \" @!p mov.b32 %0, 0;\\ n\"\n " ;
12541251 this ->stream << " \" @p ld.global.nc.f32 %0, [%1];}\\ n\"\n " ;
12551252 // stream << "\" @p ld.global.nc.L2::128B.f32 %0, [%1];}\\n\"\n" ;
1256- stream << " : \" =f\" (" << reg << " [" << local_addr << " ]"
1257- << " )\n " ;
1253+ stream << " : \" =f\" (" << reg << " [" << local_addr << " ]" << " )\n " ;
12581254 stream << " : \" l\" ((void*)(" << global_buffer << " +" << global_addr << " )), \" r\" ((int)"
12591255 << guard << " )\n " ;
12601256 stream << " );\n " ;
@@ -1334,6 +1330,113 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
13341330 LOG (FATAL) << " Invalid number of lanes for float4_e2m1fn reinterpret: " << lanes;
13351331 }
13361332 EndScope (ssa_scope);
1333+ } else if (op->op .same_as (builtin::print_buffer ())) {
1334+ ICHECK_GE (op->args .size (), 3U ) << " Print operation expects at least three arguments" ;
1335+ const PrimExpr& arg = op->args [0 ];
1336+ const auto * var_node = arg.as <VarNode>();
1337+ DataType dtype = op->dtype ;
1338+ int num_dims = op->args [2 ].as <IntImmNode>()->value ;
1339+ Array<PrimExpr> shape;
1340+ for (size_t i = 3 ; i < op->args .size (); ++i) {
1341+ shape.push_back (op->args [i]);
1342+ }
1343+ std::string format_specifier;
1344+ bool is_float16 = false ;
1345+ if (dtype.is_float ()) {
1346+ if (dtype.bits () == 16 ) {
1347+ format_specifier = " %f" ;
1348+ is_float16 = true ;
1349+ } else {
1350+ format_specifier = " %f" ;
1351+ }
1352+ } else if (dtype.is_int ()) {
1353+ format_specifier = " %d" ;
1354+ } else if (dtype.is_uint ()) {
1355+ format_specifier = " %u" ;
1356+ } else {
1357+ LOG (FATAL) << " Unsupported data type for print: " << dtype;
1358+ }
1359+ if (var_node) {
1360+ std::string buffer_name = GetVarID (var_node);
1361+ std::vector<std::string> indices;
1362+ for (int i = 0 ; i < num_dims; ++i) {
1363+ indices.push_back (" i" + std::to_string (i));
1364+ }
1365+ auto nested_loops = [&](auto && self, int dim) -> std::string {
1366+ if (dim >= num_dims) return " " ;
1367+ std::string loop_var = " i" + std::to_string (dim);
1368+ std::string body = self (self, dim + 1 );
1369+ std::ostringstream oss;
1370+ oss << " for (int " << loop_var << " = 0; " << loop_var << " < "
1371+ << shape[dim].as <IntImmNode>()->value << " ; ++" << loop_var << " ) {\n " ;
1372+
1373+ if (dim == num_dims - 1 ) {
1374+ std::string index_calculation = indices[0 ];
1375+ for (size_t i = 1 ; i < indices.size (); ++i) {
1376+ index_calculation = indices[i] + " + " +
1377+ std::to_string (shape[i - 1 ].as <IntImmNode>()->value ) + " * (" +
1378+ index_calculation + " )" ;
1379+ }
1380+ oss << " int idx = " << index_calculation << " ;\n " ;
1381+ if (is_float16) {
1382+ oss << " if (" << loop_var << " == " << shape[dim].as <IntImmNode>()->value
1383+ << " - 1) {\n "
1384+ << " printf(\" " << format_specifier << " \" , static_cast<float>(" << buffer_name
1385+ << " [idx]));\n "
1386+ << " } else {\n "
1387+ << " printf(\" " << format_specifier << " \" , static_cast<float>(" << buffer_name
1388+ << " [idx]));\n "
1389+ << " }\n " ;
1390+ } else {
1391+ oss << " if (" << loop_var << " == " << shape[dim].as <IntImmNode>()->value
1392+ << " - 1) {\n "
1393+ << " printf(\" " << format_specifier << " \" , " << buffer_name << " [idx]);\n "
1394+ << " } else {\n "
1395+ << " printf(\" " << format_specifier << " \" , " << buffer_name << " [idx]);\n "
1396+ << " }\n " ;
1397+ }
1398+ } else {
1399+ oss << " printf(\" [\" );\n "
1400+ << body << " if (" << loop_var << " == " << shape[dim].as <IntImmNode>()->value
1401+ << " - 1) {\n "
1402+ << " printf(\" ]\" );\n "
1403+ << " } else {\n "
1404+ << " printf(\" ]\\ n\" );\n "
1405+ << " }\n " ;
1406+ }
1407+ oss << " }\n " ;
1408+ return oss.str ();
1409+ };
1410+ os << " // print_buffer starts\n "
1411+ << " if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) {\n "
1412+ << " printf(\"\\ nBuffer " << buffer_name << " \\ nDatatype: " << dtype << " \\ n"
1413+ << " Shape: [" ;
1414+ for (int i = 0 ; i < num_dims; ++i) {
1415+ os << " %d" << (i < num_dims - 1 ? " , " : " " );
1416+ }
1417+ os << " ]\\ n\" , " ;
1418+ for (int i = 0 ; i < num_dims; ++i) {
1419+ os << shape[i].as <IntImmNode>()->value << (i < num_dims - 1 ? " , " : " " );
1420+ }
1421+ os << " );\n "
1422+ << " printf(\" Buffer " << buffer_name << " Contents:\\ n[\" );\n " ;
1423+ std::string loops = nested_loops (nested_loops, 0 );
1424+ os << loops << " printf(\" ]\\ n\" );\n "
1425+ << " }\n "
1426+ << " // print_buffer ends\n " ;
1427+ } else {
1428+ std::string print_arg = PrintExpr (arg);
1429+ if (is_float16) {
1430+ os << " if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) {\n "
1431+ << " printf(\" " << format_specifier << " \\ n\" , static_cast<float>(" << print_arg
1432+ << " ));\n "
1433+ << " }\n " ;
1434+ } else {
1435+ os << " if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) {\n "
1436+ << " printf(\" " << format_specifier << " \\ n\" , " << print_arg << " );\n "
1437+ << " }\n " ;
1438+ }
1439+ }
13371440 } else if (op->op .same_as (builtin::thread_return ())) {
13381441 os << " return" ;
13391442 } else {
@@ -1442,8 +1545,7 @@ void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
14421545 PrintVecConstructor (op->dtype , os);
14431546 os << " (" ;
14441547 for (int i = 0 ; i < lanes; i++) {
1445- os << " (" << PrintExpr (op->base ) << " )"
1446- << " +(" << PrintExpr (op->stride ) << " *" << i << " )" ;
1548+ os << " (" << PrintExpr (op->base ) << " )" << " +(" << PrintExpr (op->stride ) << " *" << i << " )" ;
14471549 if (i != lanes - 1 ) os << " , " ;
14481550 }
14491551 os << " )" ;
0 commit comments