Skip to content

A few improvement in fcmla pattern recognitions#173818

Draft
yuyichao wants to merge 12 commits intollvm:mainfrom
yuyichao:fcmla
Draft

A few improvement in fcmla pattern recognitions#173818
yuyichao wants to merge 12 commits intollvm:mainfrom
yuyichao:fcmla

Conversation

@yuyichao
Copy link
Contributor

Ref #173274

This fixes/improves on some of the issues mentioned in that issue. Probably need more tests and some clean up but should be good enough for initial review.

* Relax requirement on exact fastmath flag matching

  It should be enough to require all flags to include reassoc

* Fallback to treating non-reassoc additions as addends to discover more
  deinterleaving opportunities.
@github-actions
Copy link

github-actions bot commented Dec 29, 2025

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff origin/main HEAD --extensions cpp -- llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp --diff_from_common_commit

⚠️
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing origin/main to the base branch/commit you want to compare against.
⚠️

View the diff from clang-format here.
diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index f86459f80..a9e4bab0a 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -282,7 +282,10 @@ public:
     CompositeNode *CommonNode{nullptr};
     ComplexDeinterleavingRotation Rotation;
     bool AllowContract;
-    bool IsCommonReal() const { return Rotation == ComplexDeinterleavingRotation::Rotation_0 || Rotation == ComplexDeinterleavingRotation::Rotation_180; }
+    bool IsCommonReal() const {
+      return Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
+             Rotation == ComplexDeinterleavingRotation::Rotation_180;
+    }
   };
 
   explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
@@ -384,8 +387,8 @@ private:
   }
 
   CompositeNode *negCompositeNode(CompositeNode *Node) {
-    auto NegNode = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
-                                        nullptr, nullptr);
+    auto NegNode = prepareCompositeNode(
+        ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
     NegNode->Opcode = Instruction::FNeg;
     NegNode->addOperand(Node);
     return submitCompositeNode(NegNode);
@@ -403,8 +406,9 @@ private:
   /// 270: r: cr + ai * bi
   ///      i: ci - ai * br
   CompositeNode *identifyPartialMul(Instruction *Real, Instruction *Imag,
-                                    bool RealPositive=true, bool ImagPositive=true,
-                                    PartialMulNode *PN=nullptr);
+                                    bool RealPositive = true,
+                                    bool ImagPositive = true,
+                                    PartialMulNode *PN = nullptr);
 
   /// Identifies a complex add pattern and its rotation, based on the following
   /// patterns.
@@ -436,9 +440,9 @@ private:
                                    CompositeNode *Accumulator,
                                    bool &AccumPositive);
 
