A few improvement in fcmla pattern recognitions#173818
Draft
A few improvement in fcmla pattern recognitions#173818
Conversation
* Relax requirement on exact fastmath flag matching It should be enough to require all flags to include reassoc * Fallback to treating non-reassoc additions as addends to discover more deinterleaving opportunities.
You can test this locally with the following command:git-clang-format --diff origin/main HEAD --extensions cpp -- llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp --diff_from_common_commit
View the diff from clang-format here.diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index f86459f80..a9e4bab0a 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -282,7 +282,10 @@ public:
CompositeNode *CommonNode{nullptr};
ComplexDeinterleavingRotation Rotation;
bool AllowContract;
- bool IsCommonReal() const { return Rotation == ComplexDeinterleavingRotation::Rotation_0 || Rotation == ComplexDeinterleavingRotation::Rotation_180; }
+ bool IsCommonReal() const {
+ return Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
+ Rotation == ComplexDeinterleavingRotation::Rotation_180;
+ }
};
explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
@@ -384,8 +387,8 @@ private:
}
CompositeNode *negCompositeNode(CompositeNode *Node) {
- auto NegNode = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
- nullptr, nullptr);
+ auto NegNode = prepareCompositeNode(
+ ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
NegNode->Opcode = Instruction::FNeg;
NegNode->addOperand(Node);
return submitCompositeNode(NegNode);
@@ -403,8 +406,9 @@ private:
/// 270: r: cr + ai * bi
/// i: ci - ai * br
CompositeNode *identifyPartialMul(Instruction *Real, Instruction *Imag,
- bool RealPositive=true, bool ImagPositive=true,
- PartialMulNode *PN=nullptr);
+ bool RealPositive = true,
+ bool ImagPositive = true,
+ PartialMulNode *PN = nullptr);
/// Identifies a complex add pattern and its rotation, based on the following
/// patterns.
@@ -436,9 +440,9 @@ private:
CompositeNode *Accumulator,
bool &AccumPositive);
- /// Extract one addend that have both real and imaginary parts positive/negative.
- CompositeNode *extractAddend(AddendList &RealAddends,
- AddendList &ImagAddends,
+ /// Extract one addend that have both real and imaginary parts
+ /// positive/negative.
+ CompositeNode *extractAddend(AddendList &RealAddends, AddendList &ImagAddends,
bool Positive);
/// Determine if sum of multiplications of complex numbers can be formed from
@@ -647,8 +651,7 @@ static const IntrinsicInst *getFMAOrMulAdd(const Instruction *I) {
}
static inline ComplexDeinterleavingRotation
-flipRotation(ComplexDeinterleavingRotation Rotation, bool Cond=true)
-{
+flipRotation(ComplexDeinterleavingRotation Rotation, bool Cond = true) {
if (!Cond)
return Rotation;
return ComplexDeinterleavingRotation(unsigned(Rotation) ^ 2);
@@ -676,9 +679,9 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
bool RealPositive,
bool ImagPositive,
PartialMulNode *PN) {
- LLVM_DEBUG(dbgs() << "identifyPartialMul "
- << (RealPositive ? " + " : " - ") << *Real << " / "
- << (ImagPositive ? " + " : " - ") << *Imag << "\n");
+ LLVM_DEBUG(dbgs() << "identifyPartialMul " << (RealPositive ? " + " : " - ")
+ << *Real << " / " << (ImagPositive ? " + " : " - ") << *Imag
+ << "\n");
bool AllowContract = true;
@@ -702,14 +705,14 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
};
auto ProcessMulAdd = [&](Product Mul, Addend Add, bool CheckAdd,
- SmallVectorImpl<Product> &Muls, Addend &Addend) {
+ SmallVectorImpl<Product> &Muls, Addend &Addend) {
Muls.push_back(Mul);
if (CheckAdd) {
if (auto AddI = dyn_cast<Instruction>(Add.first)) {
auto Op = AddI->getOpcode();
if (Op == Instruction::FMul || Op == Instruction::Mul) {
- Muls.emplace_back(GetProduct(AddI->getOperand(0), AddI->getOperand(1),
- Add.second));
+ Muls.emplace_back(
+ GetProduct(AddI->getOperand(0), AddI->getOperand(1), Add.second));
return;
}
}
@@ -727,18 +730,18 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
IsPositive = !IsPositive;
}
if (auto II = getFMAOrMulAdd(I)) {
- ProcessMulAdd(GetProduct(II->getArgOperand(0), II->getArgOperand(1),
- IsPositive),
- GetAddend(II->getArgOperand(2), IsPositive),
- II->getFastMathFlags().allowReassoc(), Muls, Addend);
+ ProcessMulAdd(
+ GetProduct(II->getArgOperand(0), II->getArgOperand(1), IsPositive),
+ GetAddend(II->getArgOperand(2), IsPositive),
+ II->getFastMathFlags().allowReassoc(), Muls, Addend);
return true;
}
unsigned Opcode = I->getOpcode();
if (I->hasOneUse() &&
(Opcode == Instruction::FMul || Opcode == Instruction::Mul)) {
- Muls.push_back(GetProduct(I->getOperand(0), I->getOperand(1),
- IsPositive));
+ Muls.push_back(
+ GetProduct(I->getOperand(0), I->getOperand(1), IsPositive));
return true;
}
@@ -759,10 +762,9 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
unsigned Opcode0 = I0->getOpcode();
if (I0->hasOneUse() &&
(Opcode0 == Instruction::FMul || Opcode0 == Instruction::Mul)) {
- ProcessMulAdd(GetProduct(I0->getOperand(0), I0->getOperand(1),
- IsPositive),
- GetAddend(Op1, IsPositive ^ IsSub),
- true, Muls, Addend);
+ ProcessMulAdd(
+ GetProduct(I0->getOperand(0), I0->getOperand(1), IsPositive),
+ GetAddend(Op1, IsPositive ^ IsSub), true, Muls, Addend);
return true;
}
}
@@ -772,16 +774,15 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
(Opcode1 == Instruction::FMul || Opcode1 == Instruction::Mul)) {
ProcessMulAdd(GetProduct(I1->getOperand(0), I1->getOperand(1),
IsPositive ^ IsSub),
- GetAddend(Op0, IsPositive),
- false, Muls, Addend);
+ GetAddend(Op0, IsPositive), false, Muls, Addend);
return true;
}
}
return false;
};
- auto MatchCommons = [&](PartialMulNode *PN,
- CompositeNode *CN, bool CNPositive) -> CompositeNode* {
+ auto MatchCommons = [&](PartialMulNode *PN, CompositeNode *CN,
+ bool CNPositive) -> CompositeNode * {
assert(PN);
for (auto PN0 = PN; PN0; PN0 = PN0->prev) {
if (PN0->CommonNode)
@@ -794,8 +795,8 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
auto Common1 = PN1->Common;
if (RealCommon0 == PN1->IsCommonReal())
continue;
- if (auto CommonNode = (RealCommon0 ? identifyNode(Common0, Common1) :
- identifyNode(Common1, Common0))) {
+ if (auto CommonNode = (RealCommon0 ? identifyNode(Common0, Common1)
+ : identifyNode(Common1, Common0))) {
PN0->CommonNode = CommonNode;
PN1->CommonNode = CommonNode;
break;
@@ -803,8 +804,9 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
}
if (!PN0->CommonNode) {
auto PoisonCommon = PoisonValue::get(Common0->getType());
- if (auto CommonNode = (RealCommon0 ? identifyNode(Common0, PoisonCommon) :
- identifyNode(PoisonCommon, Common0))) {
+ if (auto CommonNode =
+ (RealCommon0 ? identifyNode(Common0, PoisonCommon)
+ : identifyNode(PoisonCommon, Common0))) {
PN0->CommonNode = CommonNode;
continue;
}
@@ -846,13 +848,14 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
return CN;
};
- SmallVector<Product,2> RealMuls{};
- SmallVector<Product,2> ImagMuls{};
+ SmallVector<Product, 2> RealMuls{};
+ SmallVector<Product, 2> ImagMuls{};
Addend RealAddend{nullptr, true};
Addend ImagAddend{nullptr, true};
if (!ProcessInst(Real, RealPositive, RealMuls, RealAddend) ||
!ProcessInst(Imag, ImagPositive, ImagMuls, ImagAddend)) {
- LLVM_DEBUG(dbgs() << " - Failed to match PartialMul in Real/Imag terms.\n");
+ LLVM_DEBUG(
+ dbgs() << " - Failed to match PartialMul in Real/Imag terms.\n");
if (PN && RealPositive == ImagPositive) {
auto CN = identifyNode(Real, Imag);
if (CN) {
@@ -862,8 +865,8 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
});
return MatchCommons(PN, CN, RealPositive);
}
- LLVM_DEBUG(dbgs() << " - Failed to match Addends "
- << *Real << " / " << *Imag << ".\n");
+ LLVM_DEBUG(dbgs() << " - Failed to match Addends " << *Real << " / "
+ << *Imag << ".\n");
}
return nullptr;
}
@@ -871,97 +874,98 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
if (RealMuls.size() != ImagMuls.size())
return nullptr;
- auto ForeachMatch = [&](Product RealMul, Product ImagMul,
- PartialMulNode *PN, auto &&cb) -> CompositeNode* {
+ auto ForeachMatch = [&](Product RealMul, Product ImagMul, PartialMulNode *PN,
+ auto &&cb) -> CompositeNode * {
PartialMulNode NewPN{};
NewPN.prev = PN;
NewPN.AllowContract = AllowContract;
if (RealMul.IsPositive) {
- NewPN.Rotation = (ImagMul.IsPositive ?
- ComplexDeinterleavingRotation::Rotation_0 :
- ComplexDeinterleavingRotation::Rotation_270);
- }
- else {
- NewPN.Rotation = (ImagMul.IsPositive ?
- ComplexDeinterleavingRotation::Rotation_90 :
- ComplexDeinterleavingRotation::Rotation_180);
- }
- auto IdentifyUncommon = [&] (Value *Real, Value *Imag) {
- return (NewPN.IsCommonReal() ? identifyNode(Real, Imag) :
- identifyNode(Imag, Real));
+ NewPN.Rotation =
+ (ImagMul.IsPositive ? ComplexDeinterleavingRotation::Rotation_0
+ : ComplexDeinterleavingRotation::Rotation_270);
+ } else {
+ NewPN.Rotation =
+ (ImagMul.IsPositive ? ComplexDeinterleavingRotation::Rotation_90
+ : ComplexDeinterleavingRotation::Rotation_180);
+ }
+ auto IdentifyUncommon = [&](Value *Real, Value *Imag) {
+ return (NewPN.IsCommonReal() ? identifyNode(Real, Imag)
+ : identifyNode(Imag, Real));
};
if (RealMul.Multiplier == ImagMul.Multiplier &&
- (NewPN.UncommonNode = IdentifyUncommon(RealMul.Multiplicand,
- ImagMul.Multiplicand))) {
- NewPN.Common = RealMul.Multiplier;
- if (auto CN = cb(&NewPN)) {
- return CN;
- }
+ (NewPN.UncommonNode =
+ IdentifyUncommon(RealMul.Multiplicand, ImagMul.Multiplicand))) {
+ NewPN.Common = RealMul.Multiplier;
+ if (auto CN = cb(&NewPN)) {
+ return CN;
+ }
}
if (ImagMul.Multiplicand != ImagMul.Multiplier &&
RealMul.Multiplier == ImagMul.Multiplicand &&
- (NewPN.UncommonNode = IdentifyUncommon(RealMul.Multiplicand,
- ImagMul.Multiplier))) {
- NewPN.Common = RealMul.Multiplier;
- if (auto CN = cb(&NewPN)) {
- return CN;
- }
+ (NewPN.UncommonNode =
+ IdentifyUncommon(RealMul.Multiplicand, ImagMul.Multiplier))) {
+ NewPN.Common = RealMul.Multiplier;
+ if (auto CN = cb(&NewPN)) {
+ return CN;
+ }
}
if (RealMul.Multiplicand == RealMul.Multiplier)
return nullptr;
if (RealMul.Multiplicand == ImagMul.Multiplier &&
- (NewPN.UncommonNode = IdentifyUncommon(RealMul.Multiplier,
- ImagMul.Multiplicand))) {
- NewPN.Common = RealMul.Multiplicand;
- if (auto CN = cb(&NewPN)) {
- return CN;
- }
+ (NewPN.UncommonNode =
+ IdentifyUncommon(RealMul.Multiplier, ImagMul.Multiplicand))) {
+ NewPN.Common = RealMul.Multiplicand;
+ if (auto CN = cb(&NewPN)) {
+ return CN;
+ }
}
if (ImagMul.Multiplicand != ImagMul.Multiplier &&
RealMul.Multiplicand == ImagMul.Multiplicand &&
- (NewPN.UncommonNode = IdentifyUncommon(RealMul.Multiplier,
- ImagMul.Multiplier))) {
- NewPN.Common = RealMul.Multiplicand;
- if (auto CN = cb(&NewPN)) {
- return CN;
- }
+ (NewPN.UncommonNode =
+ IdentifyUncommon(RealMul.Multiplier, ImagMul.Multiplier))) {
+ NewPN.Common = RealMul.Multiplicand;
+ if (auto CN = cb(&NewPN)) {
+ return CN;
+ }
}
return nullptr;
};
if (RealMuls.size() == 1) {
if (!RealAddend.first && !ImagAddend.first) {
- return ForeachMatch(RealMuls[0], ImagMuls[0], PN, [&](PartialMulNode *PN) {
- return MatchCommons(PN, nullptr, RealAddend.second);
- });
+ return ForeachMatch(RealMuls[0], ImagMuls[0], PN,
+ [&](PartialMulNode *PN) {
+ return MatchCommons(PN, nullptr, RealAddend.second);
+ });
}
if (!RealAddend.first || !ImagAddend.first) {
return nullptr;
}
assert(RealAddend.first && ImagAddend.first);
- if (!isa<Instruction>(RealAddend.first) || !isa<Instruction>(ImagAddend.first)) {
+ if (!isa<Instruction>(RealAddend.first) ||
+ !isa<Instruction>(ImagAddend.first)) {
if (RealAddend.second != ImagAddend.second)
return nullptr;
auto CN = identifyNode(RealAddend.first, ImagAddend.first);
if (!CN)
return nullptr;
- return ForeachMatch(RealMuls[0], ImagMuls[0], PN, [&](PartialMulNode *PN) {
- return MatchCommons(PN, CN, RealAddend.second);
- });
+ return ForeachMatch(RealMuls[0], ImagMuls[0], PN,
+ [&](PartialMulNode *PN) {
+ return MatchCommons(PN, CN, RealAddend.second);
+ });
}
return ForeachMatch(RealMuls[0], ImagMuls[0], PN, [&](PartialMulNode *PN) {
return identifyPartialMul(cast<Instruction>(RealAddend.first),
cast<Instruction>(ImagAddend.first),
RealAddend.second, ImagAddend.second, PN);
});
- }
- else {
+ } else {
assert(RealMuls.size() == 2);
assert(!RealAddend.first && !ImagAddend.first);
return ForeachMatch(RealMuls[0], ImagMuls[0], PN, [&](PartialMulNode *PN) {
- return ForeachMatch(RealMuls[1], ImagMuls[1], PN, [&](PartialMulNode *PN) {
- return MatchCommons(PN, nullptr, true);
- });
+ return ForeachMatch(
+ RealMuls[1], ImagMuls[1], PN,
+ [&](PartialMulNode *PN) { return MatchCommons(PN, nullptr, true); });
});
}
}
@@ -997,8 +1001,8 @@ ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
return nullptr;
}
- auto MatchCAdd = [&](Instruction *AR, Instruction *BI,
- Instruction *AI, Instruction *BR) -> CompositeNode* {
+ auto MatchCAdd = [&](Instruction *AR, Instruction *BI, Instruction *AI,
+ Instruction *BR) -> CompositeNode * {
CompositeNode *ResA = identifyNode(AR, AI);
if (!ResA) {
LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
@@ -1355,8 +1359,7 @@ ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
Opcode == Instruction::Sub;
};
- if (!IsOperationSupported(Real) ||
- !IsOperationSupported(Imag))
+ if (!IsOperationSupported(Real) || !IsOperationSupported(Imag))
return nullptr;
std::optional<FastMathFlags> Flags;
@@ -1385,7 +1388,8 @@ ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
// Collect multiplications and addend instructions from the given instruction
// while traversing it operands. Additionally, verify that all instructions
// have the same fast math flags.
- auto Collect = [&UpdateFlags](Instruction *Insn, SmallVectorImpl<Product> &Muls,
+ auto Collect = [&UpdateFlags](Instruction *Insn,
+ SmallVectorImpl<Product> &Muls,
AddendList &Addends) -> bool {
SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
SmallPtrSet<Value *, 8> Visited;
@@ -1497,8 +1501,8 @@ ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
AddendPositive = false;
}
}
- FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode,
- AddendPositive);
+ FinalNode =
+ identifyMultiplications(RealMuls, ImagMuls, FinalNode, AddendPositive);
if (!FinalNode)
return nullptr;
}
@@ -2223,16 +2227,16 @@ ComplexDeinterleavingGraph::identifyDeinterleave(ComplexValues &Vals) {
};
auto CheckValue = [&](Value *V, unsigned ExpectedIdx) {
- if (isa<PoisonValue>(V))
- return true;
- auto EVI = CheckExtract(V, ExpectedIdx, II);
- if (!EVI) {
- II = nullptr;
- return false;
- }
- if (!II)
- II = cast<Instruction>(EVI->getAggregateOperand());
+ if (isa<PoisonValue>(V))
return true;
+ auto EVI = CheckExtract(V, ExpectedIdx, II);
+ if (!EVI) {
+ II = nullptr;
+ return false;
+ }
+ if (!II)
+ II = cast<Instruction>(EVI->getAggregateOperand());
+ return true;
};
for (unsigned Idx = 0; Idx < Vals.size(); Idx++) {
@@ -2281,8 +2285,7 @@ ComplexDeinterleavingGraph::identifyDeinterleave(ComplexValues &Vals) {
if (!RealShuffle) {
Op0 = ImagShuffle->getOperand(0);
ShuffleTy = cast<FixedVectorType>(ImagShuffle->getType());
- }
- else {
+ } else {
Op0 = RealShuffle->getOperand(0);
ShuffleTy = cast<FixedVectorType>(RealShuffle->getType());
if (ImagShuffle) {
@@ -2307,12 +2310,13 @@ ComplexDeinterleavingGraph::identifyDeinterleave(ComplexValues &Vals) {
return nullptr;
}
- auto CheckShuffle = [&](ShuffleVectorInst *Shuffle, int Mask0, const char *Name) -> bool {
+ auto CheckShuffle = [&](ShuffleVectorInst *Shuffle, int Mask0,
+ const char *Name) -> bool {
if (!Shuffle) // Poison value
return true;
Value *Op1 = Shuffle->getOperand(1);
if (!isa<UndefValue>(Op1) && !isa<ConstantAggregateZero>(Op1)) {
- LLVM_DEBUG(dbgs() << " - " << Name << "Op1 is not undef or zero.\n");
+ LLVM_DEBUG(dbgs() << " - " << Name << "Op1 is not undef or zero.\n");
return false;
}
ArrayRef<int> Mask = Shuffle->getShuffleMask();
@@ -2321,7 +2325,8 @@ ComplexDeinterleavingGraph::identifyDeinterleave(ComplexValues &Vals) {
return false;
}
if (Mask[0] != Mask0) {
- LLVM_DEBUG(dbgs() << " - " << Name << "Masks do not have the correct initial value.\n");
+ LLVM_DEBUG(dbgs() << " - " << Name
+ << "Masks do not have the correct initial value.\n");
return false;
}
// Ensure that the deinterleaving shuffle only pulls from the first
|
Use an approach similar to how reassoc is handled. However, in this case, we need to maintain the structure of the operations so instead of collecting a set of multiplications to be added together, we build a stack of multiplications that will be added in the stack order. Compared to the old approach, the depth of the stack can be 1 (to match unpaired single partial multiplication) and can also be arbitrarily deep (to match longer complex computations). Similar to the reassoc case, we can also walk the stack to find complex pairs of common terms that may be more than one level away from each other.
We are already confirming that everything is consistent with the first operation so there's no need to check the opcode for every single instructions
We propagate the negative sign to the top level to maximize the chance of it being merged with other operations (e.g. canceling another neg or merging into add/sub)
If we couldn't find a positive addend, we could simply find a negative one and use that as the accumulator. In the worst case we may need to add a negation to the final result but we'll get rid of an add/sub between addends and a zero initialization of the accumulator.
For fixed vector it's possible to see non-zero masks in splats
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Ref #173274
This fixes/improves on some of the issues mentioned in that issue. Probably need more tests and some clean up but should be good enough for initial review.