Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions src/transform/warp_specialized_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ using namespace tir;
using namespace runtime;
using arith::IRVisitorWithAnalyzer;

struct LoopInfo {
Var loop_var;
PrimExpr extent;
PrimExpr min;
};

enum class Role { kConsumer, kProducer, kBoth };

class ProducerBufferDetector : public StmtExprVisitor {
Expand Down Expand Up @@ -838,7 +844,7 @@ class WSCodeEmitter : public StmtMutator {
num_stages = static_cast<int>(num_stages_anno->as<IntImmNode>()->value);
ICHECK(num_stages_ == 1) << "Nested pipeline not supported.";
}
loop_stack_.emplace_back(op->loop_var, op->extent);
loop_stack_.emplace_back(LoopInfo{op->loop_var, op->extent, op->min});

Array<Array<Integer>> group_info_array;
Array<Integer> order_info_array;
Expand Down Expand Up @@ -871,15 +877,14 @@ class WSCodeEmitter : public StmtMutator {

num_stages_ = num_stages;
pipeline_info_ = pipeline_info;
PrimExpr linear_index = loop_stack_[0].first;
PrimExpr linear_index = loop_stack_[0].loop_var - loop_stack_[0].min;
for (size_t i = 1; i < loop_stack_.size(); ++i) {
linear_index =
linear_index * loop_stack_[i].second + loop_stack_[i].first;
linear_index = linear_index * loop_stack_[i].extent +
(loop_stack_[i].loop_var - loop_stack_[i].min);
}
stage_ = FloorMod(linear_index, num_stages);
parity_ = FloorMod(
parity_before * op->extent + FloorDiv(linear_index, num_stages), 2);

auto result = FilterByRole(op);

Stmt grouped_for_node;
Expand Down Expand Up @@ -1137,7 +1142,7 @@ class WSCodeEmitter : public StmtMutator {
PrimExpr parity_ = 0;
PrimExpr stage_ = 0;
int num_stages_ = 1;
std::vector<std::pair<Var, PrimExpr>> loop_stack_;
std::vector<LoopInfo> loop_stack_;
Var thread_var_;
bool mbarrier_only_ = false;
PipelineInfo pipeline_info_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -550,10 +550,10 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):


def test_all():
run_assert_tl_matmul_block_static(16384, 16384, 16384, 128, 128, 32)
run_assert_tl_matmul_block_dynamic_m(16384, 16384, 16384, 128, 128, 32)
run_assert_tl_matmul_block_dynamic_mn(16384, 16384, 16384, 128, 128, 32)
run_assert_tl_matmul_block_dynamic_mnk(16384, 16384, 16384, 128, 128, 32)
run_assert_tl_matmul_block_static(1024, 1024, 1024, 128, 128, 32)
run_assert_tl_matmul_block_dynamic_m(1024, 1024, 1024, 128, 128, 32)
run_assert_tl_matmul_block_dynamic_mn(1024, 1024, 1024, 128, 128, 32)
run_assert_tl_matmul_block_dynamic_mnk(1024, 1024, 1024, 128, 128, 32)


if __name__ == "__main__":
Expand Down
1 change: 0 additions & 1 deletion tilelang/engine/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.MergeSharedMemoryAllocations(
enable_aggressive_merge=enable_aggressive_merge)(
mod)
print("mod \n", mod)
mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
# Inject PTX async copy must behind the thread sync pass
Expand Down
Loading