55
66#include  " ../layout/layout.h" 
77#include  " ../layout/utils.h" 
8+ #include  " ../transform/loop_partition.h" 
89#include  " arith/int_operator.h" 
910#include  " arith/ir_visitor_with_analyzer.h" 
1011#include  " common/loop_vectorization_utils.h" 
@@ -29,6 +30,30 @@ struct AtomicAddVectorizePlanResult {
2930  PrimExpr condition;
3031};
3132
33+ class  BufferIndiceSimplify  : public  StmtExprMutator  {
34+ public: 
35+   BufferIndiceSimplify (arith::Analyzer *analyzer) : analyzer_(analyzer) {}
36+ 
37+ private: 
38+   PrimExpr VisitExpr_ (const  BufferLoadNode *node) final  {
39+     auto  visited = StmtExprMutator::VisitExpr_ (node);
40+     auto  n = Downcast<BufferLoad>(visited);
41+     auto  nptr = n.CopyOnWrite ();
42+     nptr->indices  = nptr->indices .Map (
43+         [&](const  auto  &e) { return  analyzer_->Simplify (e); });
44+     return  n;
45+   }
46+   Stmt VisitStmt_ (const  BufferStoreNode *node) final  {
47+     auto  visited = StmtExprMutator::VisitStmt_ (node);
48+     auto  n = Downcast<BufferStore>(visited);
49+     auto  nptr = n.CopyOnWrite ();
50+     nptr->indices  = nptr->indices .Map (
51+         [&](const  auto  &e) { return  analyzer_->Simplify (e); });
52+     return  n;
53+   }
54+   arith::Analyzer *analyzer_;
55+ };
56+ 
3257class  AtomicAddVectorizePlanner  : public  arith ::IRVisitorWithAnalyzer {
3358public: 
3459  AtomicAddVectorizePlanner () = default ;
@@ -137,69 +162,75 @@ class AtomicAddVectorizePlanner : public arith::IRVisitorWithAnalyzer {
137162class  AtomicAddVectorizeRewriter  : public  StmtExprMutator  {
138163public: 
139164  AtomicAddVectorizeRewriter (const  AtomicAddVectorizePlanResult &plan,
140-                              Var thread_var, PrimExpr by_var, PrimExpr bx_var,
141-                              const  Range &thread_bounds, int  stride_y,
142-                              int  stride_x)
165+                              Var thread_var, const  Range &thread_bounds)
143166      : vector_size_(plan.vector_size), condition_(plan.condition),
144-         dynamic_ (plan.dynamic), tx_var_(std::move(thread_var)),
145-         by_var_(std::move(by_var)), bx_var_(std::move(bx_var)),
146-         stride_y_(stride_y), stride_x_(stride_x) {
167+         dynamic_ (plan.dynamic), tx_var_(std::move(thread_var)) {
147168    const  int64_t  *tx_ext = as_const_int (thread_bounds->extent );
148169    ICHECK (tx_ext)
149170        << " thread_bounds->extent must be a constant for vectorization." 
150171    extent_tx_ = static_cast <int >(*tx_ext);
151172  }
152173
153- private: 
154-   /* *
155-    * @brief Visits a For node and rewrites the innermost loop for atomic-add 
156-    * vectorization. 
157-    * 
158-    * If the visited For node is the recorded innermost loop, this method 
159-    * validates that the loop extent is a constant, divisible by the planned 
160-    * vector size, and has a zero minimum. When vectorization is enabled 
161-    * (dynamic_ == false) it: 
162-    *  - locates the thread index variable named "tx" inside the loop body, 
163-    *  - creates a new outer loop variable named "<old_loop_var>_outer", 
164-    *  - substitutes occurrences of `tx` with `tx * vector_size_` and the old 
165-    * loop var with `outer_var * vector_size_` so each outer iteration maps to a 
166-    * contiguous vector-sized chunk, 
167-    *  - returns a new For with extent divided by vector_size_ and the 
168-    * transformed body. 
169-    * 
170-    * If dynamic_ is true, the method returns the (possibly mutated) inner For 
171-    * unchanged. 
172-    * 
173-    * Side effects: 
174-    *  - updates inner_for_ to point to the current For node during visitation. 
175-    *  - performs runtime checks (ICHECK) to enforce: constant extent, extent % 
176-    * vector_size_ == 0, and zero loop minimum; violations terminate execution. 
177-    * 
178-    * @return The original or transformed For statement as a Stmt. 
179-    */  
180-   Stmt VisitStmt_ (const  ForNode *node) final  {
181-     inner_for_ = node;
182-     iter_var_ = Var (node->loop_var ->name_hint  + " _outer" 
183-     auto  ret = StmtExprMutator::VisitStmt_ (node);
184-     if  (inner_for_ == node) { //  rewrite the innermost loop
185-       For fnode = ret.as <For>().value ();
186-       auto  extent_ptr = as_const_int (fnode->extent );
187-       ICHECK (extent_ptr) << fnode->extent ;
188-       int  extent = *extent_ptr;
189-       ICHECK (extent % vector_size_ == 0 )
190-           << " extent: " "  vector_size_: " 
191-       ICHECK (is_zero (fnode->min ));
192-       if  (!dynamic_) {
193-         Map<Var, PrimExpr> vmap;
194-         vmap.Set (fnode->loop_var , iter_var_);
195-         Stmt body = Substitute (fnode->body , vmap);
196-         return  For (iter_var_, 0 , extent / vector_size_, fnode->kind , body,
197-                    fnode->thread_binding , fnode->annotations , fnode->span );
198-       }
174+   For run (For for_node, const  Fragment &loop_layout,
175+           arith::Analyzer *analyzer) {
176+     int  old_loop_depth = loop_layout->InputDim ();
177+     int  new_loop_depth = loop_layout->OutputDim ();
178+ 
179+     Array<Var> vars;
180+     for  (int  i = 0 ; i < new_loop_depth; i++) {
181+       Var var = Var (std::string{char (' i' 
182+       vars.push_back (var);
183+     }
184+     vars.push_back (tx_var_);
185+     Map<Var, PrimExpr> vmap;
186+     Stmt body = std::move (for_node);
187+     auto  inv_loop = loop_layout->Inverse ();
188+     auto  indices = inv_loop->Forward (Array<PrimExpr>(vars.begin (), vars.end ()));
189+     //  the innerest iter_var need expand because of vectorize
190+ 
191+     const  ForNode *loop = body.as <ForNode>();
192+     ICHECK (loop != nullptr );
193+     vmap.Set (loop->loop_var , indices[0 ] * vector_size_);
194+     body = loop->body ;
195+     for  (int  i = 1 ; i < old_loop_depth; i++) {
196+       const  ForNode *loop = body.as <ForNode>();
197+       ICHECK (loop != nullptr );
198+       vmap.Set (loop->loop_var , indices[i]);
199+       body = loop->body ;
199200    }
200-     return  ret;
201+     body = Substitute (body, vmap);
202+ 
203+     //  innerest iter_var extent need to be shorter because of vectorize
204+ 
205+     body = For (vars[new_loop_depth - 1 ],
206+                make_zero (vars[new_loop_depth - 1 ]->dtype ),
207+                div (inv_loop->InputShape ()[new_loop_depth - 1 ], vector_size_),
208+                ForKind::kSerial , body);
209+     analyzer->Bind (vars[new_loop_depth - 1 ],
210+                    Range (0 , div (inv_loop->InputShape ()[new_loop_depth - 1 ],
211+                                 vector_size_)));
212+ 
213+     for  (int  i = new_loop_depth - 2 ; i >= 0 ; i--) {
214+       body = For (vars[i], make_zero (vars[i]->dtype ),
215+                  div (inv_loop->InputShape ()[i], vector_size_), ForKind::kSerial ,
216+                  body);
217+       analyzer->Bind (vars[i], Range (0 , inv_loop->InputShape ()[i]));
218+     }
219+ 
220+     body = BufferIndiceSimplify (analyzer)(body);
221+ 
222+     auto  node = LoopPragmaUnroll (Downcast<For>(body));
223+     if  (loop_layout->ThreadRange ().defined ()) {
224+       auto  range = loop_layout->ThreadRange ();
225+       auto  thread_var_with_offset = tx_var_ - range->min ;
226+       node.CopyOnWrite ()->body  =
227+           Substitute (node->body , {{tx_var_, thread_var_with_offset}});
228+     }
229+     auto  new_stmt = this ->VisitStmt (node);
230+     return  Downcast<For>(new_stmt);
201231  }
202232
233+ private: 
203234  PrimExpr VisitExpr_ (const  CallNode *node) final  {
204235    if  (dynamic_) {
205236      return  StmtExprMutator::VisitExpr_ (node);
@@ -208,57 +239,18 @@ class AtomicAddVectorizeRewriter : public StmtExprMutator {
208239      if  (node->op  == builtin::call_extern () && node->args .size () >= 2 ) {
209240        if  (const  auto  *func_name = node->args [0 ].as <StringImmNode>()) {
210241          if  (func_name->value  == " AtomicAdd" 
211-             //  Matrix[by * stride_y + i / (stride_x / (tx_txtent *
212-             //  vector_size_)) + tx_var_ / (stride_x / vector_size_),
213-             //         bx * stride_x + (i % (stride_x / (tx_extent *
214-             //         vector_size_)) * (tx_extent * vector_size_) + (tx_var_ %
215-             //         (stride / vector_size_)) * vector_size_]
216-             const  BufferLoadNode *old_dst_node =
242+             const  BufferLoadNode *temp_dst_node =
217243                node->args [1 ].as <BufferLoadNode>();
218-             const  BufferLoadNode *old_value_node  =
244+             const  BufferLoadNode *temp_value_node  =
219245                node->args [2 ].as <BufferLoadNode>();
220-             if  (!old_dst_node  || !old_value_node ) {
246+             if  (!temp_dst_node  || !temp_value_node ) {
221247              return  StmtExprMutator::VisitExpr_ (node);
222248            }
223-             Array<PrimExpr> dst_indices, value_indices;
224-             if  ((extent_tx_ * vector_size_) > stride_x_) {
225-               dst_indices.push_back (
226-                   by_var_ * stride_y_ +
227-                   iter_var_ * (extent_tx_ * vector_size_ / stride_x_) +
228-                   truncdiv (tx_var_, stride_x_ / vector_size_));
229-               dst_indices.push_back (
230-                   bx_var_ * stride_x_ +
231-                   truncmod (tx_var_, stride_x_ / vector_size_) * vector_size_);
232-               value_indices.push_back (
233-                   iter_var_ * (extent_tx_ * vector_size_ / stride_x_) +
234-                   truncdiv (tx_var_ * vector_size_, stride_x_));
235-               value_indices.push_back (
236-                   truncmod (tx_var_, stride_x_ / vector_size_) * vector_size_);
237-             } else  {
238-               dst_indices.push_back (
239-                   by_var_ * stride_y_ +
240-                   truncdiv (iter_var_, stride_x_ / (extent_tx_ * vector_size_)) +
241-                   truncdiv (tx_var_, stride_x_ / vector_size_));
242-               dst_indices.push_back (
243-                   bx_var_ * stride_x_ +
244-                   truncmod (iter_var_, stride_x_ / (extent_tx_ * vector_size_)) *
245-                       (extent_tx_ * vector_size_) +
246-                   truncmod (tx_var_, stride_x_ / vector_size_) * vector_size_);
247-               value_indices.push_back (
248-                   truncdiv (iter_var_, stride_x_ / (extent_tx_ * vector_size_)) +
249-                   truncdiv (tx_var_, stride_x_ / vector_size_));
250-               value_indices.push_back (
251-                   truncmod (iter_var_, stride_x_ / (extent_tx_ * vector_size_)) *
252-                       (extent_tx_ * vector_size_) +
253-                   truncmod (tx_var_, stride_x_ / vector_size_) * vector_size_);
254-             }
249+             const  BufferLoad dst_node =
250+                 Downcast<BufferLoad>(node->args [1 ].as <BufferLoadNode>());
251+             const  BufferLoad value_node =
252+                 Downcast<BufferLoad>(node->args [2 ].as <BufferLoadNode>());
255253
256-             BufferLoad dst_node =
257-                 BufferLoad (old_dst_node->buffer , dst_indices,
258-                            old_dst_node->predicate , old_dst_node->span );
259-             BufferLoad value_node =
260-                 BufferLoad (old_value_node->buffer , value_indices,
261-                            old_value_node->predicate , old_value_node->span );
262254            Call address_of_dst =
263255                Call (DataType::Handle (), builtin::address_of (), {dst_node});
264256            Call address_of_value =
@@ -287,10 +279,7 @@ class AtomicAddVectorizeRewriter : public StmtExprMutator {
287279  const  int  vector_size_;
288280  const  PrimExpr condition_;
289281  const  bool  dynamic_;
290-   const  PrimExpr by_var_, bx_var_;
291-   int  stride_y_, stride_x_;
292282  const  Var tx_var_;
293-   Var iter_var_;
294283  int  extent_tx_;
295284};
296285
@@ -317,11 +306,10 @@ static int GetVectorizeSizeMax(int compute_capability, DataType dtype) {
317306}
318307
319308For VectorizeAtomicAdd (const  For &for_node, const  Var &thread_var,
320-                        const  Range &thread_bounds, int  compute_capability) {
309+                        const  Range &thread_bounds, int  compute_capability,
310+                        arith::Analyzer *analyzer, const  Fragment &loop_layout) {
321311
322312  int  vectorize_size_max = 1 ;
323-   int  stride_x = -1 , stride_y = -1 ;
324-   PrimExpr bx_var, by_var;
325313
326314  PostOrderVisit (for_node->body , [&](const  ObjectRef &obj) {
327315    if  (const  auto  *call = obj.as <CallNode>()) {
@@ -333,40 +321,22 @@ For VectorizeAtomicAdd(const For &for_node, const Var &thread_var,
333321        }
334322      }
335323    }
336-     if  (const  MulNode *mul = obj.as <MulNode>()) {
337-       const  VarNode *var = nullptr ;
338-       const  IntImmNode *imm = nullptr ;
339-       PrimExpr var_expr;
340-       if  ((var = mul->a .as <VarNode>()) && (imm = mul->b .as <IntImmNode>())) {
341-         var_expr = mul->a ;
342-       } else  if  ((var = mul->b .as <VarNode>()) &&
343-                  (imm = mul->a .as <IntImmNode>())) {
344-         var_expr = mul->b ;
345-       }
346-       if  (var && imm) {
347-         if  (var->name_hint  == " bx" 
348-           stride_x = imm->value ;
349-           bx_var = var_expr;
350-         } else  if  (var->name_hint  == " by" 
351-           stride_y = imm->value ;
352-           by_var = var_expr;
353-         }
354-       }
355-     }
356324  });
325+ 
357326  if  (vectorize_size_max != 1 ) {
358327    int  vectorize_hint = vectorize_size_max;
359328    AtomicAddVectorizePlanResult res = {1 , false , 0 };
360329    AtomicAddVectorizePlanner planner;
361-     res = planner.Plan (for_node, thread_var, thread_bounds, vectorize_hint);
330+     For simplified_for_node =
331+         PartitionLoop (for_node, thread_var, analyzer, loop_layout);
332+     res = planner.Plan (simplified_for_node, thread_var, thread_bounds,
333+                        vectorize_hint);
362334    vectorize_hint = res.vector_size ;
363335
364-     if  (vectorize_hint == 1  || stride_x == -1  || stride_y == -1  ||
365-         !bx_var.defined () || !by_var.defined ())
336+     if  (vectorize_hint == 1 )
366337      return  for_node;
367-     auto  rewriter = AtomicAddVectorizeRewriter (
368-         res, thread_var, by_var, bx_var, thread_bounds, stride_y, stride_x);
369-     return  Downcast<For>(rewriter (for_node));
338+     auto  rewriter = AtomicAddVectorizeRewriter (res, thread_var, thread_bounds);
339+     return  rewriter.run (for_node, loop_layout, analyzer);
370340  } else  {
371341    return  for_node;
372342  }
0 commit comments