-  /// Extract one addend that have both real and imaginary parts positive/negative.
-  CompositeNode *extractAddend(AddendList &RealAddends,
-                               AddendList &ImagAddends,
+  /// Extract one addend that have both real and imaginary parts
+  /// positive/negative.
+  CompositeNode *extractAddend(AddendList &RealAddends, AddendList &ImagAddends,
                                bool Positive);
 
   /// Determine if sum of multiplications of complex numbers can be formed from
@@ -647,8 +651,7 @@ static const IntrinsicInst *getFMAOrMulAdd(const Instruction *I) {
 }
 
 static inline ComplexDeinterleavingRotation
-flipRotation(ComplexDeinterleavingRotation Rotation, bool Cond=true)
-{
+flipRotation(ComplexDeinterleavingRotation Rotation, bool Cond = true) {
   if (!Cond)
     return Rotation;
   return ComplexDeinterleavingRotation(unsigned(Rotation) ^ 2);
@@ -676,9 +679,9 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
                                                bool RealPositive,
                                                bool ImagPositive,
                                                PartialMulNode *PN) {
-  LLVM_DEBUG(dbgs() << "identifyPartialMul "
-                    << (RealPositive ? " + " : " - ") << *Real << " / "
-                    << (ImagPositive ? " + " : " - ") << *Imag << "\n");
+  LLVM_DEBUG(dbgs() << "identifyPartialMul " << (RealPositive ? " + " : " - ")
+                    << *Real << " / " << (ImagPositive ? " + " : " - ") << *Imag
+                    << "\n");
 
   bool AllowContract = true;
 
@@ -702,14 +705,14 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
   };
 
   auto ProcessMulAdd = [&](Product Mul, Addend Add, bool CheckAdd,
-                          SmallVectorImpl<Product> &Muls, Addend &Addend) {
+                           SmallVectorImpl<Product> &Muls, Addend &Addend) {
     Muls.push_back(Mul);
     if (CheckAdd) {
       if (auto AddI = dyn_cast<Instruction>(Add.first)) {
         auto Op = AddI->getOpcode();
         if (Op == Instruction::FMul || Op == Instruction::Mul) {
-          Muls.emplace_back(GetProduct(AddI->getOperand(0), AddI->getOperand(1),
-                                       Add.second));
+          Muls.emplace_back(
+              GetProduct(AddI->getOperand(0), AddI->getOperand(1), Add.second));
           return;
         }
       }
@@ -727,18 +730,18 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
       IsPositive = !IsPositive;
     }
     if (auto II = getFMAOrMulAdd(I)) {
-      ProcessMulAdd(GetProduct(II->getArgOperand(0), II->getArgOperand(1),
-                               IsPositive),
-                    GetAddend(II->getArgOperand(2), IsPositive),
-                    II->getFastMathFlags().allowReassoc(), Muls, Addend);
+      ProcessMulAdd(
+          GetProduct(II->getArgOperand(0), II->getArgOperand(1), IsPositive),
+          GetAddend(II->getArgOperand(2), IsPositive),
+          II->getFastMathFlags().allowReassoc(), Muls, Addend);
       return true;
     }
 
     unsigned Opcode = I->getOpcode();
     if (I->hasOneUse() &&
         (Opcode == Instruction::FMul || Opcode == Instruction::Mul)) {
-      Muls.push_back(GetProduct(I->getOperand(0), I->getOperand(1),
-                                IsPositive));
+      Muls.push_back(
+          GetProduct(I->getOperand(0), I->getOperand(1), IsPositive));
       return true;
     }
 
@@ -759,10 +762,9 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
       unsigned Opcode0 = I0->getOpcode();
       if (I0->hasOneUse() &&
           (Opcode0 == Instruction::FMul || Opcode0 == Instruction::Mul)) {
-        ProcessMulAdd(GetProduct(I0->getOperand(0), I0->getOperand(1),
-                                 IsPositive),
-                      GetAddend(Op1, IsPositive ^ IsSub),
-                      true, Muls, Addend);
+        ProcessMulAdd(
+            GetProduct(I0->getOperand(0), I0->getOperand(1), IsPositive),
+            GetAddend(Op1, IsPositive ^ IsSub), true, Muls, Addend);
         return true;
       }
     }
@@ -772,16 +774,15 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
           (Opcode1 == Instruction::FMul || Opcode1 == Instruction::Mul)) {
         ProcessMulAdd(GetProduct(I1->getOperand(0), I1->getOperand(1),
                                  IsPositive ^ IsSub),
-                      GetAddend(Op0, IsPositive),
-                      false, Muls, Addend);
+                      GetAddend(Op0, IsPositive), false, Muls, Addend);
         return true;
       }
     }
     return false;
   };
 
