@@ -64,28 +64,88 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer,
6464 ICHECK (thread_var.defined ());
6565 int old_loop_depth = loop_layout->InputDim ();
6666 int new_loop_depth = loop_layout->OutputDim ();
67-
6867 // Create the new loop iter var
6968 Array<Var> vars;
7069 for (int i = 0 ; i < new_loop_depth; i++) {
7170 Var var = Var (std::string{char (' i' + i)});
71+ analyzer->Bind (var, Range::FromMinExtent (make_zero (var->dtype ),
72+ loop_layout->OutputShape ()[i]));
7273 vars.push_back (var);
7374 }
7475 vars.push_back (thread_var);
7576 // create the substitute map, and the loop body
7677 Map<Var, PrimExpr> vmap;
7778 Stmt body = std::move (op);
78- auto inv_loop = loop_layout->Inverse ();
79+ Array<PrimExpr> loop_mins;
80+ Array<PrimExpr> loop_extents;
81+ auto inverse_info = loop_layout->InverseWithLevel ();
82+ auto inv_loop = inverse_info.first ;
83+ // Must check the guard if the layout can not be proved as bijective
84+ bool need_guard = inverse_info.second != arith::IterMapLevel::Bijective;
7985 auto indices = inv_loop->Forward (Array<PrimExpr>(vars.begin (), vars.end ()));
86+ // Normalize thread var once so we can reuse the same substitution later.
87+ Map<Var, PrimExpr> thread_offset_map;
88+ bool has_thread_offset = false ;
89+ if (loop_layout->ThreadRange ().defined ()) {
90+ auto range = loop_layout->ThreadRange ();
91+ thread_offset_map.Set (thread_var, thread_var - range->min );
92+ has_thread_offset = true ;
93+ }
8094 for (int i = 0 ; i < old_loop_depth; i++) {
8195 const ForNode *loop = body.as <ForNode>();
8296 ICHECK (loop != nullptr );
8397 vmap.Set (loop->loop_var , indices[i]);
98+ loop_mins.push_back (loop->min );
99+ loop_extents.push_back (loop->extent );
84100 body = loop->body ;
85101 }
86-
87102 // substitute and re-construct the serial loop
88103 body = Substitute (body, vmap);
104+ // Guard executes the recovered loop body only if each inverse-mapped iterator
105+ // falls back into the original For ranges. We first check every axis from the
106+ // old loop nest (old_loop_depth) and then the extra index produced by inverse
107+ // layouts that carry a replicate/thread component (`inv_output_shape`). Both
108+ // must stay within bounds to ensure correctness. Example: layout([i, j]) =
109+ // floor((i * 16 + j) / 32) may generate extra points when the new loop
110+ // enumerates 0..31; the guard drops iterations whose inverse-mapped (i, j)
111+ // or replicate index fall outside their original extents.
112+ // Example: layout([i, j]) = floor((i * 16 + j) / 32) may produce extra points
113+ // when the new loop enumerates 0..31; this guard skips iterations where the
114+ // inverse i, j land outside the original extents. This protects
115+ // non-surjective loop_layout mappings that otherwise over-cover the parallel
116+ // space.
117+ PrimExpr guard = const_true ();
118+
119+ if (need_guard) {
120+ for (int i = 0 ; i < old_loop_depth; i++) {
121+ PrimExpr index = indices[i];
122+ if (has_thread_offset) {
123+ index = Substitute (index, thread_offset_map);
124+ }
125+ PrimExpr lower_bound = analyzer->Simplify (index >= loop_mins[i]);
126+ PrimExpr upper_bound =
127+ analyzer->Simplify (index < loop_mins[i] + loop_extents[i]);
128+ guard = And (guard, And (lower_bound, upper_bound));
129+ }
130+ auto inv_output_shape = inv_loop->OutputShape ();
131+ if (inv_output_shape.size () > static_cast <size_t >(old_loop_depth)) {
132+ PrimExpr replicate_index = indices[old_loop_depth];
133+ if (has_thread_offset) {
134+ replicate_index = Substitute (replicate_index, thread_offset_map);
135+ }
136+ PrimExpr replicate_extent = inv_output_shape[old_loop_depth];
137+ PrimExpr lower_bound = analyzer->Simplify (
138+ replicate_index >= make_zero (replicate_index.dtype ()));
139+ PrimExpr upper_bound =
140+ analyzer->Simplify (replicate_index < replicate_extent);
141+ guard = And (guard, And (lower_bound, upper_bound));
142+ }
143+ PrimExpr simplified_guard = analyzer->Simplify (guard);
144+ if (!analyzer->CanProve (simplified_guard)) {
145+ body = IfThenElse (simplified_guard, body, Stmt ());
146+ }
147+ }
148+
89149 for (int i = new_loop_depth - 1 ; i >= 0 ; i--) {
90150 body = For (vars[i], make_zero (vars[i]->dtype ), inv_loop->InputShape ()[i],
91151 ForKind::kSerial , body);
@@ -94,13 +154,11 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer,
94154
95155 body = BufferIndiceSimplify (analyzer)(body);
96156
97- auto for_node = LoopPragmaUnroll (Downcast<For>(body));
98- if (loop_layout->ThreadRange ().defined ()) {
99- auto range = loop_layout->ThreadRange ();
100- auto thread_var_with_offset = thread_var - range->min ;
101- for_node.CopyOnWrite ()->body =
102- Substitute (for_node->body , {{thread_var, thread_var_with_offset}});
157+ if (has_thread_offset) {
158+ body = Substitute (body, thread_offset_map);
103159 }
160+
161+ auto for_node = LoopPragmaUnroll (Downcast<For>(body));
104162 return for_node;
105163}
106164
@@ -111,6 +169,10 @@ class LoopPramaUnroller : public StmtExprMutator {
111169private:
112170 Stmt VisitStmt_ (const ForNode *node) final {
113171 if (node->kind == ForKind::kSerial ) {
172+ auto analyzer = std::make_shared<arith::Analyzer>();
173+ if (as_const_int (analyzer->Simplify (node->extent )) == nullptr ) {
174+ return StmtExprMutator::VisitStmt_ (node);
175+ }
114176 For new_for = GetRef<For>(node);
115177 auto for_ptr = new_for.CopyOnWrite ();
116178 for_ptr->annotations .Set (tir::attr::pragma_unroll_explicit, Bool (false ));
@@ -127,22 +189,20 @@ class LoopPartitioner : public StmtExprVisitor {
127189
128190 Fragment Partition (const For &op, int num_thread, int vectorize_size) {
129191 this ->VisitStmt (op);
130- int loop_size_full = 1 ;
131- PrimExpr flattened = 0 ;
192+ ICHECK (!loop_vars_.empty ());
193+ DataType dtype = loop_vars_[0 ]->var .dtype ();
194+ PrimExpr flattened = make_const (dtype, 0 );
195+ PrimExpr vector_extent = make_const (dtype, vectorize_size);
196+ PrimExpr thread_extent_const = make_const (dtype, num_thread);
132197 for (size_t i = 0 ; i < loop_vars_.size (); i++) {
133- auto ext_ptr = as_const_int (loop_vars_[i]->dom ->extent );
134- ICHECK (ext_ptr)
135- << " Loop partitioner only works with constant loop sizes, but got "
136- << loop_vars_[i]->dom ->extent ;
137- int extent = *ext_ptr;
138- loop_size_full *= extent;
198+ PrimExpr extent = loop_vars_[i]->dom ->extent ;
139199 flattened = flattened * extent + loop_vars_[i]->var ;
140200 }
141- ICHECK (loop_size_full % vectorize_size == 0 );
142- PrimExpr access_idx = FloorDiv (flattened, vectorize_size );
143- PrimExpr thd = FloorMod (access_idx, num_thread);
144- PrimExpr idx = FloorDiv (access_idx, num_thread) * vectorize_size +
145- FloorMod (flattened, vectorize_size);
201+ PrimExpr access_idx = FloorDiv (flattened, vector_extent );
202+ PrimExpr thd = FloorMod (access_idx, thread_extent_const );
203+ PrimExpr idx = FloorDiv (access_idx, thread_extent_const) * vector_extent +
204+ FloorMod (flattened, vector_extent);
205+
146206 auto fragment = Fragment (loop_vars_, {idx}, {thd}, {});
147207 if (has_fragment_) {
148208 // for fragment buffer, we don't need to replicate the loop layout
0 commit comments