Skip to content

Commit 12d0094

Browse files
committed
Refactor loop stack management in warp_specialized_rewriter
- Introduced a new `LoopInfo` struct to encapsulate loop variable details, including `loop_var`, `extent`, and `min`, enhancing clarity and maintainability. - Updated the `loop_stack_` to utilize `LoopInfo` instead of a pair, improving type safety and readability. - Adjusted linear index calculations to account for the new structure, ensuring correct behavior in loop transformations.
1 parent 6a0444e commit 12d0094

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

src/transform/warp_specialized_rewriter.cc

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ using namespace tir;
2424
using namespace runtime;
2525
using arith::IRVisitorWithAnalyzer;
2626

27+
struct LoopInfo {
28+
Var loop_var;
29+
PrimExpr extent;
30+
PrimExpr min;
31+
};
32+
2733
enum class Role { kConsumer, kProducer, kBoth };
2834

2935
class ProducerBufferDetector : public StmtExprVisitor {
@@ -838,7 +844,7 @@ class WSCodeEmitter : public StmtMutator {
838844
num_stages = static_cast<int>(num_stages_anno->as<IntImmNode>()->value);
839845
ICHECK(num_stages_ == 1) << "Nested pipeline not supported.";
840846
}
841-
loop_stack_.emplace_back(op->loop_var, op->extent);
847+
loop_stack_.emplace_back(LoopInfo{op->loop_var, op->extent, op->min});
842848

843849
Array<Array<Integer>> group_info_array;
844850
Array<Integer> order_info_array;
@@ -871,15 +877,14 @@ class WSCodeEmitter : public StmtMutator {
871877

872878
num_stages_ = num_stages;
873879
pipeline_info_ = pipeline_info;
874-
PrimExpr linear_index = loop_stack_[0].first;
880+
PrimExpr linear_index = loop_stack_[0].loop_var - loop_stack_[0].min;
875881
for (size_t i = 1; i < loop_stack_.size(); ++i) {
876-
linear_index =
877-
linear_index * loop_stack_[i].second + loop_stack_[i].first;
882+
linear_index = linear_index * loop_stack_[i].extent +
883+
(loop_stack_[i].loop_var - loop_stack_[i].min);
878884
}
879885
stage_ = FloorMod(linear_index, num_stages);
880886
parity_ = FloorMod(
881887
parity_before * op->extent + FloorDiv(linear_index, num_stages), 2);
882-
883888
auto result = FilterByRole(op);
884889

885890
Stmt grouped_for_node;
@@ -1137,7 +1142,7 @@ class WSCodeEmitter : public StmtMutator {
11371142
PrimExpr parity_ = 0;
11381143
PrimExpr stage_ = 0;
11391144
int num_stages_ = 1;
1140-
std::vector<std::pair<Var, PrimExpr>> loop_stack_;
1145+
std::vector<LoopInfo> loop_stack_;
11411146
Var thread_var_;
11421147
bool mbarrier_only_ = false;
11431148
PipelineInfo pipeline_info_;

0 commit comments

Comments
 (0)