Skip to content

Commit 0f11130

Browse files
committed
lint fix
1 parent ee01a93 commit 0f11130

File tree

2 files changed

+20
-16
lines changed

2 files changed

+20
-16
lines changed

src/op/copy.cc

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,8 @@ bool Copy::CheckBulkLoad(Target target) const {
364364
if (!TargetHasBulkCopy(target))
365365
return false;
366366
// 2. src and dst must be global and shared
367-
if (src.scope() != "global" || (dst.scope() != "shared.dyn" && dst.scope() != "shared"))
367+
if (src.scope() != "global" ||
368+
(dst.scope() != "shared.dyn" && dst.scope() != "shared"))
368369
return false;
369370
// 3. check shape.
370371
// TODO(lei): validate if we can utilize tma under this shape.
@@ -391,7 +392,8 @@ bool Copy::CheckBulkStore(Target target) const {
391392
if (!TargetHasBulkCopy(target))
392393
return false;
393394
// 2. src and dst must be shared.dyn and local.fragment
394-
if ((src.scope() != "shared.dyn" && src.scope() != "shared") || dst.scope() != "global")
395+
if ((src.scope() != "shared.dyn" && src.scope() != "shared") ||
396+
dst.scope() != "global")
395397
return false;
396398
// 3. check shape.
397399
// TODO(lei): validate if we can utilize tma under this shape.
@@ -414,7 +416,8 @@ bool Copy::CheckBulkStore(Target target) const {
414416
* otherwise.
415417
*/
416418
bool Copy::CheckLDSMCopy(Target target) const {
417-
return TargetHasLdmatrix(target) && (src.scope() == "shared.dyn" || src.scope() == "shared") &&
419+
return TargetHasLdmatrix(target) &&
420+
(src.scope() == "shared.dyn" || src.scope() == "shared") &&
418421
dst.scope() == "local.fragment";
419422
}
420423

@@ -883,10 +886,9 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
883886
ICHECK(stride != nullptr && continuous != nullptr);
884887
// We also need to check if the shape satisfies the following doc:
885888
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
886-
if (StructuralEqual()(
887-
shared_layout,
888-
makeQuarterBankSwizzleLayout(*stride, *continuous,
889-
shared_tensor->dtype.bits()))) {
889+
if (StructuralEqual()(shared_layout, makeQuarterBankSwizzleLayout(
890+
*stride, *continuous,
891+
shared_tensor->dtype.bits()))) {
890892
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_32B);
891893
} else if (StructuralEqual()(
892894
shared_layout,
@@ -898,18 +900,18 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
898900
makeFullBankSwizzleLayout(*stride, *continuous,
899901
shared_tensor->dtype.bits()))) {
900902
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B);
901-
} else if (StructuralEqual()(shared_layout, makeGemmABLayoutPadded(
902-
*stride, *continuous,
903-
shared_tensor->dtype.bits()))) {
904-
LOG(WARNING) << "Bulk copy cannot support a padded layout for src: "
905-
<< src->name << ", dst: " << dst->name
903+
} else if (StructuralEqual()(
904+
shared_layout,
905+
makeGemmABLayoutPadded(*stride, *continuous,
906+
shared_tensor->dtype.bits()))) {
907+
LOG(WARNING) << "Bulk copy cannot support a padded layout for src: "
908+
<< src->name << ", dst: " << dst->name
906909
<< ", fallback to normal copy";
907910
return LowerNormalCopy(T, analyzer);
908911
} else {
909-
LOG(WARNING)
910-
<< "Came across unsupported swizzle layout for src: "
911-
<< src->name << ", dst: " << dst->name
912-
<< ", fallback to normal copy";
912+
LOG(WARNING) << "Came across unsupported swizzle layout for src: "
913+
<< src->name << ", dst: " << dst->name
914+
<< ", fallback to normal copy";
913915
return LowerNormalCopy(T, analyzer);
914916
}
915917
}

tilelang/language/builtin.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,13 @@ def no_set_max_nreg():
159159
"""
160160
return tir.call_intrin("handle", tir.op.Op.get("tl.no_set_max_nreg"))
161161

162+
162163
def disable_warp_group_reg_alloc():
163164
"""Disable the warp group reg alloc.
164165
"""
165166
return no_set_max_nreg()
166167

168+
167169
def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr, tir.Call], parity: Union[int, Var]):
168170
"""Wait for memory barrier parity condition.
169171

0 commit comments

Comments
 (0)