Skip to content

Commit fef8d2a

Browse files
committed
lint fix
1 parent ce9e2b6 commit fef8d2a

File tree

13 files changed

+592
-563
lines changed

13 files changed

+592
-563
lines changed

.clang-tidy

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ Checks: >
4646
-cppcoreguidelines-pro-bounds-array-to-pointer-decay,
4747
-clang-analyzer-deadcode.DeadStores,
4848
-clang-analyzer-optin.cplusplus.VirtualCall,
49+
-clang-diagnostic-tautological-constant-compare,
4950
5051
WarningsAsErrors: '*'
5152

src/layout/gemm_layouts.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,7 @@ Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous,
749749
element_size);
750750
}
751751
int vector_size = 128 / element_size;
752-
752+
753753
if (mat_continuous % (vector_size * 8) == 0)
754754
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
755755
else if (mat_continuous % (vector_size * 4) == 0)

src/layout/layout.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -548,10 +548,9 @@ TVM_FFI_STATIC_INIT_BLOCK({
548548
return makeQuarterBankSwizzleLayout(stride, continuous,
549549
element_size);
550550
})
551-
.def("tl.make_linear_layout",
552-
[](int stride, int continuous) {
553-
return makeGemmLayoutLinear(stride, continuous);
554-
});
551+
.def("tl.make_linear_layout", [](int stride, int continuous) {
552+
return makeGemmLayoutLinear(stride, continuous);
553+
});
555554
});
556555

557556
TVM_FFI_STATIC_INIT_BLOCK({

src/layout/layout.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size);
168168
Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
169169
int element_size, bool k_inner = true);
170170
Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous,
171-
int continuity, int element_size, bool k_inner = true);
171+
int continuity, int element_size,
172+
bool k_inner = true);
172173
Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity,
173174
int element_size, bool k_inner = true);
174175
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size,

src/op/builtin.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -220,21 +220,21 @@ TVM_DLL const Op &mbarrier_expect_tx();
220220
* \brief tvm intrinsic for ptx tensor core wgmma instructions.
221221
*
222222
* void ptx_wgmma_ss(StringImm accum_dtype, StringImm wgmma_prefix, bool
223-
* a_is_k_major, bool b_is_k_major, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv,
224-
* StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr A_offset, Var
225-
* B_descriptor, Var B_offset, Var C_data, Var C_offset, bool scale_out, bool
226-
* scale_in_a, bool scale_in_b);
223+
* a_is_k_major, bool b_is_k_major, StringImm a_dtype_abbrv, StringImm
224+
* b_dtype_abbrv, StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr
225+
* A_offset, Var B_descriptor, Var B_offset, Var C_data, Var C_offset, bool
226+
* scale_out, bool scale_in_a, bool scale_in_b);
227227
*/
228228
TVM_DLL const Op &ptx_wgmma_ss();
229229

230230
/*!
231231
* \brief tvm intrinsics for ptx tensor core wgmma instructions.
232232
*
233233
* void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix, bool
234-
* a_is_k_major, bool b_is_k_major, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv,
235-
* StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr A_offset, Var
236-
* B_descriptor, Var B_offset, Var C_data, Var C_offset, bool scale_out, bool
237-
* scale_in_a, bool scale_in_b);
234+
* a_is_k_major, bool b_is_k_major, StringImm a_dtype_abbrv, StringImm
235+
* b_dtype_abbrv, StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr
236+
* A_offset, Var B_descriptor, Var B_offset, Var C_data, Var C_offset, bool
237+
* scale_out, bool scale_in_a, bool scale_in_b);
238238
*/
239239
TVM_DLL const Op &ptx_wgmma_rs();
240240