-  auto MatchCommons = [&](PartialMulNode *PN,
-                          CompositeNode *CN, bool CNPositive) -> CompositeNode* {
+  auto MatchCommons = [&](PartialMulNode *PN, CompositeNode *CN,
+                          bool CNPositive) -> CompositeNode * {
     assert(PN);
     for (auto PN0 = PN; PN0; PN0 = PN0->prev) {
       if (PN0->CommonNode)
@@ -794,8 +795,8 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
         auto Common1 = PN1->Common;
         if (RealCommon0 == PN1->IsCommonReal())
           continue;
-        if (auto CommonNode = (RealCommon0 ? identifyNode(Common0, Common1) :
-                               identifyNode(Common1, Common0))) {
+        if (auto CommonNode = (RealCommon0 ? identifyNode(Common0, Common1)
+                                           : identifyNode(Common1, Common0))) {
           PN0->CommonNode = CommonNode;
           PN1->CommonNode = CommonNode;
           break;
@@ -803,8 +804,9 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
       }
       if (!PN0->CommonNode) {
         auto PoisonCommon = PoisonValue::get(Common0->getType());
-        if (auto CommonNode = (RealCommon0 ? identifyNode(Common0, PoisonCommon) :
-                               identifyNode(PoisonCommon, Common0))) {
+        if (auto CommonNode =
+                (RealCommon0 ? identifyNode(Common0, PoisonCommon)
+                             : identifyNode(PoisonCommon, Common0))) {
           PN0->CommonNode = CommonNode;
           continue;
         }
@@ -846,13 +848,14 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
     return CN;
   };
 
-  SmallVector<Product,2> RealMuls{};
-  SmallVector<Product,2> ImagMuls{};
+  SmallVector<Product, 2> RealMuls{};
+  SmallVector<Product, 2> ImagMuls{};
   Addend RealAddend{nullptr, true};
   Addend ImagAddend{nullptr, true};
   if (!ProcessInst(Real, RealPositive, RealMuls, RealAddend) ||
       !ProcessInst(Imag, ImagPositive, ImagMuls, ImagAddend)) {
-    LLVM_DEBUG(dbgs() << "  - Failed to match PartialMul in Real/Imag terms.\n");
+    LLVM_DEBUG(
+        dbgs() << "  - Failed to match PartialMul in Real/Imag terms.\n");
     if (PN && RealPositive == ImagPositive) {
       auto CN = identifyNode(Real, Imag);
       if (CN) {
@@ -862,8 +865,8 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
         });
         return MatchCommons(PN, CN, RealPositive);
       }
-      LLVM_DEBUG(dbgs() << "  - Failed to match Addends "
-                 << *Real << " / " << *Imag << ".\n");
+      LLVM_DEBUG(dbgs() << "  - Failed to match Addends " << *Real << " / "
+                        << *Imag << ".\n");
     }
     return nullptr;
   }
@@ -871,97 +874,98 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
   if (RealMuls.size() != ImagMuls.size())
     return nullptr;
 
-  auto ForeachMatch = [&](Product RealMul, Product ImagMul,
-                          PartialMulNode *PN, auto &&cb) -> CompositeNode* {
+  auto ForeachMatch = [&](Product RealMul, Product ImagMul, PartialMulNode *PN,
+                          auto &&cb) -> CompositeNode * {
     PartialMulNode NewPN{};
     NewPN.prev = PN;
     NewPN.AllowContract = AllowContract;
     if (RealMul.IsPositive) {
-      NewPN.Rotation = (ImagMul.IsPositive ?
-                        ComplexDeinterleavingRotation::Rotation_0 :
-                        ComplexDeinterleavingRotation::Rotation_270);
-    }
-    else {
-      NewPN.Rotation = (ImagMul.IsPositive ?
-                        ComplexDeinterleavingRotation::Rotation_90 :
-                        ComplexDeinterleavingRotation::Rotation_180);
-    }
-    auto IdentifyUncommon = [&] (Value *Real, Value *Imag) {
-      return (NewPN.IsCommonReal() ? identifyNode(Real, Imag) :
-              identifyNode(Imag, Real));
+      NewPN.Rotation =
+          (ImagMul.IsPositive ? ComplexDeinterleavingRotation::Rotation_0
+                              : ComplexDeinterleavingRotation::Rotation_270);
+    } else {
+      NewPN.Rotation =
+          (ImagMul.IsPositive ? ComplexDeinterleavingRotation::Rotation_90
+                              : ComplexDeinterleavingRotation::Rotation_180);
+    }
+    auto IdentifyUncommon = [&](Value *Real, Value *Imag) {
+      return (NewPN.IsCommonReal() ? identifyNode(Real, Imag)
+                                   : identifyNode(Imag, Real));
     };
     if (RealMul.Multiplier == ImagMul.Multiplier &&
-        (NewPN.UncommonNode = IdentifyUncommon(RealMul.Multiplicand,
-                                               ImagMul.Multiplicand))) {
-        NewPN.Common = RealMul.Multiplier;
-        if (auto CN = cb(&NewPN)) {
-          return CN;
-        }
+        (NewPN.UncommonNode =
+             IdentifyUncommon(RealMul.Multiplicand, ImagMul.Multiplicand))) {
+      NewPN.Common = RealMul.Multiplier;
+      if (auto CN = cb(&NewPN)) {
+        return CN;
+      }
     }
     if (ImagMul.Multiplicand != ImagMul.Multiplier &&
         RealMul.Multiplier == ImagMul.Multiplicand &&
-        (NewPN.UncommonNode = IdentifyUncommon(RealMul.Multiplicand,
-                                               ImagMul.Multiplier))) {
-        NewPN.Common = RealMul.Multiplier;
-        if (auto CN = cb(&NewPN)) {
-          return CN;
-        }
+        (NewPN.UncommonNode =
+             IdentifyUncommon(RealMul.Multiplicand, ImagMul.Multiplier))) {
+      NewPN.Common = RealMul.Multiplier;
+      if (auto CN = cb(&NewPN)) {
+        return CN;
+      }
     }
     if (RealMul.Multiplicand == RealMul.Multiplier)
       return nullptr;
     if (RealMul.Multiplicand == ImagMul.Multiplier &&
-        (NewPN.UncommonNode = IdentifyUncommon(RealMul.Multiplier,
-                                               ImagMul.Multiplicand))) {
-        NewPN.Common = RealMul.Multiplicand;
-        if (auto CN = cb(&NewPN)) {
-          return CN;
-        }
+        (NewPN.UncommonNode =
+             IdentifyUncommon(RealMul.Multiplier, ImagMul.Multiplicand))) {
+      NewPN.Common = RealMul.Multiplicand;
+      if (auto CN = cb(&NewPN)) {
+        return CN;
+      }
     }
     if (ImagMul.Multiplicand != ImagMul.Multiplier &&
         RealMul.Multiplicand == ImagMul.Multiplicand &&
-        (NewPN.UncommonNode = IdentifyUncommon(RealMul.Multiplier,
-                                               ImagMul.Multiplier))) {
-        NewPN.Common = RealMul.Multiplicand;
-        if (auto CN = cb(&NewPN)) {
-          return CN;
-        }
+        (NewPN.UncommonNode =
+             IdentifyUncommon(RealMul.Multiplier, ImagMul.Multiplier))) {
+      NewPN.Common = RealMul.Multiplicand;
+      if (auto CN = cb(&NewPN)) {
+        return CN;
+      }
     }
     return nullptr;
   };
 
   if (RealMuls.size() == 1) {
     if (!RealAddend.first && !ImagAddend.first) {
-      return ForeachMatch(RealMuls[0], ImagMuls[0], PN, [&](PartialMulNode *PN) {
-        return MatchCommons(PN, nullptr, RealAddend.second);
-      });
+      return ForeachMatch(RealMuls[0], ImagMuls[0], PN,
+                          [&](PartialMulNode *PN) {
+                            return MatchCommons(PN, nullptr, RealAddend.second);
+                          });
     }
     if (!RealAddend.first || !ImagAddend.first) {
       return nullptr;
     }
     assert(RealAddend.first && ImagAddend.first);
-    if (!isa<Instruction>(RealAddend.first) || !isa<Instruction>(ImagAddend.first)) {
+    if (!isa<Instruction>(RealAddend.first) ||
+        !isa<Instruction>(ImagAddend.first)) {
       if (RealAddend.second != ImagAddend.second)
         return nullptr;
       auto CN = identifyNode(RealAddend.first, ImagAddend.first);
       if (!CN)
         return nullptr;
-      return ForeachMatch(RealMuls[0], ImagMuls[0], PN, [&](PartialMulNode *PN) {
-        return MatchCommons(PN, CN, RealAddend.second);
-      });
+      return ForeachMatch(RealMuls[0], ImagMuls[0], PN,
+                          [&](PartialMulNode *PN) {
+                            return MatchCommons(PN, CN, RealAddend.second);
+                          });
     }
     return ForeachMatch(RealMuls[0], ImagMuls[0], PN, [&](PartialMulNode *PN) {
       return identifyPartialMul(cast<Instruction>(RealAddend.first),
                                 cast<Instruction>(ImagAddend.first),
                                 RealAddend.second, ImagAddend.second, PN);
     });
-  }
-  else {
+  } else {
     assert(RealMuls.size() == 2);
     assert(!RealAddend.first && !ImagAddend.first);
     return ForeachMatch(RealMuls[0], ImagMuls[0], PN, [&](PartialMulNode *PN) {
-      return ForeachMatch(RealMuls[1], ImagMuls[1], PN, [&](PartialMulNode *PN) {
-        return MatchCommons(PN, nullptr, true);
-      });
+      return ForeachMatch(
+          RealMuls[1], ImagMuls[1], PN,
+          [&](PartialMulNode *PN) { return MatchCommons(PN, nullptr, true); });
     });
   }
 }
@@ -997,8 +1001,8 @@ ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
     return nullptr;
   }
 
-  auto MatchCAdd = [&](Instruction *AR, Instruction *BI,
-                       Instruction *AI, Instruction *BR) -> CompositeNode* {
+  auto MatchCAdd = [&](Instruction *AR, Instruction *BI, Instruction *AI,
+                       Instruction *BR) -> CompositeNode * {
     CompositeNode *ResA = identifyNode(AR, AI);
     if (!ResA) {
       LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
@@ -1355,8 +1359,7 @@ ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
            Opcode == Instruction::Sub;
   };
 
-  if (!IsOperationSupported(Real) ||
-      !IsOperationSupported(Imag))
+  if (!IsOperationSupported(Real) || !IsOperationSupported(Imag))
     return nullptr;
 
   std::optional<FastMathFlags> Flags;
@@ -1385,7 +1388,8 @@ ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
   // Collect multiplications and addend instructions from the given instruction
   // while traversing it operands. Additionally, verify that all instructions
   // have the same fast math flags.
-  auto Collect = [&UpdateFlags](Instruction *Insn, SmallVectorImpl<Product> &Muls,
+  auto Collect = [&UpdateFlags](Instruction *Insn,
+                                SmallVectorImpl<Product> &Muls,
                                 AddendList &Addends) -> bool {
     SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
     SmallPtrSet<Value *, 8> Visited;
@@ -1497,8 +1501,8 @@ ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
         AddendPositive = false;
       }
     }
-    FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode,
-                                        AddendPositive);
+    FinalNode =
+        identifyMultiplications(RealMuls, ImagMuls, FinalNode, AddendPositive);
     if (!FinalNode)
       return nullptr;
   }
@@ -2223,16 +2227,16 @@ ComplexDeinterleavingGraph::identifyDeinterleave(ComplexValues &Vals) {
   };
 
   auto CheckValue = [&](Value *V, unsigned ExpectedIdx) {
-      if (isa<PoisonValue>(V))
-        return true;
-      auto EVI = CheckExtract(V, ExpectedIdx, II);
-      if (!EVI) {
-        II = nullptr;
-        return false;
-      }
-      if (!II)
-        II = cast<Instruction>(EVI->getAggregateOperand());
+    if (isa<PoisonValue>(V))
       return true;
+    auto EVI = CheckExtract(V, ExpectedIdx, II);
+    if (!EVI) {
+      II = nullptr;
+      return false;
+    }
+    if (!II)
+      II = cast<Instruction>(EVI->getAggregateOperand());
+    return true;
   };
 
   for (unsigned Idx = 0; Idx < Vals.size(); Idx++) {
@@ -2281,8 +2285,7 @@ ComplexDeinterleavingGraph::identifyDeinterleave(ComplexValues &Vals) {
   if (!RealShuffle) {
     Op0 = ImagShuffle->getOperand(0);
     ShuffleTy = cast<FixedVectorType>(ImagShuffle->getType());
-  }
-  else {
+  } else {
     Op0 = RealShuffle->getOperand(0);
     ShuffleTy = cast<FixedVectorType>(RealShuffle->getType());
     if (ImagShuffle) {
@@ -2307,12 +2310,13 @@ ComplexDeinterleavingGraph::identifyDeinterleave(ComplexValues &Vals) {
     return nullptr;
   }
 
-  auto CheckShuffle = [&](ShuffleVectorInst *Shuffle, int Mask0, const char *Name) -> bool {
+  auto CheckShuffle = [&](ShuffleVectorInst *Shuffle, int Mask0,
+                          const char *Name) -> bool {
     if (!Shuffle) // Poison value
       return true;
     Value *Op1 = Shuffle->getOperand(1);
     if (!isa<UndefValue>(Op1) && !isa<ConstantAggregateZero>(Op1)) {
-        LLVM_DEBUG(dbgs() << " - " << Name << "Op1 is not undef or zero.\n");
+      LLVM_DEBUG(dbgs() << " - " << Name << "Op1 is not undef or zero.\n");
       return false;
     }
     ArrayRef<int> Mask = Shuffle->getShuffleMask();
@@ -2321,7 +2325,8 @@ ComplexDeinterleavingGraph::identifyDeinterleave(ComplexValues &Vals) {
       return false;
     }
     if (Mask[0] != Mask0) {
-      LLVM_DEBUG(dbgs() << " - " << Name << "Masks do not have the correct initial value.\n");
+      LLVM_DEBUG(dbgs() << " - " << Name
+                        << "Masks do not have the correct initial value.\n");
       return false;
     }
     // Ensure that the deinterleaving shuffle only pulls from the first

Use an approach similar to how reassoc is handled.
However, in this case, we need to maintain the structure of the operations
so instead of collecting a set of multiplications to be added together,
we build a stack of multiplications that will be added in the stack order.

Compared to the old approach, the depth of the stack can be 1
(to match unpaired single partial multiplication) and can also be
arbitrarily deep (to match longer complex computations).
Similar to the reassoc case, we can also walk the stack to find
complex pairs of common terms that may be more than one level
away from each other.
We are already confirming that everything is consistent with the first operation so there's no need to check the opcode for every single instructions
We propagate the negative sign to the top level to maximize the chance
of it being merged with other operations
(e.g. canceling another neg or merging into add/sub)
If we couldn't find a positive addend, we could simply find a negative one
and use that as the accumulator.
In the worst case we may need to add a negation to the final result
but we'll get rid of an add/sub between addends and a zero initialization
of the accumulator.
For fixed vector it's possible to see non-zero masks in splats
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant