Skip to content

Commit 592a1f6

Browse files
ZihengJiangtqchen
authored andcommitted
[CODEGEN] Detect broadcast(cast(x)) pattern in FMA (#551)
* [CODEGEN] Detect broadcast(cast(x)) pattern in FMA * [CODEGEN] Improve * [CODEGEN] Fix
1 parent fde9b57 commit 592a1f6

File tree

1 file changed

+36
-8
lines changed

1 file changed

+36
-8
lines changed

src/pass/lower_intrin.cc

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,50 @@ class IntrinInjecter : public IRMutator {
3434
}
3535

3636
Expr Mutate_(const Add* op, const Expr& e) final {
37-
if (fma_ == nullptr || !op->type.is_float()) {
38-
return IRMutator::Mutate_(op, e);
39-
}
4037
if (const Mul* mb = op->b.as<Mul>()) {
41-
Expr r = (*fma_)(Call::make(
42-
op->type, "fma", {mb->a, mb->b, op->a}, Call::PureIntrinsic));
43-
if (r.defined()) return this->Mutate(r);
38+
return MakeFMA(mb->a, mb->b, op->a, op, e);
4439
} else if (const Mul* ma = op->a.as<Mul>()) {
40+
return MakeFMA(ma->a, ma->b, op->b, op, e);
41+
}
42+
return IRMutator::Mutate_(op, e);
43+
}
44+
45+
private:
46+
Expr SwapBroadcastCast(const Expr& e) {
47+
// Try to change broadcast(cast(x)) to cast(broadcast(x))
48+
// For some targets, LLVM will generate more efficient FMA
49+
// instruction with the latter. For example, vmla vs. vmlal
50+
// on ARM.
51+
if (const Broadcast* bcast = e.as<Broadcast>()) {
52+
if (const Cast* cast = bcast->value.as<Cast>()) {
53+
if (cast->type.bits() == cast->value.type().bits() * 2) {
54+
Expr new_bcast = Broadcast::make(cast->value, bcast->lanes);
55+
return Cast::make(bcast->type, new_bcast);
56+
}
57+
}
58+
}
59+
return e;
60+
}
61+
62+
Expr MakeFMA(const Expr& a, const Expr& b, const Expr& c,
63+
const Add* op, const Expr& e) {
64+
// emit fma instruction: a * b + c
65+
Expr lhs = SwapBroadcastCast(a);
66+
Expr rhs = SwapBroadcastCast(b);
67+
68+
if (fma_ != nullptr && op->type.is_float()) {
4569
Expr r = (*fma_)(Call::make(
46-
op->type, "fma", {ma->a, ma->b, op->b}, Call::PureIntrinsic));
70+
op->type, "fma", {lhs, rhs, c}, Call::PureIntrinsic));
4771
if (r.defined()) return this->Mutate(r);
72+
} else {
73+
if (!lhs.same_as(a) || !rhs.same_as(b)) {
74+
Expr mul = this->Mutate(Mul::make(lhs, rhs));
75+
return Add::make(mul, this->Mutate(c));
76+
}
4877
}
4978
return IRMutator::Mutate_(op, e);
5079
}
5180

52-
private:
5381
Expr ApplyPattern(const std::string& name, const Expr& e) {
5482
for (size_t i = 0; i < patterns_.size(); ++i) {
5583
std::string& p = patterns_[i];

0 commit comments

Comments
 (0)