Skip to content

Commit c899ab9

Browse files
committed
lint fix
1 parent 195a159 commit c899ab9

File tree

2 files changed

+95
-75
lines changed

2 files changed

+95
-75
lines changed

src/transform/lower_intrin.cc

Lines changed: 72 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,13 @@ namespace tl {
4141
using namespace tir;
4242

4343
class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
44-
public:
44+
public:
4545
using IRMutatorWithAnalyzer::VisitExpr_;
4646
using IRMutatorWithAnalyzer::VisitStmt_;
4747
using FLowerGeneral = ffi::TypedFunction<PrimExpr(PrimExpr)>;
4848

49-
IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string mtriple = "")
49+
IntrinInjecter(arith::Analyzer *analyzer, std::string target,
50+
std::string mtriple = "")
5051
: IRMutatorWithAnalyzer(analyzer) {
5152
std::vector<std::string> patterns;
5253
patterns.push_back(target + ".FLowerIntrinsic");
@@ -59,7 +60,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
5960
patterns.push_back("default.FLowerIntrinsic");
6061
patterns.push_back("default.FLegalize");
6162

62-
for (const std::string& pattern : patterns)
63+
for (const std::string &pattern : patterns)
6364
if (Op::HasAttrMap(pattern)) {
6465
attr_maps_.push_back(Op::GetAttrMap<FLowerGeneral>(pattern));
6566
if (fma_ == nullptr) {
@@ -68,9 +69,9 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
6869
}
6970
}
7071

71-
PrimExpr VisitExpr_(const CallNode* op) final {
72-
if (auto* ptr_op = op->op.as<OpNode>()) {
73-
for (const auto& f_attr_map : attr_maps_) {
72+
PrimExpr VisitExpr_(const CallNode *op) final {
73+
if (auto *ptr_op = op->op.as<OpNode>()) {
74+
for (const auto &f_attr_map : attr_maps_) {
7475
FLowerGeneral f = f_attr_map.get(GetRef<Op>(ptr_op), nullptr);
7576
if (f != nullptr) {
7677
PrimExpr e = GetRef<PrimExpr>(op);
@@ -88,34 +89,35 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
8889
return IRMutatorWithAnalyzer::VisitExpr_(op);
8990
}
9091

91-
PrimExpr VisitExpr_(const AddNode* op) final {
92-
if (const MulNode* mb = op->b.as<MulNode>()) {
92+
PrimExpr VisitExpr_(const AddNode *op) final {
93+
if (const MulNode *mb = op->b.as<MulNode>()) {
9394
return MakeFMA(mb->a, mb->b, op->a, op);
94-
} else if (const MulNode* ma = op->a.as<MulNode>()) {
95+
} else if (const MulNode *ma = op->a.as<MulNode>()) {
9596
return MakeFMA(ma->a, ma->b, op->b, op);
9697
}
9798
return IRMutatorWithAnalyzer::VisitExpr_(op);
9899
}
99100

100101
// We use floordiv for integer analysis,
101102
// but will need to lower them to native truncdiv instructions
102-
PrimExpr VisitExpr_(const FloorDivNode* op) final {
103+
PrimExpr VisitExpr_(const FloorDivNode *op) final {
103104
auto e = GetRef<PrimExpr>(op);
104105
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
105106
op = ret.as<FloorDivNode>();
106-
if (op == nullptr) return ret;
107+
if (op == nullptr)
108+
return ret;
107109
int shift;
108-
const DataType& dtype = op->dtype;
110+
const DataType &dtype = op->dtype;
109111
ICHECK(dtype.is_int() || dtype.is_uint());
110112

111113
auto is_ceildiv_numerator = [&]() {
112114
std::vector<std::pair<PrimExpr, int>> terms;
113-
std::function<void(const PrimExpr&, int)> collect =
114-
[&](const PrimExpr& expr, int sign) {
115-
if (const auto* add = expr.as<AddNode>()) {
115+
std::function<void(const PrimExpr &, int)> collect =
116+
[&](const PrimExpr &expr, int sign) {
117+
if (const auto *add = expr.as<AddNode>()) {
116118
collect(add->a, sign);
117119
collect(add->b, sign);
118-
} else if (const auto* sub = expr.as<SubNode>()) {
120+
} else if (const auto *sub = expr.as<SubNode>()) {
119121
collect(sub->a, sign);
120122
collect(sub->b, -sign);
121123
} else {
@@ -125,10 +127,10 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
125127
collect(op->a, 1);
126128
int den_coeff = 0;
127129
int64_t const_term = 0;
128-
for (const auto& term : terms) {
129-
const PrimExpr& expr = term.first;
130+
for (const auto &term : terms) {
131+
const PrimExpr &expr = term.first;
130132
int sign = term.second;
131-
if (const auto* imm = expr.as<IntImmNode>()) {
133+
if (const auto *imm = expr.as<IntImmNode>()) {
132134
const_term += static_cast<int64_t>(sign) * imm->value;
133135
} else if (analyzer_->CanProveEqual(expr, op->b)) {
134136
den_coeff += sign;
@@ -146,7 +148,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
146148

147149
if (analyzer_->CanProveGreaterEqual(op->b, 0)) {
148150
// Common path, positive divisor
149-
if (analyzer_->CanProveGreaterEqual(op->a, 0) || analyzer_->CanProveGreaterEqual(e, 0)) {
151+
if (analyzer_->CanProveGreaterEqual(op->a, 0) ||
152+
analyzer_->CanProveGreaterEqual(e, 0)) {
150153
return truncdiv(op->a, op->b);
151154
}
152155

@@ -155,9 +158,10 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
155158
arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a);
156159
if (const_int_bound->min_value < 0 &&
157160
const_int_bound->min_value >
158-
-(Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value)) {
159-
// The goal is to write floordiv(a,b) in terms of truncdiv, without using
160-
// negative operands.
161+
-(Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))
162+
->value)) {
163+
// The goal is to write floordiv(a,b) in terms of truncdiv, without
164+
// using negative operands.
161165
//
162166
// For any integer c
163167
//
@@ -188,7 +192,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
188192
// == truncdiv(a + b*c, b) - c
189193
IntImm min(op->a->dtype.element_of(), const_int_bound->min_value);
190194
PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b);
191-
PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv);
195+
PrimExpr offset_numerator =
196+
analyzer_->Simplify(op->a + op->b * ceildiv);
192197
return truncdiv(offset_numerator, op->b) - ceildiv;
193198
}
194199

@@ -198,7 +203,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
198203
// condition on b >= 0.
199204
// truncmod(a, b) < 0 will implies ceildiv,
200205
// So we need to correct these cases.
201-
if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) {
206+
if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) &&
207+
support_bitwise_op_) {
202208
// equivalent to rdiv + (rmod >= 0 ? 0: -1);
203209
return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1));
204210
} else {
@@ -216,27 +222,29 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
216222
auto rdiv = tir::Var("rdiv", dtype);
217223
// b >= 0 => (rmod >=0 ? rdiv : rdiv - 1)
218224
// b < 0 => (rmod <= 0 ? rdiv : rdiv - 1)
219-
PrimExpr let_rdiv =
220-
tir::Let(rdiv, truncdiv(op->a, op->b),
221-
tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv,
222-
rdiv - make_const(dtype, 1)));
225+
PrimExpr let_rdiv = tir::Let(
226+
rdiv, truncdiv(op->a, op->b),
227+
tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0),
228+
rdiv, rdiv - make_const(dtype, 1)));
223229
return Let(rmod, truncmod(op->a, op->b), let_rdiv);
224230
}
225231
}
226232
}
227233

228-
PrimExpr VisitExpr_(const FloorModNode* op) final {
234+
PrimExpr VisitExpr_(const FloorModNode *op) final {
229235
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
230236
op = ret.as<FloorModNode>();
231-
if (op == nullptr) return ret;
237+
if (op == nullptr)
238+
return ret;
232239
// Lower floordiv to native truncdiv.
233240
int shift;
234-
const DataType& dtype = op->dtype;
241+
const DataType &dtype = op->dtype;
235242
ICHECK(dtype.is_int() || dtype.is_uint());
236243

237244
if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) {
238245
// lower to masking if possible.
239-
int64_t mask = (static_cast<int64_t>(1) << static_cast<int64_t>(shift)) - 1;
246+
int64_t mask =
247+
(static_cast<int64_t>(1) << static_cast<int64_t>(shift)) - 1;
240248
return op->a & make_const(dtype, mask);
241249
}
242250

@@ -251,7 +259,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
251259
arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a);
252260
if (const_int_bound->min_value < 0 &&
253261
const_int_bound->min_value >
254-
-(Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value)) {
262+
-(Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))
263+
->value)) {
255264
// The goal is to write floormod(a,b) in terms of truncdiv and truncmod,
256265
// without using negative operands.
257266
//
@@ -283,7 +292,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
283292
// == truncmod(a + b*c, b)
284293
IntImm min(op->a->dtype.element_of(), const_int_bound->min_value);
285294
PrimExpr ceildiv = truncdiv(-min + (op->b - 1), op->b);
286-
PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv);
295+
PrimExpr offset_numerator =
296+
analyzer_->Simplify(op->a + op->b * ceildiv);
287297
return truncmod(offset_numerator, op->b);
288298
}
289299

@@ -292,7 +302,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
292302
// mod(a, b) < 0 will imply we are doing ceildiv,
293303
// So we need to correct these cases.
294304
PrimExpr rmod = truncmod(op->a, op->b);
295-
if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) && support_bitwise_op_) {
305+
if ((dtype == DataType::Int(32) || dtype == DataType::Int(64)) &&
306+
support_bitwise_op_) {
296307
// (rmod >> shift) & b
297308
// -> (rmod >= 0 ? 0: -1) & b
298309
// -> rmod >= 0 ? 0 : b
@@ -304,23 +315,25 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
304315
} else {
305316
if (dtype.is_float()) {
306317
// a - floor(a / b) * b
307-
return op->a - (VisitExpr_(tvm::floor(op->a / op->b).as<CallNode>()) * op->b);
318+
return op->a -
319+
(VisitExpr_(tvm::floor(op->a / op->b).as<CallNode>()) * op->b);
308320
} else {
309321
// uncommon case
310-
DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident";
322+
DLOG(INFO)
323+
<< "LowerFloorMod: Cannot decide the sign of divsor and divident";
311324
auto rmod = tir::Var("rmod", dtype);
312325
// b > 0 && rmod >= 0 -> rmod
313326
// b > 0 && rmod < 0 -> rmod + b
314327
// b < 0 && rmod < 0 -> rmod
315328
// b < 0 && rmod > 0 -> rmod + b
316-
return Let(
317-
rmod, truncmod(op->a, op->b),
318-
Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod, rmod + op->b));
329+
return Let(rmod, truncmod(op->a, op->b),
330+
Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0),
331+
rmod, rmod + op->b));
319332
}
320333
}
321334
}
322335