src/target/codegen_cuda.cc

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1565,23 +1565,29 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
15651565
const bool a_is_shared = true;
15661566
this->PrintIndent();
15671567
std::string asm_code = PrintWGMMAAssembly(
1568-
shape, a_is_k_major, b_is_k_major, A_dtype, B_dtype, C_dtype, a_desc, A_offset,
1569-
b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a, scale_in_b,
1570-
a_is_shared, "", "", "", false);
1568+
shape, a_is_k_major, b_is_k_major, A_dtype, B_dtype, C_dtype, a_desc,
1569+
A_offset, b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a,
1570+
scale_in_b, a_is_shared, "", "", "", false);
15711571
auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape);
1572-
std::string wgmma_asm_code = "tl::wgmma_ss<(AType), (BType), (CType), (M), (N), (K), (tnspA), (tnspB), (scaleA), (scaleB)>(uint64_t((desc_a) + (A_offset)), uint64_t((desc_b) + (B_offset)), ((uint32_t*)((C))), (scale_out));\n";
1572+
std::string wgmma_asm_code =
1573+
"tl::wgmma_ss<(AType), (BType), (CType), (M), (N), (K), (tnspA), "
1574+
"(tnspB), (scaleA), (scaleB)>(uint64_t((desc_a) + (A_offset)), "
1575+
"uint64_t((desc_b) + (B_offset)), ((uint32_t*)((C))), (scale_out));\n";
15731576
// replace patterns
15741577
tl::codegen::Replacer replacer;
1575-
replacer.register_rule("(AType)", tl::codegen::ptx::DTypeEnumToString(A_dtype));
1576-
replacer.register_rule("(BType)", tl::codegen::ptx::DTypeEnumToString(B_dtype));
1577-
replacer.register_rule("(CType)", tl::codegen::ptx::DTypeEnumToString(C_dtype));
1578+
replacer.register_rule("(AType)",
1579+
tl::codegen::ptx::DTypeEnumToString(A_dtype));
1580+
replacer.register_rule("(BType)",
1581+
tl::codegen::ptx::DTypeEnumToString(B_dtype));
1582+
replacer.register_rule("(CType)",
1583+
tl::codegen::ptx::DTypeEnumToString(C_dtype));
15781584
replacer.register_rule("(M)", std::to_string(m));
15791585
replacer.register_rule("(N)", std::to_string(n));
15801586
replacer.register_rule("(K)", std::to_string(k));
1581-
replacer.register_rule("(tnspA)", a_is_k_major? "false": "true");
1582-
replacer.register_rule("(tnspB)", b_is_k_major? "false": "true");
1583-
replacer.register_rule("(scaleA)", scale_in_a? "1": "-1");
1584-
replacer.register_rule("(scaleB)", scale_in_b? "1": "-1");
1587+
replacer.register_rule("(tnspA)", a_is_k_major ? "false" : "true");
1588+
replacer.register_rule("(tnspB)", b_is_k_major ? "false" : "true");
1589+
replacer.register_rule("(scaleA)", scale_in_a ? "1" : "-1");
1590+
replacer.register_rule("(scaleB)", scale_in_b ? "1" : "-1");
15851591
replacer.register_rule("(desc_a)", a_desc);
15861592
replacer.register_rule("(A_offset)", A_offset);
15871593
replacer.register_rule("(desc_b)", b_desc);

