@@ -41,12 +41,13 @@ namespace tl {
4141using namespace tir ;
4242
4343class IntrinInjecter : public tvm ::arith::IRMutatorWithAnalyzer {
44- public:
44+ public:
4545 using IRMutatorWithAnalyzer::VisitExpr_;
4646 using IRMutatorWithAnalyzer::VisitStmt_;
4747 using FLowerGeneral = ffi::TypedFunction<PrimExpr(PrimExpr)>;
4848
49- IntrinInjecter (arith::Analyzer* analyzer, std::string target, std::string mtriple = " " )
49+ IntrinInjecter (arith::Analyzer *analyzer, std::string target,
50+ std::string mtriple = " " )
5051 : IRMutatorWithAnalyzer(analyzer) {
5152 std::vector<std::string> patterns;
5253 patterns.push_back (target + " .FLowerIntrinsic" );
@@ -59,7 +60,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
5960 patterns.push_back (" default.FLowerIntrinsic" );
6061 patterns.push_back (" default.FLegalize" );
6162
62- for (const std::string& pattern : patterns)
63+ for (const std::string & pattern : patterns)
6364 if (Op::HasAttrMap (pattern)) {
6465 attr_maps_.push_back (Op::GetAttrMap<FLowerGeneral>(pattern));
6566 if (fma_ == nullptr ) {
@@ -68,9 +69,9 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
6869 }
6970 }
7071
71- PrimExpr VisitExpr_ (const CallNode* op) final {
72- if (auto * ptr_op = op->op .as <OpNode>()) {
73- for (const auto & f_attr_map : attr_maps_) {
72+ PrimExpr VisitExpr_ (const CallNode * op) final {
73+ if (auto * ptr_op = op->op .as <OpNode>()) {
74+ for (const auto & f_attr_map : attr_maps_) {
7475 FLowerGeneral f = f_attr_map.get (GetRef<Op>(ptr_op), nullptr );
7576 if (f != nullptr ) {
7677 PrimExpr e = GetRef<PrimExpr>(op);
@@ -88,34 +89,35 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
8889 return IRMutatorWithAnalyzer::VisitExpr_ (op);
8990 }
9091
91- PrimExpr VisitExpr_ (const AddNode* op) final {
92- if (const MulNode* mb = op->b .as <MulNode>()) {
92+ PrimExpr VisitExpr_ (const AddNode * op) final {
93+ if (const MulNode * mb = op->b .as <MulNode>()) {
9394 return MakeFMA (mb->a , mb->b , op->a , op);
94- } else if (const MulNode* ma = op->a .as <MulNode>()) {
95+ } else if (const MulNode * ma = op->a .as <MulNode>()) {
9596 return MakeFMA (ma->a , ma->b , op->b , op);
9697 }
9798 return IRMutatorWithAnalyzer::VisitExpr_ (op);
9899 }
99100
100101 // We use floordiv for integer analysis,
101102 // but will need to lower them to native truncdiv instructions
102- PrimExpr VisitExpr_ (const FloorDivNode* op) final {
103+ PrimExpr VisitExpr_ (const FloorDivNode * op) final {
103104 auto e = GetRef<PrimExpr>(op);
104105 PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_ (op);
105106 op = ret.as <FloorDivNode>();
106- if (op == nullptr ) return ret;
107+ if (op == nullptr )
108+ return ret;
107109 int shift;
108- const DataType& dtype = op->dtype ;
110+ const DataType & dtype = op->dtype ;
109111 ICHECK (dtype.is_int () || dtype.is_uint ());
110112
111113 auto is_ceildiv_numerator = [&]() {
112114 std::vector<std::pair<PrimExpr, int >> terms;
113- std::function<void (const PrimExpr&, int )> collect =
114- [&](const PrimExpr& expr, int sign) {
115- if (const auto * add = expr.as <AddNode>()) {
115+ std::function<void (const PrimExpr &, int )> collect =
116+ [&](const PrimExpr & expr, int sign) {
117+ if (const auto * add = expr.as <AddNode>()) {
116118 collect (add->a , sign);
117119 collect (add->b , sign);
118- } else if (const auto * sub = expr.as <SubNode>()) {
120+ } else if (const auto * sub = expr.as <SubNode>()) {
119121 collect (sub->a , sign);
120122 collect (sub->b , -sign);
121123 } else {
@@ -125,10 +127,10 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
125127 collect (op->a , 1 );
126128 int den_coeff = 0 ;
127129 int64_t const_term = 0 ;
128- for (const auto & term : terms) {
129- const PrimExpr& expr = term.first ;
130+ for (const auto & term : terms) {
131+ const PrimExpr & expr = term.first ;
130132 int sign = term.second ;
131- if (const auto * imm = expr.as <IntImmNode>()) {
133+ if (const auto * imm = expr.as <IntImmNode>()) {
132134 const_term += static_cast <int64_t >(sign) * imm->value ;
133135 } else if (analyzer_->CanProveEqual (expr, op->b )) {
134136 den_coeff += sign;
@@ -146,7 +148,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
146148
147149 if (analyzer_->CanProveGreaterEqual (op->b , 0 )) {
148150 // Common path, positive divisor
149- if (analyzer_->CanProveGreaterEqual (op->a , 0 ) || analyzer_->CanProveGreaterEqual (e, 0 )) {
151+ if (analyzer_->CanProveGreaterEqual (op->a , 0 ) ||
152+ analyzer_->CanProveGreaterEqual (e, 0 )) {
150153 return truncdiv (op->a , op->b );
151154 }
152155
@@ -155,9 +158,10 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
155158 arith::ConstIntBound const_int_bound = analyzer_->const_int_bound (op->a );
156159 if (const_int_bound->min_value < 0 &&
157160 const_int_bound->min_value >
158- -(Downcast<IntImm>(tvm::max_value (op->a ->dtype .element_of ()))->value )) {
159- // The goal is to write floordiv(a,b) in terms of truncdiv, without using
160- // negative operands.
161+ -(Downcast<IntImm>(tvm::max_value (op->a ->dtype .element_of ()))
162+ ->value )) {
163+ // The goal is to write floordiv(a,b) in terms of truncdiv, without
164+ // using negative operands.
161165 //
162166 // For any integer c
163167 //
@@ -188,7 +192,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
188192 // == truncdiv(a + b*c, b) - c
189193 IntImm min (op->a ->dtype .element_of (), const_int_bound->min_value );
190194 PrimExpr ceildiv = truncdiv ((op->b - 1 ) - min, op->b );
191- PrimExpr offset_numerator = analyzer_->Simplify (op->a + op->b * ceildiv);
195+ PrimExpr offset_numerator =
196+ analyzer_->Simplify (op->a + op->b * ceildiv);
192197 return truncdiv (offset_numerator, op->b ) - ceildiv;
193198 }
194199
@@ -198,7 +203,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
198203 // condition on b >= 0.
199204 // truncmod(a, b) < 0 will implies ceildiv,
200205 // So we need to correct these cases.
201- if ((dtype == DataType::Int (32 ) || dtype == DataType::Int (64 )) && support_bitwise_op_) {
206+ if ((dtype == DataType::Int (32 ) || dtype == DataType::Int (64 )) &&
207+ support_bitwise_op_) {
202208 // equivalent to rdiv + (rmod >= 0 ? 0: -1);
203209 return rdiv + (rmod >> make_const (dtype, dtype.bits () - 1 ));
204210 } else {
@@ -216,27 +222,29 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
216222 auto rdiv = tir::Var (" rdiv" , dtype);
217223 // b >= 0 => (rmod >=0 ? rdiv : rdiv - 1)
218224 // b < 0 => (rmod <= 0 ? rdiv : rdiv - 1)
219- PrimExpr let_rdiv =
220- tir::Let ( rdiv, truncdiv (op->a , op->b ),
221- tir::Select ((op->b >= 0 && rmod >= 0 ) || (op->b < 0 && rmod <= 0 ), rdiv ,
222- rdiv - make_const (dtype, 1 )));
225+ PrimExpr let_rdiv = tir::Let (
226+ rdiv, truncdiv (op->a , op->b ),
227+ tir::Select ((op->b >= 0 && rmod >= 0 ) || (op->b < 0 && rmod <= 0 ),
228+ rdiv, rdiv - make_const (dtype, 1 )));
223229 return Let (rmod, truncmod (op->a , op->b ), let_rdiv);
224230 }
225231 }
226232 }
227233
228- PrimExpr VisitExpr_ (const FloorModNode* op) final {
234+ PrimExpr VisitExpr_ (const FloorModNode * op) final {
229235 PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_ (op);
230236 op = ret.as <FloorModNode>();
231- if (op == nullptr ) return ret;
237+ if (op == nullptr )
238+ return ret;
232239 // Lower floordiv to native truncdiv.
233240 int shift;
234- const DataType& dtype = op->dtype ;
241+ const DataType & dtype = op->dtype ;
235242 ICHECK (dtype.is_int () || dtype.is_uint ());
236243
237244 if (support_bitwise_op_ && is_const_power_of_two_integer (op->b , &shift)) {
238245 // lower to masking if possible.
239- int64_t mask = (static_cast <int64_t >(1 ) << static_cast <int64_t >(shift)) - 1 ;
246+ int64_t mask =
247+ (static_cast <int64_t >(1 ) << static_cast <int64_t >(shift)) - 1 ;
240248 return op->a & make_const (dtype, mask);
241249 }
242250
@@ -251,7 +259,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
251259 arith::ConstIntBound const_int_bound = analyzer_->const_int_bound (op->a );
252260 if (const_int_bound->min_value < 0 &&
253261 const_int_bound->min_value >
254- -(Downcast<IntImm>(tvm::max_value (op->a ->dtype .element_of ()))->value )) {
262+ -(Downcast<IntImm>(tvm::max_value (op->a ->dtype .element_of ()))
263+ ->value )) {
255264 // The goal is to write floormod(a,b) in terms of truncdiv and truncmod,
256265 // without using negative operands.
257266 //
@@ -283,7 +292,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
283292 // == truncmod(a + b*c, b)
284293 IntImm min (op->a ->dtype .element_of (), const_int_bound->min_value );
285294 PrimExpr ceildiv = truncdiv (-min + (op->b - 1 ), op->b );
286- PrimExpr offset_numerator = analyzer_->Simplify (op->a + op->b * ceildiv);
295+ PrimExpr offset_numerator =
296+ analyzer_->Simplify (op->a + op->b * ceildiv);
287297 return truncmod (offset_numerator, op->b );
288298 }
289299
@@ -292,7 +302,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
292302 // mod(a, b) < 0 will imply we are doing ceildiv,
293303 // So we need to correct these cases.
294304 PrimExpr rmod = truncmod (op->a , op->b );
295- if ((dtype == DataType::Int (32 ) || dtype == DataType::Int (64 )) && support_bitwise_op_) {
305+ if ((dtype == DataType::Int (32 ) || dtype == DataType::Int (64 )) &&
306+ support_bitwise_op_) {
296307 // (rmod >> shift) & b
297308 // -> (rmod >= 0 ? 0: -1) & b
298309 // -> rmod >= 0 ? 0 : b
@@ -304,23 +315,25 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
304315 } else {
305316 if (dtype.is_float ()) {
306317 // a - floor(a / b) * b
307- return op->a - (VisitExpr_ (tvm::floor (op->a / op->b ).as <CallNode>()) * op->b );
318+ return op->a -
319+ (VisitExpr_ (tvm::floor (op->a / op->b ).as <CallNode>()) * op->b );
308320 } else {
309321 // uncommon case
310- DLOG (INFO) << " LowerFloorMod: Cannot decide the sign of divsor and divident" ;
322+ DLOG (INFO)
323+ << " LowerFloorMod: Cannot decide the sign of divsor and divident" ;
311324 auto rmod = tir::Var (" rmod" , dtype);
312325 // b > 0 && rmod >= 0 -> rmod
313326 // b > 0 && rmod < 0 -> rmod + b
314327 // b < 0 && rmod < 0 -> rmod
315328 // b < 0 && rmod > 0 -> rmod + b
316- return Let (
317- rmod, truncmod ( op->a , op->b ),
318- Select ((op-> b >= 0 && rmod >= 0 ) || (op-> b < 0 && rmod <= 0 ), rmod, rmod + op->b ));
329+ return Let (rmod, truncmod (op-> a , op-> b ),
330+ Select (( op->b >= 0 && rmod >= 0 ) || ( op->b < 0 && rmod <= 0 ),
331+ rmod, rmod + op->b ));
319332 }
320333 }
321334 }
322335
323- PrimExpr VisitExpr_ (const MaxNode* op) final {
336+ PrimExpr VisitExpr_ (const MaxNode * op) final {
324337 using namespace arith ;
325338 PVar<PrimExpr> x, y;
326339 PVar<IntImm> c;
@@ -332,7 +345,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
332345 return IRMutatorWithAnalyzer::VisitExpr_ (op);
333346 }
334347
335- PrimExpr VisitExpr_ (const EQNode* op) final {
348+ PrimExpr VisitExpr_ (const EQNode * op) final {
336349 using namespace arith ;
337350 PVar<PrimExpr> x, y;
338351 auto e = GetRef<PrimExpr>(op);
@@ -342,7 +355,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
342355 return IRMutatorWithAnalyzer::VisitExpr_ (op);
343356 }
344357
345- PrimExpr VisitExpr_ (const NENode* op) final {
358+ PrimExpr VisitExpr_ (const NENode * op) final {
346359 using namespace arith ;
347360 PVar<PrimExpr> x, y;
348361 auto e = GetRef<PrimExpr>(op);
@@ -352,14 +365,14 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
352365 return IRMutatorWithAnalyzer::VisitExpr_ (op);
353366 }
354367
355- private:
356- PrimExpr SwapBroadcastCast (const PrimExpr& e) {
368+ private:
369+ PrimExpr SwapBroadcastCast (const PrimExpr & e) {
357370 // Try to change broadcast(cast(x)) to cast(broadcast(x))
358371 // For some targets, LLVM will generate more efficient FMA
359372 // instruction with the latter. For example, vmla vs. vmlal
360373 // on ARM.
361- if (const BroadcastNode* bcast = e.as <BroadcastNode>()) {
362- if (const CastNode* cast = bcast->value .as <CastNode>()) {
374+ if (const BroadcastNode * bcast = e.as <BroadcastNode>()) {
375+ if (const CastNode * cast = bcast->value .as <CastNode>()) {
363376 auto should_swap = [&]() {
364377 // Maintain behaviour (int8 -> int16, fp16 -> fp32).
365378 if (cast->dtype .bits () == cast->value .dtype ().bits () * 2 ) {
@@ -385,14 +398,16 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
385398 return e;
386399 }
387400
388- PrimExpr MakeFMA (const PrimExpr& a, const PrimExpr& b, const PrimExpr& c, const AddNode* op) {
401+ PrimExpr MakeFMA (const PrimExpr &a, const PrimExpr &b, const PrimExpr &c,
402+ const AddNode *op) {
389403 // emit fma instruction: a * b + c
390404 PrimExpr lhs = SwapBroadcastCast (a);
391405 PrimExpr rhs = SwapBroadcastCast (b);
392406
393407 if (fma_ != nullptr && op->dtype .is_float ()) {
394408 PrimExpr r = fma_ (Call (op->dtype , builtin::fma (), {lhs, rhs, c}));
395- if (r.defined ()) return this ->VisitExpr (r);
409+ if (r.defined ())
410+ return this ->VisitExpr (r);
396411 } else {
397412 if (!lhs.same_as (a) || !rhs.same_as (b)) {
398413 PrimExpr mul = this ->VisitExpr (Mul (lhs, rhs));
@@ -408,7 +423,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
408423 bool support_bitwise_op_{true };
409424};
410425
411- Stmt LowerIntrinStmt (Stmt stmt, const std::string& target) {
426+ Stmt LowerIntrinStmt (Stmt stmt, const std::string & target) {
412427 arith::Analyzer analyzer;
413428 return IntrinInjecter (&analyzer, target)(std::move (stmt));
414429}
@@ -418,13 +433,13 @@ using namespace tir::transform;
418433
419434tvm::transform::Pass LowerIntrin () {
420435 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
421- auto * n = f.CopyOnWrite ();
436+ auto * n = f.CopyOnWrite ();
422437 auto target = f->GetAttr <Target>(tvm::attr::kTarget );
423438 ICHECK (target.defined ()) << " LowerIntrin: Require the target attribute" ;
424439 arith::Analyzer analyzer;
425440 auto mtriple = target.value ()->GetAttr <String>(" mtriple" , " " );
426- n->body =
427- IntrinInjecter (&analyzer, target. value ()-> kind -> name , mtriple.value ())(std::move (n->body ));
441+ n->body = IntrinInjecter (&analyzer, target. value ()-> kind -> name ,
442+ mtriple.value ())(std::move (n->body ));
428443 return f;
429444 };
430445 return CreatePrimFuncPass (pass_func, 0 , " tl.LowerIntrin" , {});
@@ -435,7 +450,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
435450 refl::GlobalDef ().def (" tl.transform.LowerIntrin" , LowerIntrin);
436451});
437452
438- } // namespace transform
453+ } // namespace transform
439454
440- } // namespace tl
441- } // namespace tvm
455+ } // namespace tl
456+ } // namespace tvm
0 commit comments