2424
2525#include  " loop_vectorize.h" 
2626
27- #include  < tvm/arith/iter_affine_map.h> 
28- #include  < tvm/tir/builtin.h> 
29- #include  < tvm/tir/stmt_functor.h> 
30- 
31- #include  < numeric> 
32- 
33- #include  " ../layout/layout.h" 
34- #include  " ../layout/utils.h" 
3527#include  " arith/int_operator.h" 
3628#include  " arith/ir_visitor_with_analyzer.h" 
3729#include  " common/loop_vectorization_utils.h" 
30+ #include  " tvm/tir/analysis.h" 
31+ #include  " tvm/tir/var.h" 
32+ #include  < tvm/arith/iter_affine_map.h> 
33+ #include  < tvm/tir/builtin.h> 
34+ #include  < tvm/tir/stmt_functor.h> 
3835
3936namespace  tvm  {
4037namespace  tl  {
@@ -56,15 +53,18 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
5653    return  vector_size_;
5754  }
5855
59-   bool  GetDynamic () { return  dynamic_; }
60- 
61-   PrimExpr GetCondition () { return  condition_; }
62- 
6356private: 
6457  void  VisitStmt_ (const  ForNode *node) final  {
6558    inner_for_ = node;
66-     iter_map_.Set (node->loop_var , Range (node->min , node->extent ));
67- 
59+     auto  extent_ptr = as_const_int (node->extent );
60+     //  Here I disable dynamic shape completely,
61+     //    In order to do it, the Planner should accept an analyzer with
62+     //    arithmetic info outside to prove the dividiblity of vector size
63+     if  (!extent_ptr) {
64+       vector_size_ = 1 ;
65+       return ;
66+     }
67+     vector_size_ = arith::ZeroAwareGCD (vector_size_, *extent_ptr);
6868    arith::IRVisitorWithAnalyzer::VisitStmt_ (node);
6969  }
7070
@@ -113,76 +113,47 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
113113  void  UpdateVectorSize (const  Array<PrimExpr> &indices, const  Buffer &buffer) {
114114    if  (!inner_for_)
115115      return ;
116-     auto  extent_ptr = inner_for_->extent .as <IntImmNode>();
117-     if  (!extent_ptr)
116+     //  1. Compute raw element offset
117+     auto  strides = buffer->strides ;
118+     if  (buffer->strides .empty ()) {
119+       PrimExpr stride = 1 ;
120+       for  (int  i = indices.size () - 1 ; i >= 0 ; --i) {
121+         strides.push_back (stride);
122+         stride = stride * buffer->shape [i];
123+       }
124+       strides = Array<PrimExpr>{strides.rbegin (), strides.rend ()};
125+     }
126+     PrimExpr elem_offset = 0 ;
127+     for  (int  i = 0 ; i < indices.size (); ++i) {
128+       elem_offset += indices[i] * strides[i];
129+     }
130+ 
131+     //  2. If element offset is independent with loop_var, ignore it
132+     if  (CanProveIndependent (elem_offset, inner_for_->loop_var , &analyzer_)) {
118133      return ;
134+     }
119135
120-     const  DataType &access_type = buffer->dtype ;
121-     //  i // 2, i % 8 can also be vectorized as factor 16
122-     int  max_vector_size = vector_load_bits_max_ / access_type.bits ();
123-     //  so we should disable this GCD optimization
124-     max_vector_size = arith::ZeroAwareGCD (max_vector_size, extent_ptr->value );
125-     auto  last_dim = buffer->shape .back ();
126-     auto  mod_set = analyzer_.modular_set (last_dim);
127-     //  when dynamic shape like [m, k]: coeff=1, base=0, GCD will block
128-     //  conditionally tail vectorize
129-     if  (buffer->shape .back ().as <IntImmNode>()) {
130-       max_vector_size = arith::ZeroAwareGCD (max_vector_size, mod_set->coeff );
131-       auto  gcd_base = arith::ZeroAwareGCD (max_vector_size, mod_set->base );
132-       //  If gcd_base is equal to the last dimension,
133-       //  we should analyze the second-to-last dimension
134-       //  in relation to the last dimension.
135-       if  (gcd_base < Downcast<IntImm>(last_dim)->value ) {
136-         max_vector_size = gcd_base;
137-       }
138-       vector_size_ = arith::ZeroAwareGCD (max_vector_size, vector_size_);
139- 
140-       //  Generate strides if not existed
141-       auto  strides = buffer->strides ;
142-       if  (buffer->strides .empty ()) {
143-         PrimExpr stride = 1 ;
144-         for  (int  i = indices.size () - 1 ; i >= 0 ; --i) {
145-           strides.push_back (stride);
146-           stride = stride * buffer->shape [i];
147-         }
148-         strides = Array<PrimExpr>{strides.rbegin (), strides.rend ()};
149-       }
136+     //  3. Tight vectorize bound
137+     vector_size_ = arith::ZeroAwareGCD (vector_size_, vector_load_bits_max_ /
138+                                                          buffer->dtype .bits ());
150139
151-       //  Generate and check element offset expression
152-       ICHECK (indices.size () == strides.size ()) << " Invalid indices and strides" 
153-       PrimExpr elem_offset = 0 ;
154-       for  (int  i = 0 ; i < indices.size (); ++i) {
155-         elem_offset += indices[i] * strides[i];
156-       }
157-       while  (!IndiceCanVectorize (elem_offset, inner_for_->loop_var ,
158-                                  inner_for_->extent , vector_size_,
159-                                  &analyzer_)) {
160-         vector_size_ /= 2 ;
161-       }
162-     } else  if  (vector_size_ <= vector_load_bits_max_ / buffer->dtype .bits ()) {
163-       //  dynamic shape load: get the vectorization condition
164-       dynamic_ = true ;
165-       PrimExpr offset = buffer.OffsetOf (indices).back ();
166-       condition_ = (FloorMod (offset, vector_size_) == 0 );
140+     //  4. Try to vectorize buffer load
141+     while  (!IndiceCanVectorize (elem_offset, inner_for_->loop_var ,
142+                                inner_for_->extent , vector_size_, &analyzer_)) {
143+       vector_size_ /= 2 ;
167144    }
168145  }
169146
170147  const  int  vector_load_bits_max_ = 128 ;
171148
172149  const  ForNode *inner_for_{};
173-   Map<Var, Range> iter_map_;
174150  bool  has_nonlocal_memory_access_ = false ;
175151  int  vector_size_ = 128 ;
176-   //  conditionally vectorize
177-   bool  dynamic_ = false ;
178-   PrimExpr condition_;
179152};
180153
181154class  VectorizeRewriter  : public  StmtExprMutator  {
182155public: 
183-   VectorizeRewriter (const  VectorizePlanResult &plan)
184-       : vector_size_(plan.vector_size), condition_(plan.condition),
185-         dynamic_ (plan.dynamic) {}
156+   VectorizeRewriter (int  vector_size) : vector_size_(vector_size) {}
186157
187158private: 
188159  Stmt VisitStmt_ (const  ForNode *node) final  {
@@ -197,23 +168,19 @@ class VectorizeRewriter : public StmtExprMutator {
197168      ICHECK (extent % vector_size_ == 0 )
198169          << " extent: " "  vector_size_: " 
199170      ICHECK (is_zero (fnode->min ));
200-       if  (!dynamic_) { //  check dynamic shape
201-         if  (extent == vector_size_) {
202-           fnode.CopyOnWrite ()->kind  = ForKind::kVectorized ;
203-           return  fnode;
204-         } else  {
205-           Var inner_var = Var (" vec" 
206-           Var outer_var = Var (old_var->name_hint );
207-           Map<Var, PrimExpr> vmap;
208-           vmap.Set (fnode->loop_var , outer_var * vector_size_ + inner_var);
209-           Stmt body = Substitute (fnode->body , vmap);
210-           body = For (inner_var, 0 , vector_size_, ForKind::kVectorized , body);
211-           body = For (outer_var, 0 , extent / vector_size_, fnode->kind , body,
212-                      fnode->thread_binding , fnode->annotations , fnode->span );
213-           return  body;
214-         }
215-       } else  {
171+       if  (extent == vector_size_) {
172+         fnode.CopyOnWrite ()->kind  = ForKind::kVectorized ;
216173        return  fnode;
174+       } else  {
175+         Var inner_var = Var (" vec" 
176+         Var outer_var = Var (old_var->name_hint );
177+         Map<Var, PrimExpr> vmap;
178+         vmap.Set (fnode->loop_var , outer_var * vector_size_ + inner_var);
179+         Stmt body = Substitute (fnode->body , vmap);
180+         body = For (inner_var, 0 , vector_size_, ForKind::kVectorized , body);
181+         body = For (outer_var, 0 , extent / vector_size_, fnode->kind , body,
182+                    fnode->thread_binding , fnode->annotations , fnode->span );
183+         return  body;
217184      }
218185    } else  {
219186      return  ret;
@@ -222,18 +189,25 @@ class VectorizeRewriter : public StmtExprMutator {
222189
223190  const  ForNode *inner_for_{};
224191  const  int  vector_size_;
225-   const  PrimExpr condition_;
226-   const  bool  dynamic_;
227192};
228193
229194int  GetVectorizeSize (const  For &loop) { return  VectorizePlanner ().Plan (loop); }
230195
231- VectorizePlanResult GetVectorizePlanResult (const  For &loop) {
232-   VectorizePlanner planner;
233-   int  vector_size = planner.Plan (loop);
234-   bool  dynamic = planner.GetDynamic ();
235-   PrimExpr condition = planner.GetCondition ();
236-   return  {vector_size, dynamic, condition};
196+ bool  CanProveIndependent (const  PrimExpr &expr, Var var,
197+                          arith::Analyzer *analyzer) {
198+   //  1. if var doesn't exist, it is independent
199+   bool  used_var = UsesVar (
200+       expr, [&](const  VarNode *v) { return  GetRef<Var>(v).same_as (var); });
201+   if  (!used_var) {
202+     return  true ;
203+   }
204+   //  2. if \forall v_1, v_2, f(v_1) == f(v_2), f is independent with v
205+   Var var_1 (" _t" dtype ());
206+   auto  expr_1 = Substitute (expr, {{var, var_1}});
207+   if  (analyzer->CanProveEqual (expr, expr_1)) {
208+     return  true ;
209+   }
210+   return  false ;
237211}
238212
239213bool  IndiceCanVectorize (const  PrimExpr &expr, Var var,
@@ -280,14 +254,13 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var,
280254}
281255
282256For VectorizeLoop (const  For &loop, int  vectorize_hint) {
283-   VectorizePlanResult res{128 , false , 0 };
284257  if  (vectorize_hint <= 0 ) {
285-     res =  GetVectorizePlanResult (loop) ;
286-     vectorize_hint = res. vector_size ;
258+     VectorizePlanner planner ;
259+     vectorize_hint = planner. Plan (loop) ;
287260  }
288261  if  (vectorize_hint == 1 )
289262    return  loop;
290-   auto  rewriter = VectorizeRewriter (res );
263+   auto  rewriter = VectorizeRewriter (vectorize_hint );
291264  return  Downcast<For>(rewriter (loop));
292265}
293266
0 commit comments