src/target/ptx.cc

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@ namespace codegen {
3636
namespace ptx {
3737

3838
static const char *enum_to_str[] = {
39-
"kInt4", "kUInt4", "kInt8", "kUInt8", "kInt16", "kUInt16", "kInt32", "kUInt32",
40-
"kInt64", "kUInt64", "kFloat8_e4m3", "kFloat8_e5m2", "kFloat16", "kBFloat16", "kFloat16x2", "kFloat32",
41-
"kTensorFloat32", "kFloat64", "kBit1", "kBit8", "kBit16", "kBit32", "kBit64"};
39+
"kInt4", "kUInt4", "kInt8", "kUInt8", "kInt16",
40+
"kUInt16", "kInt32", "kUInt32", "kInt64", "kUInt64",
41+
"kFloat8_e4m3", "kFloat8_e5m2", "kFloat16", "kBFloat16", "kFloat16x2",
42+
"kFloat32", "kTensorFloat32", "kFloat64", "kBit1", "kBit8",
43+
"kBit16", "kBit32", "kBit64"};
4244

4345
static const char *dtype_str[] = {
4446
".s4", ".u4", ".s8", ".u8", ".s16", ".u16", ".s32", ".u32",
@@ -103,13 +105,13 @@ DataType DTypeFromString(const std::string str) {
103105
}
104106
}
105107

106-
107108
std::string DTypeEnumToString(const ptx::DataType &dtype) {
108-
return "tl::DataType::" + std::string(enum_to_str[static_cast<int>(dtype)]);
109+
return "tl::DataType::" + std::string(enum_to_str[static_cast<int>(dtype)]);
109110
}
110111

111112
std::string DTypeEnumToString(const std::string &dtype) {
112-
return "tl::DataType::" + std::string(enum_to_str[static_cast<int>(DTypeFromString(dtype))]);
113+
return "tl::DataType::" +
114+
std::string(enum_to_str[static_cast<int>(DTypeFromString(dtype))]);
113115
}
114116

115117
/*!
@@ -1183,16 +1185,18 @@ PrintMMAAssembly(const std::string &shape, const std::string &A_layout,
11831185
return asm_code;
11841186
}
11851187

1186-
std::string PrintWGMMAAssembly(
1187-
const std::string &shape, const bool &a_is_k_major, const bool &b_is_k_major,
1188-
const std::string &A_dtype, const std::string &B_dtype,
1189-
const std::string &C_dtype, const std::string &a_desc,
1190-
const std::string &A_offset, const std::string &b_desc,
1191-
const std::string &B_offset, const std::string &c_ptr,
1192-
const std::string &c_offset, const bool &scale_out, const bool &scale_in_a,
1193-
const bool &scale_in_b, const bool &a_is_shared,
1194-
const std::string &metadata, const std::string &metadata_offset,
1195-
const std::string &sparsity_selector, bool sparse) {
1188+
std::string
1189+
PrintWGMMAAssembly(const std::string &shape, const bool &a_is_k_major,
1190+
const bool &b_is_k_major, const std::string &A_dtype,
1191+
const std::string &B_dtype, const std::string &C_dtype,
1192+
const std::string &a_desc, const std::string &A_offset,
1193+
const std::string &b_desc, const std::string &B_offset,
1194+
const std::string &c_ptr, const std::string &c_offset,
1195+
const bool &scale_out, const bool &scale_in_a,
1196+
const bool &scale_in_b, const bool &a_is_shared,
1197+
const std::string &metadata,
1198+
const std::string &metadata_offset,
1199+
const std::string &sparsity_selector, bool sparse) {
11961200
ptx::DataType dtype_a = ptx::DTypeFromString(A_dtype),
11971201
dtype_b = ptx::DTypeFromString(B_dtype),
11981202
dtype_c = ptx::DTypeFromString(C_dtype);

src/target/ptx.h

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ namespace tvm::tl {
3333
namespace codegen {
3434

3535
namespace ptx {
36-
36+
3737
/*!
3838
* \brief PTX data type.
3939
* \note
@@ -73,7 +73,6 @@ enum class DataType : int {
7373
*/
7474
DataType DTypeFromString(const std::string str);
7575

76-
7776
/*!
7877
* \brief Print ptx data type from enum.
7978
*/
@@ -90,7 +89,6 @@ std::string DTypeEnumToString(const std::string &dtype);
9089
std::tuple<int, int, int> ParseMMAShape(const std::string &str);
9190
} // namespace ptx
9291

93-
9492
/*!
9593
* \brief Replace patterns with replacement strings.
9694
* \note should use std::format instead when codebase is ported to C++20.
@@ -162,16 +160,18 @@ PrintMMAAssembly(const std::string &shape, const std::string &A_layout,
162160
* \param B_dtype The data type of multiplicand B.
163161
* \param C_dtype The data type of multiplicand C.
164162
*/
165-
std::string PrintWGMMAAssembly(
166-
const std::string &shape, const bool &a_is_k_major, const bool &b_is_k_major,
167-
const std::string &A_dtype, const std::string &B_dtype,
168-
const std::string &C_dtype, const std::string &a_desc,
169-
const std::string &A_offset, const std::string &b_desc,
170-
const std::string &B_offset, const std::string &c_ptr,
171-
const std::string &c_offset, const bool &scale_out, const bool &scale_in_a,
172-
const bool &scale_in_b, const bool &a_is_shared,
173-
const std::string &metadata, const std::string &metadata_offset,
174-
const std::string &sparsity_selector, bool sparse);
163+
std::string
164+
PrintWGMMAAssembly(const std::string &shape, const bool &a_is_k_major,
165+
const bool &b_is_k_major, const std::string &A_dtype,
166+
const std::string &B_dtype, const std::string &C_dtype,
167+
const std::string &a_desc, const std::string &A_offset,
168+
const std::string &b_desc, const std::string &B_offset,
169+
const std::string &c_ptr, const std::string &c_offset,
170+
const bool &scale_out, const bool &scale_in_a,
171+
const bool &scale_in_b, const bool &a_is_shared,
172+
const std::string &metadata,
173+
const std::string &metadata_offset,
174+
const std::string &sparsity_selector, bool sparse);
175175

176176
/*!
177177
* \brief Print ldmatrix assembly string given parameters.

src/tl_templates/cuda/common.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ using int4_t = int4;
5454
} \
5555
} while (0)
5656

57-
5857
// abs function for bfloat_t and half_t since there is no implicit conversion
5958
// method
6059
TL_PATCH TL_DEVICE half_t __habs(const half_t x) {

0 commit comments

Comments
 (0)