@@ -34,22 +34,50 @@ class IntrinInjecter : public IRMutator {
3434 }
3535
3636 Expr Mutate_ (const Add* op, const Expr& e) final {
37- if (fma_ == nullptr || !op->type .is_float ()) {
38- return IRMutator::Mutate_ (op, e);
39- }
4037 if (const Mul* mb = op->b .as <Mul>()) {
41- Expr r = (*fma_)(Call::make (
42- op->type , " fma" , {mb->a , mb->b , op->a }, Call::PureIntrinsic));
43- if (r.defined ()) return this ->Mutate (r);
38+ return MakeFMA (mb->a , mb->b , op->a , op, e);
4439 } else if (const Mul* ma = op->a .as <Mul>()) {
40+ return MakeFMA (ma->a , ma->b , op->b , op, e);
41+ }
42+ return IRMutator::Mutate_ (op, e);
43+ }
44+
45+ private:
46+ Expr SwapBroadcastCast (const Expr& e) {
47+ // Try to change broadcast(cast(x)) to cast(broadcast(x))
48+ // For some targets, LLVM will generate more efficient FMA
49+ // instruction with the latter. For example, vmla vs. vmlal
50+ // on ARM.
51+ if (const Broadcast* bcast = e.as <Broadcast>()) {
52+ if (const Cast* cast = bcast->value .as <Cast>()) {
53+ if (cast->type .bits () == cast->value .type ().bits () * 2 ) {
54+ Expr new_bcast = Broadcast::make (cast->value , bcast->lanes );
55+ return Cast::make (bcast->type , new_bcast);
56+ }
57+ }
58+ }
59+ return e;
60+ }
61+
62+ Expr MakeFMA (const Expr& a, const Expr& b, const Expr& c,
63+ const Add* op, const Expr& e) {
64+ // emit fma instruction: a * b + c
65+ Expr lhs = SwapBroadcastCast (a);
66+ Expr rhs = SwapBroadcastCast (b);
67+
68+ if (fma_ != nullptr && op->type .is_float ()) {
4569 Expr r = (*fma_)(Call::make (
46- op->type , " fma" , {ma-> a , ma-> b , op-> b }, Call::PureIntrinsic));
70+ op->type , " fma" , {lhs, rhs, c }, Call::PureIntrinsic));
4771 if (r.defined ()) return this ->Mutate (r);
72+ } else {
73+ if (!lhs.same_as (a) || !rhs.same_as (b)) {
74+ Expr mul = this ->Mutate (Mul::make (lhs, rhs));
75+ return Add::make (mul, this ->Mutate (c));
76+ }
4877 }
4978 return IRMutator::Mutate_ (op, e);
5079 }
5180
52- private:
5381 Expr ApplyPattern (const std::string& name, const Expr& e) {
5482 for (size_t i = 0 ; i < patterns_.size (); ++i) {
5583 std::string& p = patterns_[i];
0 commit comments