@@ -36,9 +36,11 @@ namespace codegen {
3636namespace ptx {
3737
3838static const char *enum_to_str[] = {
39- " kInt4" , " kUInt4" , " kInt8" , " kUInt8" , " kInt16" , " kUInt16" , " kInt32" , " kUInt32" ,
40- " kInt64" , " kUInt64" , " kFloat8_e4m3" , " kFloat8_e5m2" , " kFloat16" , " kBFloat16" , " kFloat16x2" , " kFloat32" ,
41- " kTensorFloat32" , " kFloat64" , " kBit1" , " kBit8" , " kBit16" , " kBit32" , " kBit64" };
39+ " kInt4" , " kUInt4" , " kInt8" , " kUInt8" , " kInt16" ,
40+ " kUInt16" , " kInt32" , " kUInt32" , " kInt64" , " kUInt64" ,
41+ " kFloat8_e4m3" , " kFloat8_e5m2" , " kFloat16" , " kBFloat16" , " kFloat16x2" ,
42+ " kFloat32" , " kTensorFloat32" , " kFloat64" , " kBit1" , " kBit8" ,
43+ " kBit16" , " kBit32" , " kBit64" };
4244
4345static const char *dtype_str[] = {
4446 " .s4" , " .u4" , " .s8" , " .u8" , " .s16" , " .u16" , " .s32" , " .u32" ,
@@ -103,13 +105,13 @@ DataType DTypeFromString(const std::string str) {
103105 }
104106}
105107
106-
107108std::string DTypeEnumToString (const ptx::DataType &dtype) {
108- return " tl::DataType::" + std::string (enum_to_str[static_cast <int >(dtype)]);
109+ return " tl::DataType::" + std::string (enum_to_str[static_cast <int >(dtype)]);
109110}
110111
111112std::string DTypeEnumToString (const std::string &dtype) {
112- return " tl::DataType::" + std::string (enum_to_str[static_cast <int >(DTypeFromString (dtype))]);
113+ return " tl::DataType::" +
114+ std::string (enum_to_str[static_cast <int >(DTypeFromString (dtype))]);
113115}
114116
115117/* !
@@ -1183,16 +1185,18 @@ PrintMMAAssembly(const std::string &shape, const std::string &A_layout,
11831185 return asm_code;
11841186}
11851187
1186- std::string PrintWGMMAAssembly (
1187- const std::string &shape, const bool &a_is_k_major, const bool &b_is_k_major,
1188- const std::string &A_dtype, const std::string &B_dtype,
1189- const std::string &C_dtype, const std::string &a_desc,
1190- const std::string &A_offset, const std::string &b_desc,
1191- const std::string &B_offset, const std::string &c_ptr,
1192- const std::string &c_offset, const bool &scale_out, const bool &scale_in_a,
1193- const bool &scale_in_b, const bool &a_is_shared,
1194- const std::string &metadata, const std::string &metadata_offset,
1195- const std::string &sparsity_selector, bool sparse) {
1188+ std::string
1189+ PrintWGMMAAssembly (const std::string &shape, const bool &a_is_k_major,
1190+ const bool &b_is_k_major, const std::string &A_dtype,
1191+ const std::string &B_dtype, const std::string &C_dtype,
1192+ const std::string &a_desc, const std::string &A_offset,
1193+ const std::string &b_desc, const std::string &B_offset,
1194+ const std::string &c_ptr, const std::string &c_offset,
1195+ const bool &scale_out, const bool &scale_in_a,
1196+ const bool &scale_in_b, const bool &a_is_shared,
1197+ const std::string &metadata,
1198+ const std::string &metadata_offset,
1199+ const std::string &sparsity_selector, bool sparse) {
11961200 ptx::DataType dtype_a = ptx::DTypeFromString (A_dtype),
11971201 dtype_b = ptx::DTypeFromString (B_dtype),
11981202 dtype_c = ptx::DTypeFromString (C_dtype);
0 commit comments