323-
PrimExpr VisitExpr_(const MaxNode* op) final {
336+
PrimExpr VisitExpr_(const MaxNode *op) final {
324337
using namespace arith;
325338
PVar<PrimExpr> x, y;
326339
PVar<IntImm> c;
@@ -332,7 +345,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
332345
return IRMutatorWithAnalyzer::VisitExpr_(op);
333346
}
334347

335-
PrimExpr VisitExpr_(const EQNode* op) final {
348+
PrimExpr VisitExpr_(const EQNode *op) final {
336349
using namespace arith;
337350
PVar<PrimExpr> x, y;
338351
auto e = GetRef<PrimExpr>(op);
@@ -342,7 +355,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
342355
return IRMutatorWithAnalyzer::VisitExpr_(op);
343356
}
344357

345-
PrimExpr VisitExpr_(const NENode* op) final {
358+
PrimExpr VisitExpr_(const NENode *op) final {
346359
using namespace arith;
347360
PVar<PrimExpr> x, y;
348361
auto e = GetRef<PrimExpr>(op);
@@ -352,14 +365,14 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
352365
return IRMutatorWithAnalyzer::VisitExpr_(op);
353366
}
354367

355-
private:
356-
PrimExpr SwapBroadcastCast(const PrimExpr& e) {
368+
private:
369+
PrimExpr SwapBroadcastCast(const PrimExpr &e) {
357370
// Try to change broadcast(cast(x)) to cast(broadcast(x))
358371
// For some targets, LLVM will generate more efficient FMA
359372
// instruction with the latter. For example, vmla vs. vmlal
360373
// on ARM.
361-
if (const BroadcastNode* bcast = e.as<BroadcastNode>()) {
362-
if (const CastNode* cast = bcast->value.as<CastNode>()) {
374+
if (const BroadcastNode *bcast = e.as<BroadcastNode>()) {
375+
if (const CastNode *cast = bcast->value.as<CastNode>()) {
363376
auto should_swap = [&]() {
364377
// Maintain behaviour (int8 -> int16, fp16 -> fp32).
365378
if (cast->dtype.bits() == cast->value.dtype().bits() * 2) {
@@ -385,14 +398,16 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
385398
return e;
386399
}
387400

388-
PrimExpr MakeFMA(const PrimExpr& a, const PrimExpr& b, const PrimExpr& c, const AddNode* op) {
401+
PrimExpr MakeFMA(const PrimExpr &a, const PrimExpr &b, const PrimExpr &c,
402+
const AddNode *op) {
389403
// emit fma instruction: a * b + c
390404
PrimExpr lhs = SwapBroadcastCast(a);
391405
PrimExpr rhs = SwapBroadcastCast(b);
392406

393407
if (fma_ != nullptr && op->dtype.is_float()) {
394408
PrimExpr r = fma_(Call(op->dtype, builtin::fma(), {lhs, rhs, c}));
395-
if (r.defined()) return this->VisitExpr(r);
409+
if (r.defined())
410+
return this->VisitExpr(r);
396411
} else {
397412
if (!lhs.same_as(a) || !rhs.same_as(b)) {
398413
PrimExpr mul = this->VisitExpr(Mul(lhs, rhs));
@@ -408,7 +423,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
408423
bool support_bitwise_op_{true};
409424
};
410425

411-
Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) {
426+
Stmt LowerIntrinStmt(Stmt stmt, const std::string &target) {
412427
arith::Analyzer analyzer;
413428
return IntrinInjecter(&analyzer, target)(std::move(stmt));
414429
}
@@ -418,13 +433,13 @@ using namespace tir::transform;
418433

419434
tvm::transform::Pass LowerIntrin() {
420435
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
421-
auto* n = f.CopyOnWrite();
436+
auto *n = f.CopyOnWrite();
422437
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
423438
ICHECK(target.defined()) << "LowerIntrin: Require the target attribute";
424439
arith::Analyzer analyzer;
425440
auto mtriple = target.value()->GetAttr<String>("mtriple", "");
426-
n->body =
427-
IntrinInjecter(&analyzer, target.value()->kind->name, mtriple.value())(std::move(n->body));
441+
n->body = IntrinInjecter(&analyzer, target.value()->kind->name,
442+
mtriple.value())(std::move(n->body));
428443
return f;
429444
};
430445
return CreatePrimFuncPass(pass_func, 0, "tl.LowerIntrin", {});
@@ -435,7 +450,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
435450
refl::GlobalDef().def("tl.transform.LowerIntrin", LowerIntrin);
436451
});
437452

438-
} // namespace transform
453+
} // namespace transform
439454

440-
} // namespace tl
441-
} // namespace tvm
455+
} // namespace tl
456+
} // namespace tvm

0 commit comments

Comments
 (0)