Skip to content

Commit 7d0bd4b

Browse files
committed
Enhance thread management and logging in TileLang compilation
- Added a method to check if printing is enabled during compilation, improving control over logging behavior. - Updated the JIT kernel class to utilize the new method for logging compilation status, ensuring consistent and clear output. - Added comments to clarify the purpose of changes and improve code readability.
1 parent d54d1aa commit 7d0bd4b

File tree

6 files changed

+38
-652
lines changed

6 files changed

+38
-652
lines changed

examples/deepseek_nsa/example_tilelang_nsa_decode.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
tilelang.testing.set_random_seed(42)
99

1010

11-
# TODO(@yu): checkout tma with nsa
11+
# TODO(lei): workaround, as threads is not divisible by warp group size,
12+
# auto warp specialization may have some bugs.
1213
@tilelang.jit(
1314
out_idx=[-1],
1415
pass_configs={
@@ -173,8 +174,6 @@ def main():
173174
block_counts=block_counts,
174175
block_size=block_size,
175176
)
176-
print("out", out)
177-
print("ref", ref)
178177
torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2)
179178

180179

src/op/copy.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,8 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
881881
auto stride = as_const_int(shared_layout->InputShape()[0]);
882882
auto continuous = as_const_int(shared_layout->InputShape()[1]);
883883
ICHECK(stride != nullptr && continuous != nullptr);
884+
// We also need to check if the shape satisfies the following doc:
885+
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
884886
if (StructuralEqual()(shared_layout, makeGemmABLayoutPadded(
885887
*stride, *continuous,
886888
shared_tensor->dtype.bits()))) {

src/transform/warp_specialized_rewriter.cc

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,17 +1405,36 @@ class WarpSpecializedRewriter : public StmtExprMutator {
14051405

14061406
class WarpSpecializedDetector : public IRVisitorWithAnalyzer {
14071407
public:
1408+
// return true means this aws will be disabled
14081409
static bool Detect(Stmt stmt, bool skip_thread_partition = false) {
14091410
WarpSpecializedDetector detector;
14101411
detector.VisitStmt(stmt);
1411-
return detector.has_warp_specialization_ ||
1412-
(detector.has_tma_op_ && detector.has_mbarrier_op_);
1412+
if (!detector.num_threads_is_divisible_by_warp_group_) {
1413+
LOG(WARNING)
1414+
<< "Auto warp specialization will be disabled because the number of "
1415+
"threads"
1416+
<< detector.thread_var_->dom->extent
1417+
<< "is not divisible by warp group size";
1418+
return true;
1419+
}
1420+
if (detector.has_warp_specialization_) {
1421+
LOG(WARNING) << "Auto warp specialization will be disabled because warp "
1422+
"specialization is manually enabled";
1423+
return true;
1424+
}
1425+
if (detector.has_tma_op_ && detector.has_mbarrier_op_) {
1426+
LOG(WARNING) << "Auto warp specialization will be disabled because TMA "
1427+
"and mbarrier are both present";
1428+
return true;
1429+
}
1430+
return false;
14131431
}
14141432

14151433
WarpSpecializedDetector() {
14161434
has_tma_op_ = false;
14171435
has_mbarrier_op_ = false;
14181436
has_warp_specialization_ = false;
1437+
num_threads_is_divisible_by_warp_group_ = false;
14191438
}
14201439

14211440
private:
@@ -1449,6 +1468,8 @@ class WarpSpecializedDetector : public IRVisitorWithAnalyzer {
14491468
if (iv->thread_tag == "threadIdx.x") {
14501469
ICHECK(iv->dom->extent.as<IntImmNode>());
14511470
thread_var_ = iv;
1471+
num_threads_is_divisible_by_warp_group_ =
1472+
iv->dom->extent.as<IntImmNode>()->value % warp_group_size_ == 0;
14521473
}
14531474
}
14541475
IRVisitorWithAnalyzer::VisitStmt_(op);
@@ -1458,6 +1479,8 @@ class WarpSpecializedDetector : public IRVisitorWithAnalyzer {
14581479
IterVar thread_var_;
14591480
bool has_mbarrier_op_{false};
14601481
bool has_warp_specialization_{false};
1482+
bool num_threads_is_divisible_by_warp_group_{false};
1483+
const int warp_group_size_ = 128;
14611484
};
14621485

14631486
using namespace tir::transform;

0 commit comments

Comments
 (0)